Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel_ROCM
Commits
f1e53807
Unverified
Commit
f1e53807
authored
Feb 10, 2025
by
Illia Silin
Committed by
GitHub
Feb 10, 2025
Browse files
Merge branch 'develop' into ck_host_lib
parents
7450417d
d9f1ead3
Changes
458
Show whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
3434 additions
and
546 deletions
+3434
-546
include/ck/utility/amd_buffer_addressing.hpp
include/ck/utility/amd_buffer_addressing.hpp
+26
-9
include/ck/utility/amd_ck_fp8.hpp
include/ck/utility/amd_ck_fp8.hpp
+996
-0
include/ck/utility/amd_inline_asm.hpp
include/ck/utility/amd_inline_asm.hpp
+22
-1
include/ck/utility/amd_wave_read_first_lane.hpp
include/ck/utility/amd_wave_read_first_lane.hpp
+14
-13
include/ck/utility/amd_xdlops.hpp
include/ck/utility/amd_xdlops.hpp
+265
-2
include/ck/utility/array.hpp
include/ck/utility/array.hpp
+4
-2
include/ck/utility/blkgemmpipe_scheduler.hpp
include/ck/utility/blkgemmpipe_scheduler.hpp
+10
-2
include/ck/utility/container_helper.hpp
include/ck/utility/container_helper.hpp
+3
-3
include/ck/utility/data_type.hpp
include/ck/utility/data_type.hpp
+1935
-483
include/ck/utility/debug.hpp
include/ck/utility/debug.hpp
+2
-1
include/ck/utility/dynamic_buffer.hpp
include/ck/utility/dynamic_buffer.hpp
+21
-8
include/ck/utility/e8m0.hpp
include/ck/utility/e8m0.hpp
+80
-0
include/ck/utility/enable_if.hpp
include/ck/utility/enable_if.hpp
+18
-1
include/ck/utility/env.hpp
include/ck/utility/env.hpp
+3
-1
include/ck/utility/functional.hpp
include/ck/utility/functional.hpp
+3
-3
include/ck/utility/functional4.hpp
include/ck/utility/functional4.hpp
+6
-6
include/ck/utility/integral_constant.hpp
include/ck/utility/integral_constant.hpp
+6
-1
include/ck/utility/is_detected.hpp
include/ck/utility/is_detected.hpp
+9
-7
include/ck/utility/loop_scheduler.hpp
include/ck/utility/loop_scheduler.hpp
+6
-2
include/ck/utility/magic_division.hpp
include/ck/utility/magic_division.hpp
+5
-1
No files found.
Too many changes to show.
To preserve performance only
458 of 458+
files are displayed.
Plain diff
Email patch
include/ck/utility/amd_buffer_addressing.hpp
View file @
f1e53807
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-202
4
, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-202
5
, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "data_type.hpp"
...
...
@@ -429,7 +429,8 @@ __device__ typename vector_type<T, N>::type amd_buffer_load_impl(int32x4_t src_w
(
is_same
<
T
,
f8_t
>::
value
&&
(
N
==
1
||
N
==
2
||
N
==
4
||
N
==
8
||
N
==
16
))
||
(
is_same
<
T
,
bf8_t
>::
value
&&
(
N
==
1
||
N
==
2
||
N
==
4
||
N
==
8
||
N
==
16
))
||
(
is_same
<
T
,
int8_t
>::
value
&&
(
N
==
1
||
N
==
2
||
N
==
4
||
N
==
8
||
N
==
16
))
||
(
is_same
<
T
,
uint8_t
>::
value
&&
(
N
==
1
||
N
==
2
||
N
==
4
||
N
==
8
||
N
==
16
)),
(
is_same
<
T
,
uint8_t
>::
value
&&
(
N
==
1
||
N
==
2
||
N
==
4
||
N
==
8
||
N
==
16
))
||
(
is_same
<
T
,
pk_i4_t
>::
value
&&
(
N
==
1
||
N
==
2
||
N
==
4
||
N
==
8
||
N
==
16
)),
"wrong! not implemented"
);
using
r_t
=
typename
vector_type
<
T
,
N
>::
type
;
...
...
@@ -549,8 +550,10 @@ __device__ void amd_buffer_store_impl(const typename vector_type<T, N>::type src
(
is_same
<
T
,
half_t
>::
value
&&
(
N
==
1
||
N
==
2
||
N
==
4
||
N
==
8
||
N
==
16
))
||
(
is_same
<
T
,
bhalf_t
>::
value
&&
(
N
==
1
||
N
==
2
||
N
==
4
||
N
==
8
||
N
==
16
))
||
(
is_same
<
T
,
int32_t
>::
value
&&
(
N
==
1
||
N
==
2
||
N
==
4
||
N
==
8
||
N
==
16
))
||
(
is_same
<
T
,
f8_t
>::
value
&&
(
N
==
1
||
N
==
2
||
N
==
4
||
N
==
8
||
N
==
16
))
||
(
is_same
<
T
,
bf8_t
>::
value
&&
(
N
==
1
||
N
==
2
||
N
==
4
||
N
==
8
||
N
==
16
))
||
(
is_same
<
T
,
f8_fnuz_t
>::
value
&&
(
N
==
1
||
N
==
2
||
N
==
4
||
N
==
8
||
N
==
16
))
||
(
is_same
<
T
,
bf8_fnuz_t
>::
value
&&
(
N
==
1
||
N
==
2
||
N
==
4
||
N
==
8
||
N
==
16
))
||
(
is_same
<
T
,
fp8_storage_t
>::
value
&&
(
N
==
1
||
N
==
2
||
N
==
4
||
N
==
8
||
N
==
16
))
||
(
is_same
<
T
,
int8_t
>::
value
&&
(
N
==
1
||
N
==
2
||
N
==
4
||
N
==
8
||
N
==
16
)),
"wrong! not implemented"
);
...
...
@@ -578,7 +581,7 @@ __device__ void amd_global_atomic_add_impl(const typename vector_type<T, N>::typ
tmp
.
template
AsType
<
half2_t
>()[
i
]);
});
}
#if defined(__gfx942__)
#if defined(__gfx942__)
|| defined(__gfx950__)
else
if
constexpr
(
is_same
<
T
,
bhalf_t
>::
value
)
{
vector_type
<
bhalf_t
,
N
>
tmp
{
src_thread_data
};
...
...
@@ -843,8 +846,8 @@ amd_buffer_load_invalid_element_return_zero(const T* p_src_wave,
#else
vector_t
tmp
=
amd_buffer_load_impl
<
scalar_t
,
vector_size
,
coherence
>
(
src_wave_buffer_resource
,
src_thread_addr_offset
,
0
);
vector_t
tmp
{
amd_buffer_load_impl
<
scalar_t
,
vector_size
,
coherence
>
(
src_wave_buffer_resource
,
src_thread_addr_offset
,
0
)
}
;
return
src_thread_element_valid
?
tmp
:
vector_t
(
0
);
#endif
}
...
...
@@ -873,8 +876,8 @@ amd_buffer_load_invalid_element_return_customized_value(const T* p_src_wave,
constexpr
index_t
vector_size
=
scalar_type
<
vector_t
>::
vector_size
;
vector_t
tmp
=
amd_buffer_load_impl
<
scalar_t
,
vector_size
,
coherence
>
(
src_wave_buffer_resource
,
src_thread_addr_offset
,
0
);
vector_t
tmp
{
amd_buffer_load_impl
<
scalar_t
,
vector_size
,
coherence
>
(
src_wave_buffer_resource
,
src_thread_addr_offset
,
0
)
}
;
return
src_thread_element_valid
?
tmp
:
vector_t
(
customized_value
);
}
...
...
@@ -1018,15 +1021,24 @@ __device__ void amd_direct_load_global_to_lds(const T* global_base_ptr,
constexpr
auto
bytes_per_thread
=
sizeof
(
T
)
*
NumElemsPerThread
;
static_assert
(
bytes_per_thread
==
dword_bytes
);
#ifndef CK_CODE_GEN_RTC
const
uint32_t
*
global_ptr
=
reinterpret_cast
<
uint32_t
*>
(
reinterpret_cast
<
uintptr_t
>
(
global_base_ptr
));
#else
const
uint32_t
*
global_ptr
=
reinterpret_cast
<
uint32_t
*>
(
reinterpret_cast
<
size_t
>
(
global_base_ptr
));
#endif
const
int32x4_t
src_resource
=
make_wave_buffer_resource
(
global_ptr
,
src_element_space_size
);
const
index_t
global_offset_bytes
=
is_valid
?
global_offset
*
sizeof
(
T
)
:
0x80000000
;
#if CK_USE_AMD_LDS_DIRECT_LOAD_INLINE_ASM
T
*
lds_ptr
=
lds_base_ptr
+
lds_offset
;
#ifndef CK_CODE_GEN_RTC
auto
const
lds_ptr_sgpr
=
__builtin_amdgcn_readfirstlane
((
reinterpret_cast
<
uintptr_t
>
(
lds_ptr
)));
#else
auto
const
lds_ptr_sgpr
=
__builtin_amdgcn_readfirstlane
((
reinterpret_cast
<
size_t
>
(
lds_ptr
)));
#endif
asm
volatile
(
"s_mov_b32 m0, %0;
\n\t
"
"buffer_load_dword %1, %2, 0 offen lds;
\n\t
"
::
"s"
(
lds_ptr_sgpr
),
"v"
(
global_offset_bytes
),
...
...
@@ -1035,8 +1047,13 @@ __device__ void amd_direct_load_global_to_lds(const T* global_base_ptr,
#else
// LDS pointer must be attributed with the LDS address space.
__attribute__
((
address_space
(
3
)))
uint32_t
*
lds_ptr
=
#ifndef CK_CODE_GEN_RTC
reinterpret_cast
<
__attribute__
((
address_space
(
3
)))
uint32_t
*>
(
reinterpret_cast
<
uintptr_t
>
(
lds_base_ptr
+
lds_offset
));
#else
reinterpret_cast
<
__attribute__
((
address_space
(
3
)))
uint32_t
*>
(
reinterpret_cast
<
size_t
>
(
lds_base_ptr
+
lds_offset
));
#endif
llvm_amdgcn_raw_buffer_load_lds
(
src_resource
,
lds_ptr
,
sizeof
(
uint32_t
),
global_offset_bytes
,
0
,
0
,
0
);
...
...
include/ck/utility/amd_ck_fp8.hpp
0 → 100644
View file @
f1e53807
// SPDX-License-Identifier: MIT
// Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck/ck.hpp"
#include "ck/utility/enable_if.hpp"
#include "ck/utility/random_gen.hpp"
#include "ck/utility/type.hpp"
#ifdef CK_USE_FNUZ_FP8
#define CK_USE_FNUZ_FP8 1
#else
#define CK_USE_FNUZ_FP8 0
#endif
#ifdef CK_USE_OCP_FP8
#define CK_USE_OCP_FP8 1
#else
#define CK_USE_OCP_FP8 0
#endif
#if(defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__) || defined(__gfx1200__) || \
defined(__gfx1201__) || defined(__gfx950__)) && \
__HIP_DEVICE_COMPILE__
#define CK_FP8_CVT_FAST_PATH 1
#else
#define CK_FP8_CVT_FAST_PATH 0
#endif
#if(defined(__gfx1200__) || defined(__gfx1201__) || defined(__gfx950__)) && __HIP_DEVICE_COMPILE__
#define CK_OCP_FP8_CVT_FAST_PATH 1
#else
#define CK_OCP_FP8_CVT_FAST_PATH 0
#endif
namespace
ck
{
using
f8_fnuz_t
=
_BitInt
(
8
);
using
bf8_fnuz_t
=
unsigned
_BitInt
(
8
);
typedef
unsigned
char
fp8_storage_t
;
/**
* \brief Describes FP8 interpretation
*/
enum
class
ck_fp8_interpretation_t
{
CK_E4M3_OCP
=
0
,
// OCP E4M3
CK_E5M2_OCP
=
1
,
// OCP E5M2
CK_E4M3_FNUZ
=
2
,
// FP8
CK_E5M2_FNUZ
=
3
,
// BF8
};
/**
* \brief Describes saturation behavior
*/
enum
class
ck_saturation_t
{
CK_NOSAT
=
0
,
// No saturation - replace with NaN or Inf
CK_SATFINITE
=
1
,
// Saturate to finite
};
namespace
fp8_impl
{
typedef
fp8_storage_t
fp8x2_storage_t
__attribute__
((
ext_vector_type
(
2
)));
typedef
float
float2_t
__attribute__
((
ext_vector_type
(
2
)));
__host__
__device__
static
inline
constexpr
bool
fnuz_f8_is_nan
(
f8_fnuz_t
a
)
{
return
static_cast
<
unsigned
char
>
(
a
)
==
0x80
;
}
__host__
__device__
static
inline
constexpr
bool
fnuz_bf8_is_nan
(
bf8_fnuz_t
a
)
{
return
static_cast
<
unsigned
char
>
(
a
)
==
0x80
;
}
__host__
__device__
static
inline
constexpr
bool
ocp_f8_is_nan
(
fp8_storage_t
a
)
{
return
(
a
&
0x7f
)
==
0x7f
;
}
__host__
__device__
static
inline
constexpr
bool
ocp_bf8_is_nan
(
fp8_storage_t
a
)
{
return
(
a
&
0x7f
)
>
0x7c
;
}
// The conversion function is from rocblas
// https://github.com/ROCm/rocBLAS/blob/9b7f692abe3c54b88d1e77e045a7db7f1f188b69/library/include/internal/rocblas_hip_f8_impl.h#L220
// This has been modified to handle double types as well
template
<
typename
T
,
int
wm
,
int
we
,
bool
is_fnuz
,
bool
clip
=
false
>
__host__
__device__
static
inline
T
cast_from_f8
(
fp8_storage_t
x
)
{
constexpr
bool
is_half
=
__hip_internal
::
is_same
<
T
,
_Float16
>::
value
;
constexpr
bool
is_float
=
__hip_internal
::
is_same
<
T
,
float
>::
value
;
constexpr
bool
is_double
=
__hip_internal
::
is_same
<
T
,
double
>::
value
;
static_assert
(
is_half
||
is_float
||
is_double
,
"only half, float and double are supported"
);
constexpr
int
weo
=
is_half
?
5
:
(
is_float
?
8
:
11
);
constexpr
int
wmo
=
is_half
?
10
:
(
is_float
?
23
:
52
);
T
fInf
,
fNegInf
,
fNaN
,
fNeg0
,
fmax
,
fmin
;
if
constexpr
(
is_half
)
{
const
unsigned
short
int
ihInf
=
0x7C00
;
const
unsigned
short
int
ihNegInf
=
0xFC00
;
const
unsigned
short
int
ihNaN
=
0x7C01
;
const
unsigned
short
int
ihNeg0
=
0x8000
;
/* Max number in e5m2 57344*/
const
unsigned
short
int
ifmax
=
0x7B00
;
const
unsigned
short
int
ifmin
=
0xFB00
;
fInf
=
bit_cast
<
_Float16
>
(
ihInf
);
fNegInf
=
bit_cast
<
_Float16
>
(
ihNegInf
);
fNaN
=
bit_cast
<
_Float16
>
(
ihNaN
);
fNeg0
=
bit_cast
<
_Float16
>
(
ihNeg0
);
fmax
=
bit_cast
<
_Float16
>
(
ifmax
);
fmin
=
bit_cast
<
_Float16
>
(
ifmin
);
}
else
if
constexpr
(
is_float
)
{
const
unsigned
int
ifInf
=
0x7F800000
;
const
unsigned
int
ifNegInf
=
0xFF800000
;
const
unsigned
int
ifNaN
=
0x7F800001
;
const
unsigned
int
ifNeg0
=
0x80000000
;
/* Max number in e5m2 57344*/
const
unsigned
int
ifmax
=
0x47600000
;
const
unsigned
int
ifmin
=
0xC7600000
;
fInf
=
bit_cast
<
float
>
(
ifInf
);
fNegInf
=
bit_cast
<
float
>
(
ifNegInf
);
fNaN
=
bit_cast
<
float
>
(
ifNaN
);
fNeg0
=
bit_cast
<
float
>
(
ifNeg0
);
fmax
=
bit_cast
<
float
>
(
ifmax
);
fmin
=
bit_cast
<
float
>
(
ifmin
);
}
else
if
constexpr
(
is_double
)
{
const
unsigned
long
long
ifInf
=
0x7FF0000000000000ull
;
const
unsigned
long
long
ifNegInf
=
0xFFF0000000000000ull
;
const
unsigned
long
long
ifNaN
=
0x7FF0000000000001ull
;
const
unsigned
long
long
ifNeg0
=
0x8000000000000000ull
;
/* Max number in e5m2 57344*/
const
unsigned
long
long
ifmax
=
0x40EC000000000000ull
;
const
unsigned
long
long
ifmin
=
0xC0EC000000000000ull
;
fInf
=
bit_cast
<
double
>
(
ifInf
);
fNegInf
=
bit_cast
<
double
>
(
ifNegInf
);
fNaN
=
bit_cast
<
double
>
(
ifNaN
);
fNeg0
=
bit_cast
<
double
>
(
ifNeg0
);
fmax
=
bit_cast
<
double
>
(
ifmax
);
fmin
=
bit_cast
<
double
>
(
ifmin
);
}
if
(
x
==
0
)
{
return
0
;
}
unsigned
long
long
sign
=
x
>>
7
;
unsigned
long
long
mantissa
=
x
&
((
1
<<
wm
)
-
1
);
int
exponent
=
(
x
&
0x7F
)
>>
wm
;
if
constexpr
(
is_fnuz
)
{
if
(
x
==
0x80
)
{
return
fNaN
;
}
}
else
{
if
(
x
==
0x80
)
{
return
fNeg0
;
}
if
constexpr
(
we
==
4
)
{
// e4m3
if
((
x
&
0x7F
)
==
0x7F
)
{
return
fNaN
;
}
}
else
if
((
x
&
0x7C
)
==
0x7C
)
{
// e5m2
if
((
x
&
0x3
)
==
0
)
{
if
constexpr
(
clip
)
{
return
sign
?
fmin
:
fmax
;
}
return
sign
?
fNegInf
:
fInf
;
}
return
fNaN
;
}
}
typename
std
::
conditional
<
sizeof
(
T
)
==
2
,
unsigned
short
int
,
typename
std
::
conditional
<
sizeof
(
T
)
==
4
,
unsigned
int
,
unsigned
long
long
>::
type
>::
type
retval
;
if
constexpr
(
we
==
5
&&
is_half
&&
!
is_fnuz
)
{
retval
=
x
<<
8
;
return
bit_cast
<
T
>
(
retval
);
}
const
int
exp_low_cutoff
=
(
1
<<
(
weo
-
1
))
-
(
1
<<
(
we
-
1
))
+
1
-
(
is_fnuz
?
1
:
0
);
// subnormal input
if
(
exponent
==
0
)
{
#if defined(__HIP_DEVICE_COMPILE__) && __HIP_DEVICE_COMPILE__
// guaranteed mantissa!=0 since cases 0x0 and 0x80 are handled above
int
sh
=
1
+
__clz
(
mantissa
)
-
(
32
-
wm
);
#else
int
sh
=
1
+
__builtin_clz
(
mantissa
)
-
(
32
-
wm
);
#endif
mantissa
<<=
sh
;
exponent
+=
1
-
sh
;
mantissa
&=
((
1ull
<<
wm
)
-
1
);
}
exponent
+=
exp_low_cutoff
-
1
;
mantissa
<<=
wmo
-
wm
;
// subnormal output (occurs when T=half, we=5, negative_zero_nan=true)
if
(
exponent
<=
0
)
{
mantissa
|=
1
<<
wmo
;
mantissa
>>=
1
-
exponent
;
exponent
=
0
;
}
if
constexpr
(
sizeof
(
T
)
==
2
)
retval
=
(
sign
<<
15
)
|
(
exponent
<<
10
)
|
mantissa
;
else
if
constexpr
(
sizeof
(
T
)
==
4
)
retval
=
(
sign
<<
31
)
|
(
exponent
<<
23
)
|
mantissa
;
else
retval
=
(
sign
<<
63
)
|
(
static_cast
<
unsigned
long
long
>
(
exponent
)
<<
52
)
|
mantissa
;
return
bit_cast
<
T
>
(
retval
);
}
#if CK_FP8_CVT_FAST_PATH
template
<
ck_fp8_interpretation_t
interpret
>
static
__device__
float
cast_to_f32_from_f8
(
fp8_storage_t
v
)
{
union
{
unsigned
int
i32val
;
unsigned
char
i8val
[
4
];
}
val
;
val
.
i8val
[
0
]
=
v
;
static_assert
(
interpret
==
ck_fp8_interpretation_t
::
CK_E4M3_FNUZ
||
interpret
==
ck_fp8_interpretation_t
::
CK_E4M3_OCP
||
interpret
==
ck_fp8_interpretation_t
::
CK_E5M2_FNUZ
||
interpret
==
ck_fp8_interpretation_t
::
CK_E5M2_OCP
,
"Only FNUZ and OCP interpretations are supported"
);
if
constexpr
((
interpret
==
ck_fp8_interpretation_t
::
CK_E4M3_FNUZ
)
||
(
interpret
==
ck_fp8_interpretation_t
::
CK_E4M3_OCP
))
{
return
__builtin_amdgcn_cvt_f32_fp8
(
val
.
i32val
,
0
);
}
else
{
return
__builtin_amdgcn_cvt_f32_bf8
(
val
.
i32val
,
0
);
}
}
template
<
ck_fp8_interpretation_t
interpret
>
static
__device__
float2_t
cast_to_f32x2_from_f8x2
(
fp8x2_storage_t
v
)
{
const
auto
i16val
=
bit_cast
<
uint16_t
>
(
v
);
static_assert
(
interpret
==
ck_fp8_interpretation_t
::
CK_E4M3_FNUZ
||
interpret
==
ck_fp8_interpretation_t
::
CK_E4M3_OCP
||
interpret
==
ck_fp8_interpretation_t
::
CK_E5M2_FNUZ
||
interpret
==
ck_fp8_interpretation_t
::
CK_E5M2_OCP
,
"Only FNUZ and OCP interpretations are supported"
);
if
constexpr
((
interpret
==
ck_fp8_interpretation_t
::
CK_E4M3_FNUZ
)
||
(
interpret
==
ck_fp8_interpretation_t
::
CK_E4M3_OCP
))
{
return
__builtin_amdgcn_cvt_pk_f32_fp8
(
i16val
,
false
);
}
else
{
return
__builtin_amdgcn_cvt_pk_f32_bf8
(
i16val
,
false
);
}
}
#endif
}
// namespace fp8_impl
struct
f8_ocp_t
{
using
data_type
=
fp8_storage_t
;
data_type
data
;
static
constexpr
ck_saturation_t
default_saturation
=
ck_saturation_t
::
CK_SATFINITE
;
static
constexpr
ck_fp8_interpretation_t
default_interpret
=
ck_fp8_interpretation_t
::
CK_E4M3_OCP
;
static
constexpr
unsigned
int
we
=
4
;
// exponent width
static
constexpr
unsigned
int
wm
=
3
;
// mantissa width
__host__
__device__
constexpr
bool
operator
==
(
const
f8_ocp_t
&
other
)
const
{
return
(
data
==
other
.
data
)
&&
(
fp8_impl
::
ocp_f8_is_nan
(
data
)
==
false
);
// NaN != NaN
}
#if CK_USE_OCP_FP8
__host__
__device__
explicit
operator
float
()
const
#else
__host__
explicit
operator
float
()
const
#endif
{
#if CK_OCP_FP8_CVT_FAST_PATH
return
fp8_impl
::
cast_to_f32_from_f8
<
default_interpret
>
(
this
->
data
);
#else
return
fp8_impl
::
cast_from_f8
<
float
,
wm
,
we
,
false
>
(
this
->
data
);
// XXX: clip==false must be consistent with operator _Float16
#endif
}
#if CK_USE_OCP_FP8
__host__
__device__
explicit
operator
_Float16
()
const
#else
__host__
explicit
operator
_Float16
()
const
#endif
{
#if CK_OCP_FP8_CVT_FAST_PATH
return
static_cast
<
_Float16
>
(
fp8_impl
::
cast_to_f32_from_f8
<
default_interpret
>
(
this
->
data
));
#else
return
fp8_impl
::
cast_from_f8
<
_Float16
,
wm
,
we
,
false
>
(
this
->
data
);
// XXX: clip==false must be consistent with operator float
#endif
}
};
struct
bf8_ocp_t
{
using
data_type
=
fp8_storage_t
;
data_type
data
;
static
constexpr
ck_saturation_t
default_saturation
=
ck_saturation_t
::
CK_SATFINITE
;
static
constexpr
ck_fp8_interpretation_t
default_interpret
=
ck_fp8_interpretation_t
::
CK_E5M2_OCP
;
static
constexpr
unsigned
int
we
=
5
;
// exponent width
static
constexpr
unsigned
int
wm
=
2
;
// mantissa width
__host__
__device__
constexpr
bool
operator
==
(
const
bf8_ocp_t
&
other
)
const
{
return
(
data
==
other
.
data
)
&&
(
fp8_impl
::
ocp_bf8_is_nan
(
data
)
==
false
);
// NaN != NaN
}
#if CK_USE_OCP_FP8
__host__
__device__
explicit
operator
float
()
const
#else
__host__
explicit
operator
float
()
const
#endif
{
#if defined(__gfx950__) || defined(__gfx1200__) || defined(__gfx1201__)
return
fp8_impl
::
cast_to_f32_from_f8
<
default_interpret
>
(
this
->
data
);
#else
return
fp8_impl
::
cast_from_f8
<
float
,
wm
,
we
,
false
>
(
this
->
data
);
// XXX: clip==false must be consistent with operator _Float16
#endif
}
#if CK_USE_OCP_FP8
__host__
__device__
explicit
operator
_Float16
()
const
#else
__host__
explicit
operator
_Float16
()
const
#endif
{
#if defined(__gfx950__) || defined(__gfx1200__) || defined(__gfx1201__)
return
static_cast
<
_Float16
>
(
fp8_impl
::
cast_to_f32_from_f8
<
default_interpret
>
(
this
->
data
));
#else
return
fp8_impl
::
cast_from_f8
<
_Float16
,
wm
,
we
,
false
>
(
this
->
data
);
// XXX: clip==false must be consistent with operator float
#endif
}
};
template
<
typename
T
>
__host__
__device__
static
inline
constexpr
bool
fp8_is_nan
(
T
);
template
<
>
__host__
__device__
inline
constexpr
bool
fp8_is_nan
(
f8_ocp_t
a
)
{
return
fp8_impl
::
ocp_f8_is_nan
(
a
.
data
);
}
template
<
>
__host__
__device__
inline
constexpr
bool
fp8_is_nan
(
bf8_ocp_t
a
)
{
return
fp8_impl
::
ocp_bf8_is_nan
(
a
.
data
);
}
template
<
>
__host__
__device__
inline
constexpr
bool
fp8_is_nan
(
f8_fnuz_t
a
)
{
return
fp8_impl
::
fnuz_f8_is_nan
(
a
);
}
template
<
>
__host__
__device__
inline
constexpr
bool
fp8_is_nan
(
bf8_fnuz_t
a
)
{
return
fp8_impl
::
fnuz_bf8_is_nan
(
a
);
}
template
<
typename
T
,
ck
::
enable_if_t
<
is_same_v
<
T
,
bf8_ocp_t
>
||
is_same_v
<
T
,
f8_ocp_t
>
||
is_same_v
<
T
,
bf8_fnuz_t
>
||
is_same_v
<
T
,
f8_fnuz_t
>
,
bool
>
=
true
>
__host__
__device__
static
inline
constexpr
bool
fp8_is_inf
(
T
)
{
return
false
;
}
template
<
>
__host__
__device__
inline
constexpr
bool
fp8_is_inf
(
bf8_ocp_t
a
)
{
return
(
a
.
data
&
0x7f
)
==
0x7c
;
}
namespace
fp8_impl
{
// Assertions to check for supported conversion types
#define __assert_ocp_support(interp) \
{ \
if(interp != ck_fp8_interpretation_t::CK_E4M3_OCP && \
interp != ck_fp8_interpretation_t::CK_E5M2_OCP) \
{ \
__hip_assert(false && "type is unsupported by current target device"); \
} \
}
#define __assert_fnuz_support(interp) \
{ \
if(interp != ck_fp8_interpretation_t::CK_E4M3_FNUZ && \
interp != ck_fp8_interpretation_t::CK_E5M2_FNUZ) \
{ \
__hip_assert(false && "type is unsupported by current target device"); \
} \
}
__host__
__device__
static
inline
void
__is_interpret_supported
([[
maybe_unused
]]
ck_fp8_interpretation_t
interp
)
{
#if defined(__HIP_DEVICE_COMPILE__) && __HIP_DEVICE_COMPILE__
#if CK_USE_OCP_FP8
__assert_ocp_support
(
interp
);
#endif
#if CK_USE_FNUZ_FP8
__assert_fnuz_support
(
interp
);
#endif
#endif
}
#if CK_FP8_CVT_FAST_PATH
// The conversion function is from rocblas
// https://github.com/ROCm/rocBLAS/blob/9b7f692abe3c54b88d1e77e045a7db7f1f188b69/library/include/internal/rocblas_float8.h#L79
template
<
ck_fp8_interpretation_t
interpret
,
bool
saturate
,
bool
stochastic_rounding
=
false
>
static
__device__
fp8_storage_t
cast_to_f8_from_f32
(
float
v
,
unsigned
int
rng
=
0
)
{
fp8_storage_t
i8data
;
union
{
float
fval
;
unsigned
int
i32val
;
unsigned
char
i8val
[
4
];
// NOTE: not endian independent
}
val
;
unsigned
int
ival
=
0
;
val
.
fval
=
v
;
if
constexpr
(
saturate
)
{
if
constexpr
(
interpret
==
ck_fp8_interpretation_t
::
CK_E4M3_FNUZ
)
{
if
((
val
.
i32val
&
0x7F800000
)
!=
0x7F800000
)
{
/// propagate NAN/INF, no clipping
val
.
fval
=
__builtin_amdgcn_fmed3f
(
val
.
fval
,
240.0
,
-
240.0
);
}
}
else
if
constexpr
(
interpret
==
ck_fp8_interpretation_t
::
CK_E4M3_OCP
)
{
// OCP type
if
((
val
.
i32val
&
0x7F800000
)
!=
0x7F800000
)
{
/// propagate NAN/INF, no clipping
val
.
fval
=
__builtin_amdgcn_fmed3f
(
val
.
fval
,
448.0
,
-
448.0
);
}
}
else
{
if
((
val
.
i32val
&
0x7F800000
)
!=
0x7F800000
)
{
/// propagate NAN/INF, no clipping
val
.
fval
=
__builtin_amdgcn_fmed3f
(
val
.
fval
,
57344.0
,
-
57344.0
);
}
}
}
if
constexpr
(
stochastic_rounding
)
{
ival
=
(
interpret
==
ck_fp8_interpretation_t
::
CK_E4M3_FNUZ
)
||
(
interpret
==
ck_fp8_interpretation_t
::
CK_E4M3_OCP
)
?
__builtin_amdgcn_cvt_sr_fp8_f32
(
val
.
fval
,
rng
,
ival
,
0
)
:
__builtin_amdgcn_cvt_sr_bf8_f32
(
val
.
fval
,
rng
,
ival
,
0
);
// 0 pos
val
.
i32val
=
ival
;
i8data
=
val
.
i8val
[
0
];
// little endian
}
else
{
// RNE CVT
ival
=
(
interpret
==
ck_fp8_interpretation_t
::
CK_E4M3_FNUZ
)
||
(
interpret
==
ck_fp8_interpretation_t
::
CK_E4M3_OCP
)
?
__builtin_amdgcn_cvt_pk_fp8_f32
(
val
.
fval
,
val
.
fval
,
ival
,
false
)
:
__builtin_amdgcn_cvt_pk_bf8_f32
(
val
.
fval
,
val
.
fval
,
ival
,
false
);
// false -> WORD0
val
.
i32val
=
ival
;
i8data
=
val
.
i8val
[
0
];
}
return
i8data
;
}
#endif // CK_FP8_CVT_FAST_PATH
// The conversion function is from rocblas
// https://github.com/ROCm/rocBLAS/blob/9b7f692abe3c54b88d1e77e045a7db7f1f188b69/library/include/internal/rocblas_hip_f8_impl.h#L39
// This has been modified to add double types conversion as well
template
<
typename
T
,
int
wm
,
int
we
,
bool
is_fnuz
,
bool
clip
=
false
,
bool
stoch
=
false
>
__host__
__device__
static
inline
fp8_storage_t
cast_to_f8
(
T
_x
,
unsigned
int
rng
=
0
)
{
constexpr
bool
is_half
=
__hip_internal
::
is_same
<
T
,
_Float16
>::
value
;
constexpr
bool
is_float
=
__hip_internal
::
is_same
<
T
,
float
>::
value
;
constexpr
bool
is_double
=
__hip_internal
::
is_same
<
T
,
double
>::
value
;
static_assert
(
is_half
||
is_float
||
is_double
,
"Only half, float and double can be cast to f8"
);
constexpr
int
mfmt
=
(
sizeof
(
T
)
==
8
)
?
52
:
((
sizeof
(
T
)
==
4
)
?
23
:
10
);
using
T_bitwise
=
typename
std
::
conditional
<
sizeof
(
T
)
==
2
,
unsigned
short
int
,
typename
std
::
conditional
<
sizeof
(
T
)
==
4
,
unsigned
int
,
unsigned
long
long
>::
type
>::
type
;
T_bitwise
x_bitwise
=
bit_cast
<
T_bitwise
>
(
_x
);
unsigned
long
long
x
{
x_bitwise
};
unsigned
long
long
head
,
mantissa
;
int
exponent
,
bias
;
unsigned
int
sign
;
unsigned
long
long
fInf
,
mask
;
if
constexpr
(
sizeof
(
T
)
==
8
)
{
head
=
x
&
0xFFF0000000000000ull
;
mantissa
=
x
&
0xFFFFFFFFFFFFFull
;
exponent
=
(
head
>>
52
)
&
0x7FF
;
sign
=
head
>>
63
;
bias
=
1023
;
fInf
=
0x7FF0000000000000ull
;
mask
=
0x7FFFFFFFFFFFFFFFull
;
}
else
if
constexpr
(
sizeof
(
T
)
==
4
)
{
head
=
x
&
0xFF800000
;
mantissa
=
x
&
0x7FFFFF
;
exponent
=
(
head
>>
23
)
&
0xFF
;
sign
=
head
>>
31
;
bias
=
127
;
fInf
=
0x7F800000
;
mask
=
0x7FFFFFFF
;
}
else
{
head
=
x
&
0xFC00
;
mantissa
=
x
&
0x3FF
;
exponent
=
(
head
>>
10
)
&
0x1F
;
sign
=
head
>>
15
;
bias
=
15
;
fInf
=
0x7C00
;
mask
=
0x7FFF
;
}
unsigned
int
signed_inf
=
0
;
unsigned
int
nan
=
0
;
if
constexpr
(
is_fnuz
)
{
signed_inf
=
clip
?
((
sign
<<
7
)
+
0x7f
)
:
0x80
;
nan
=
0x80
;
}
else
{
if
constexpr
(
we
==
4
)
{
// e4m3
signed_inf
=
(
sign
<<
7
)
+
(
clip
?
0x7e
:
0x7f
);
}
else
{
// e5m2
signed_inf
=
(
sign
<<
7
)
+
(
clip
?
0x7b
:
0x7c
);
}
nan
=
(
sign
<<
7
)
+
0x7f
;
}
// Max values
unsigned
long
long
ifmax
=
0
;
if
constexpr
(
sizeof
(
T
)
==
8
)
{
if
constexpr
(
we
==
5
)
{
// 57344
ifmax
=
0x40EC000000000000ull
;
}
else
{
if
constexpr
(
is_fnuz
)
{
// 240
ifmax
=
0x406E000000000000ull
;
}
else
{
// 448
ifmax
=
0x407C000000000000ull
;
}
}
}
else
if
(
sizeof
(
T
)
==
4
)
{
if
constexpr
(
we
==
5
)
{
ifmax
=
0x47600000
;
}
else
{
if
constexpr
(
is_fnuz
)
{
ifmax
=
0x43700000
;
}
else
{
ifmax
=
0x43E00000
;
}
}
}
else
{
if
constexpr
(
we
==
5
)
{
ifmax
=
0x7B00
;
}
else
{
if
constexpr
(
is_fnuz
)
{
ifmax
=
0x5B80
;
}
else
{
ifmax
=
0x5F00
;
}
}
}
// Deal with inf and NaNs
if
((
x
&
fInf
)
==
fInf
)
{
if
constexpr
(
is_fnuz
)
return
signed_inf
;
return
mantissa
!=
0
?
nan
:
signed_inf
;
}
if
((
x
&
mask
)
>
ifmax
)
{
return
signed_inf
;
}
if
(
x
==
0
)
{
return
0
;
}
// First need to check if it is normal or denorm as there is a difference of
// implicit 1 Then need to adjust the exponent to align with the F8 exponent,
// in the meanwhile, shift The mantissa. Then for stochastic rounding, add rng
// to mantissa and truncate. And for RNE, no need to add rng. Then probably
// need to check whether there is carry and adjust exponent and mantissa again
// For IEEE bias mode, the bias is 2^(k-1) -1 where k is the width of exponent
// bits
const
int
f8_bias
=
(
1
<<
(
we
-
1
))
-
1
+
(
is_fnuz
?
1
:
0
);
const
int
f8_denormal_act_exponent
=
1
-
f8_bias
;
// actual exponent of f8 denormal
// act_exponent is the actual exponent of fp32/fp16 (after subtracting bias)
// f8_exponent is the converted f8 exponent with bias encoding
// exponent_diff is the diff between fp32/fp16 exponent and f8 exponent,
// the difference needs to be adjusted and mantissa shifted
int
act_exponent
,
f8_exponent
,
exponent_diff
;
if
(
exponent
==
0
)
{
// fp32/fp16 is in denormal.
/* fp32 denormal is below 2^-127 so it is usually not a concern here, we
mostly concern fp16 here. In this case, f8 is usually in denormal. But there
could be exceptions. fp16 denormal has exponent bias 15 while bf8 with NANOO has
exponent bias 16. It means that there are some numbers in fp16 denormal but they
are bf8 (NANOO) normals - smallest bf8 (NANOO) normal is 2^-15. fp16 numbers
where exponent==0 (actual exponent -14) and highest bit of mantissa is 1 are bf8
(NANOO) normal. In this case, the fp16 mantissa should be shift left by 1 */
act_exponent
=
exponent
-
bias
+
1
;
exponent_diff
=
f8_denormal_act_exponent
-
act_exponent
;
// actual exponent is exponent-bias+1 as it is denormal
}
else
{
// fp32/fp16 is normal with implicit 1
act_exponent
=
exponent
-
bias
;
if
(
act_exponent
<=
f8_denormal_act_exponent
)
{
/* This is the case where fp32/fp16 is normal but it is in f8 denormal
range. For example fp8 nanoo mode, denormal exponent is -7, but if the fp32/fp16
actual exponent is -7, it is actually larger due to the implicit 1,
Therefore it needs to be adjust to -6 and mantissa shift right by 1.
So for fp32/fp16, exponent -8 is the cut point to convert to fp8 nanoo */
exponent_diff
=
f8_denormal_act_exponent
-
act_exponent
;
}
else
{
// both fp32/fp16 and f8 are in normal range
exponent_diff
=
0
;
// exponent_diff=0 does not mean there is no difference
// for this case, act_exponent could be larger. Just
// that it does not need shift mantissa
}
mantissa
+=
(
1ull
<<
mfmt
);
// Add the implicit 1 into mantissa
}
bool
midpoint
=
(
mantissa
&
((
1ull
<<
(
mfmt
-
wm
+
exponent_diff
))
-
1
))
==
(
1ull
<<
(
mfmt
-
wm
+
exponent_diff
-
1
));
/* This part is a bit tricky. The judgment of whether it is a tie needs to be
done before we shift right as shift right could rip off some residual part and
make something not midpoint look like midpoint. For example, the fp16 number
0x1002 (0 00100 0000000010), it is larger than midpoint, but after shift right
by 4 bits, it would look like midpoint.
*/
if
(
exponent_diff
>
0
)
mantissa
>>=
exponent_diff
;
else
if
(
exponent_diff
==
-
1
)
mantissa
<<=
-
exponent_diff
;
bool
implicit_one
=
mantissa
&
(
1ull
<<
mfmt
);
// if there is no implicit 1, it means the f8 is denormal and need to adjust
// to denorm exponent
f8_exponent
=
(
act_exponent
+
exponent_diff
)
/*actual f8 exponent*/
+
f8_bias
-
(
implicit_one
?
0
:
1
);
// Now we have the exponent and mantissa adjusted
unsigned
long
long
drop_mask
=
(
1ull
<<
(
mfmt
-
wm
))
-
1
;
bool
odd
=
mantissa
&
(
1ull
<<
(
mfmt
-
wm
));
// if the least significant bit that is not truncated is 1
mantissa
+=
(
stoch
?
rng
:
(
midpoint
?
(
odd
?
mantissa
:
mantissa
-
1ull
)
:
mantissa
))
&
drop_mask
;
// Now we deal with overflow
if
(
f8_exponent
==
0
)
{
if
((
1ull
<<
mfmt
)
&
mantissa
)
{
f8_exponent
=
1
;
// denormal overflow to become normal, promote exponent
}
}
else
{
if
((
1ull
<<
(
mfmt
+
1
))
&
mantissa
)
{
mantissa
>>=
1
;
f8_exponent
++
;
}
}
mantissa
>>=
(
mfmt
-
wm
);
// above range: quantize to maximum possible float of the same sign
const
int
max_exp
=
(
1
<<
we
)
-
1
;
if
(
f8_exponent
>
max_exp
)
{
if
constexpr
(
clip
)
{
mantissa
=
(
1
<<
wm
)
-
1
;
f8_exponent
=
max_exp
;
}
else
{
return
signed_inf
;
}
}
if
(
f8_exponent
==
0
&&
mantissa
==
0
)
return
is_fnuz
?
0
:
(
sign
<<
7
);
mantissa
&=
(
1
<<
wm
)
-
1
;
return
(
sign
<<
7
)
|
(
f8_exponent
<<
wm
)
|
mantissa
;
}
/**
* \brief convert float to @p fp8_storage_t
*
* \tparam interp interpretation of fp8
* \tparam sat saturation of fp8
* \param f float number
* \return fp8_storage_t
*/
template
<
ck_fp8_interpretation_t
interp
,
ck_saturation_t
sat
=
ck_saturation_t
::
CK_SATFINITE
,
bool
stochastic_rounding
=
false
>
#if CK_FP8_CVT_FAST_PATH
__host__
__device__
static
inline
fp8_storage_t
cvt_float_to_fp8
(
const
float
f
)
{
__is_interpret_supported
(
interp
);
uint32_t
rng
=
0
;
if
constexpr
(
stochastic_rounding
)
{
constexpr
int
seed
=
1254739
;
#ifndef CK_CODE_GEN_RTC
rng
=
prand_generator
<
float
,
seed
>
(
reinterpret_cast
<
uintptr_t
>
(
&
f
),
f
);
#else
rng
=
prand_generator
<
float
,
seed
>
(
reinterpret_cast
<
size_t
>
(
&
f
),
f
);
#endif
}
return
cast_to_f8_from_f32
<
interp
,
sat
==
ck_saturation_t
::
CK_SATFINITE
,
stochastic_rounding
>
(
f
,
rng
);
#else
#if CK_USE_OCP_FP8
__host__
__device__
static
inline
fp8_storage_t
cvt_float_to_fp8
(
const
float
f
)
{
#else
__host__
static
inline
fp8_storage_t
cvt_float_to_fp8
(
const
float
f
)
{
#endif
uint32_t
rng
=
0
;
if
constexpr
(
stochastic_rounding
)
{
constexpr
int
seed
=
1254739
;
#ifndef CK_CODE_GEN_RTC
rng
=
prand_generator
<
float
,
seed
>
(
reinterpret_cast
<
uintptr_t
>
(
&
f
),
f
);
#else
rng
=
prand_generator
<
float
,
seed
>
(
reinterpret_cast
<
size_t
>
(
&
f
),
f
);
#endif
}
if
constexpr
(
interp
==
ck_fp8_interpretation_t
::
CK_E4M3_FNUZ
)
{
return
cast_to_f8
<
float
,
3
,
4
,
true
,
sat
==
ck_saturation_t
::
CK_SATFINITE
,
stochastic_rounding
>
(
f
,
rng
);
}
else
if
constexpr
(
interp
==
ck_fp8_interpretation_t
::
CK_E5M2_FNUZ
)
{
return
cast_to_f8
<
float
,
2
,
5
,
true
,
sat
==
ck_saturation_t
::
CK_SATFINITE
,
stochastic_rounding
>
(
f
,
rng
);
}
else
if
constexpr
(
interp
==
ck_fp8_interpretation_t
::
CK_E4M3_OCP
)
{
return
cast_to_f8
<
float
,
3
,
4
,
false
,
sat
==
ck_saturation_t
::
CK_SATFINITE
,
stochastic_rounding
>
(
f
,
rng
);
}
else
if
constexpr
(
interp
==
ck_fp8_interpretation_t
::
CK_E5M2_OCP
)
{
return
cast_to_f8
<
float
,
2
,
5
,
false
,
sat
==
ck_saturation_t
::
CK_SATFINITE
,
stochastic_rounding
>
(
f
,
rng
);
}
else
{
__hip_assert
(
false
&&
"FP8 type is not supported by current target device"
);
return
0
;
}
#endif // CK_FP8_CVT_FAST_PATH
}
/**
* \brief convert _Float16 to @p fp8_storage_t
*
* \tparam sat saturation of fp8
* \tparam interp interpretation of fp8
* \tparam stochastic_rounding switch between RNE and SR
* \param x _Float16 value
* \return fp8_storage_t
*/
template
<
ck_fp8_interpretation_t
interp
,
ck_saturation_t
sat
=
ck_saturation_t
::
CK_SATFINITE
,
bool
stochastic_rounding
=
false
>
#if CK_FP8_CVT_FAST_PATH || CK_USE_OCP_FP8
__host__
__device__
static
inline
fp8_storage_t
cvt_half_t_to_fp8
(
const
_Float16
x
)
#else
__host__
static
inline
fp8_storage_t
cvt_half_t_to_fp8
(
const
_Float16
x
)
#endif
{
return
cvt_float_to_fp8
<
interp
,
sat
,
stochastic_rounding
>
(
static_cast
<
float
>
(
x
));
}
}
// namespace fp8_impl
// Declare a template function for fp8 conversion using RNE
template
<
typename
Y
,
typename
X
>
__host__
__device__
constexpr
Y
f8_convert_rne
(
X
x
);
// convert fp32 to fp8 with rounding to nearest even
template
<
>
inline
__host__
__device__
f8_ocp_t
f8_convert_rne
<
f8_ocp_t
,
float
>
(
float
x
)
{
return
f8_ocp_t
{
fp8_impl
::
cvt_float_to_fp8
<
f8_ocp_t
::
default_interpret
,
f8_ocp_t
::
default_saturation
>
(
x
)};
}
// convert fp32 to bf8 with rounding to nearest even
template
<
>
inline
__host__
__device__
bf8_ocp_t
f8_convert_rne
<
bf8_ocp_t
,
float
>
(
float
x
)
{
return
bf8_ocp_t
{
fp8_impl
::
cvt_float_to_fp8
<
bf8_ocp_t
::
default_interpret
,
bf8_ocp_t
::
default_saturation
>
(
x
)};
}
// convert _Float16 to fp8 with rounding to nearest even
template
<
>
inline
__host__
__device__
f8_ocp_t
f8_convert_rne
<
f8_ocp_t
,
_Float16
>
(
_Float16
x
)
{
return
f8_ocp_t
{
fp8_impl
::
cvt_half_t_to_fp8
<
f8_ocp_t
::
default_interpret
,
f8_ocp_t
::
default_saturation
>
(
x
)};
}
template
<
>
inline
__host__
__device__
bf8_ocp_t
f8_convert_rne
<
bf8_ocp_t
,
_Float16
>
(
_Float16
x
)
{
return
bf8_ocp_t
{
fp8_impl
::
cvt_half_t_to_fp8
<
bf8_ocp_t
::
default_interpret
,
bf8_ocp_t
::
default_saturation
>
(
x
)};
}
// Declare a template function for fp8 conversion using RNE
template
<
typename
Y
,
typename
X
>
__host__
__device__
constexpr
Y
f8_convert_sr
(
X
x
);
// convert fp32 to fp8 with stochastic rounding
template
<
>
inline
__host__
__device__
f8_ocp_t
f8_convert_sr
<
f8_ocp_t
,
float
>
(
float
x
)
{
return
f8_ocp_t
{
fp8_impl
::
cvt_float_to_fp8
<
f8_ocp_t
::
default_interpret
,
f8_ocp_t
::
default_saturation
,
true
>
(
x
)};
}
// convert fp32 to bf8 with stochastic rounding
template
<
>
inline
__host__
__device__
bf8_ocp_t
f8_convert_sr
<
bf8_ocp_t
,
float
>
(
float
x
)
{
return
bf8_ocp_t
{
fp8_impl
::
cvt_float_to_fp8
<
bf8_ocp_t
::
default_interpret
,
bf8_ocp_t
::
default_saturation
,
true
>
(
x
)};
}
// convert _Float16 to fp8 with stochastic rounding
template
<
>
inline
__host__
__device__
f8_ocp_t
f8_convert_sr
<
f8_ocp_t
,
_Float16
>
(
_Float16
x
)
{
return
f8_ocp_t
{
fp8_impl
::
cvt_half_t_to_fp8
<
f8_ocp_t
::
default_interpret
,
f8_ocp_t
::
default_saturation
,
true
>
(
x
)};
}
// convert _Float16 to bf8 with stochastic rounding
template
<
>
inline
__host__
__device__
bf8_ocp_t
f8_convert_sr
<
bf8_ocp_t
,
_Float16
>
(
_Float16
x
)
{
return
bf8_ocp_t
{
fp8_impl
::
cvt_half_t_to_fp8
<
bf8_ocp_t
::
default_interpret
,
bf8_ocp_t
::
default_saturation
,
true
>
(
x
)};
}
#if CK_USE_OCP_FP8
using
f8_t
=
f8_ocp_t
;
using
bf8_t
=
bf8_ocp_t
;
#define CK_FP8_TYPE_FNUZ 0
#define CK_FP8_TYPE_OCP 1
#else
using
f8_t
=
f8_fnuz_t
;
using
bf8_t
=
bf8_fnuz_t
;
#define CK_FP8_TYPE_FNUZ 1
#define CK_FP8_TYPE_OCP 0
#endif
}
// namespace ck
include/ck/utility/amd_inline_asm.hpp
View file @
f1e53807
...
...
@@ -4,13 +4,34 @@
#ifndef CK_AMD_INLINE_ASM_HPP
#define CK_AMD_INLINE_ASM_HPP
#include "data_type.hpp"
#include "c_style_pointer_cast.hpp"
#include "data_type.hpp"
// TODO: deprecate all amd_assembly_outer_product_xxx
namespace
ck
{
inline
__device__
int
amd_assembly_and_or_b32
(
int
a
,
int
b
,
int
d
)
{
int
c
;
asm
volatile
(
"v_and_or_b32 %0, %1, %2, %3"
:
"=v"
(
c
)
:
"v"
(
a
),
"v"
(
b
),
"v"
(
d
));
return
c
;
}
inline
__device__
half2_t
amd_assembly_pk_fma_f16
(
half2_t
a
,
half2_t
b
,
half2_t
c
)
{
half2_t
d
;
asm
volatile
(
"v_pk_fma_f16 %0, %1, %2, %3"
:
"=v"
(
d
)
:
"v"
(
a
),
"v"
(
b
),
"v"
(
c
));
return
d
;
}
inline
__device__
half2_t
amd_assembly_pk_add_f16
(
half2_t
a
,
half2_t
b
)
{
half2_t
c
;
asm
volatile
(
"v_pk_add_f16 %0, %1, %2"
:
"=v"
(
c
)
:
"v"
(
a
),
"v"
(
b
));
return
c
;
}
// c0 += inner_product(a, b0)
// c1 += inner_product(a, b1)
__device__
void
amd_assembly_outer_product_1x2
(
float
a
,
float
b0
,
float
b1
,
float
&
c0
,
float
&
c1
)
...
...
include/ck/utility/amd_wave_read_first_lane.hpp
View file @
f1e53807
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-202
4
, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-202
5
, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
...
...
@@ -7,10 +7,12 @@
#include "ck/utility/functional2.hpp"
#include "ck/utility/math.hpp"
#ifndef CK_CODE_GEN_RTC
#include <array>
#include <cstddef>
#include <cstdint>
#include <type_traits>
#endif
namespace
ck
{
namespace
detail
{
...
...
@@ -37,7 +39,7 @@ struct get_carrier<3>
{
using
value_type
=
uint32_t
;
std
::
a
rray
<
std
::
byte
,
3
>
bytes
;
A
rray
<
ck
::
byte
,
3
>
bytes
;
static_assert
(
sizeof
(
bytes
)
<=
sizeof
(
value_type
));
// replacement of host std::copy_n()
...
...
@@ -61,22 +63,22 @@ struct get_carrier<3>
// method to trigger template substitution failure
__device__
carrier
(
const
carrier
&
other
)
noexcept
{
copy_n
(
other
.
bytes
.
begin
(),
bytes
.
s
ize
(),
bytes
.
begin
());
copy_n
(
other
.
bytes
.
begin
(),
bytes
.
S
ize
(),
bytes
.
begin
());
}
public:
__device__
carrier
&
operator
=
(
value_type
value
)
noexcept
{
copy_n
(
reinterpret_cast
<
const
std
::
byte
*>
(
&
value
),
bytes
.
s
ize
(),
bytes
.
begin
());
copy_n
(
reinterpret_cast
<
const
ck
::
byte
*>
(
&
value
),
bytes
.
S
ize
(),
bytes
.
begin
());
return
*
this
;
}
__device__
operator
value_type
()
const
noexcept
{
std
::
byte
result
[
sizeof
(
value_type
)];
ck
::
byte
result
[
sizeof
(
value_type
)];
copy_n
(
bytes
.
begin
(),
bytes
.
s
ize
(),
result
);
copy_n
(
bytes
.
begin
(),
bytes
.
S
ize
(),
result
);
return
*
reinterpret_cast
<
const
value_type
*>
(
result
);
}
...
...
@@ -109,8 +111,8 @@ __device__ inline int64_t amd_wave_read_first_lane(int64_t value)
{
constexpr
unsigned
object_size
=
sizeof
(
int64_t
);
constexpr
unsigned
second_part_offset
=
object_size
/
2
;
auto
*
const
from_obj
=
reinterpret_cast
<
const
std
::
byte
*>
(
&
value
);
alignas
(
int64_t
)
std
::
byte
to_obj
[
object_size
];
auto
*
const
from_obj
=
reinterpret_cast
<
const
ck
::
byte
*>
(
&
value
);
alignas
(
int64_t
)
ck
::
byte
to_obj
[
object_size
];
using
Sgpr
=
uint32_t
;
...
...
@@ -122,17 +124,16 @@ __device__ inline int64_t amd_wave_read_first_lane(int64_t value)
return
*
reinterpret_cast
<
int64_t
*>
(
to_obj
);
}
template
<
typename
Object
,
typename
=
std
::
enable_if_t
<
std
::
is_class_v
<
Object
>
&&
std
::
is_trivially_copyable_v
<
Object
>>>
template
<
typename
Object
,
typename
=
ck
::
enable_if_t
<
ck
::
is_class_v
<
Object
>
&&
ck
::
is_trivially_copyable_v
<
Object
>>>
__device__
auto
amd_wave_read_first_lane
(
const
Object
&
obj
)
{
using
Size
=
unsigned
;
constexpr
Size
SgprSize
=
4
;
constexpr
Size
ObjectSize
=
sizeof
(
Object
);
auto
*
const
from_obj
=
reinterpret_cast
<
const
std
::
byte
*>
(
&
obj
);
alignas
(
Object
)
std
::
byte
to_obj
[
ObjectSize
];
auto
*
const
from_obj
=
reinterpret_cast
<
const
ck
::
byte
*>
(
&
obj
);
alignas
(
Object
)
ck
::
byte
to_obj
[
ObjectSize
];
constexpr
Size
RemainedSize
=
ObjectSize
%
SgprSize
;
constexpr
Size
CompleteSgprCopyBoundary
=
ObjectSize
-
RemainedSize
;
...
...
include/ck/utility/amd_xdlops.hpp
View file @
f1e53807
...
...
@@ -4,8 +4,8 @@
#pragma once
namespace
ck
{
// Define the common macro for
gfx94x
models
#if defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__)
// Define the common macro for
MI300
models
#if defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__)
|| defined(__gfx950__)
#define __gfx94__
#endif
...
...
@@ -134,6 +134,46 @@ struct intrin_mfma_f32_32x32x4f16<32, 64>
}
};
template
<
index_t
MPerWave
,
index_t
NPerWave
>
struct
intrin_mfma_f32_32x32x16f16
;
template
<
>
struct
intrin_mfma_f32_32x32x16f16
<
32
,
32
>
{
template
<
class
FloatC
>
__device__
static
void
Run
(
const
half8_t
&
reg_a
,
const
half8_t
&
reg_b
,
FloatC
&
reg_c
)
{
#if defined(__gfx950__)
reg_c
.
template
AsType
<
float16_t
>()(
Number
<
0
>
{})
=
__builtin_amdgcn_mfma_f32_32x32x16_f16
(
reg_a
,
reg_b
,
reg_c
.
template
AsType
<
float16_t
>()[
Number
<
0
>
{}],
0
,
0
,
0
);
#else
ignore
=
reg_a
;
ignore
=
reg_b
;
ignore
=
reg_c
;
#endif // defined(__gfx950__)
}
};
template
<
index_t
MPerWave
,
index_t
NPerWave
>
struct
intrin_mfma_f32_16x16x32f16
;
template
<
>
struct
intrin_mfma_f32_16x16x32f16
<
16
,
16
>
{
template
<
class
FloatC
>
__device__
static
void
Run
(
const
half8_t
&
reg_a
,
const
half8_t
&
reg_b
,
FloatC
&
reg_c
)
{
#if defined(__gfx950__)
reg_c
.
template
AsType
<
float4_t
>()(
Number
<
0
>
{})
=
__builtin_amdgcn_mfma_f32_16x16x32_f16
(
reg_a
,
reg_b
,
reg_c
.
template
AsType
<
float4_t
>()[
Number
<
0
>
{}],
0
,
0
,
0
);
#else
ignore
=
reg_a
;
ignore
=
reg_b
;
ignore
=
reg_c
;
#endif // defined(__gfx950__)
}
};
template
<
index_t
MPerWave
,
index_t
NPerWave
>
struct
intrin_mfma_f32_32x32x8f16
;
...
...
@@ -204,6 +244,46 @@ struct intrin_mfma_f32_4x4x4f16<8, 64>
};
// bfp16
template
<
index_t
MPerWave
,
index_t
NPerWave
>
struct
intrin_mfma_f32_32x32x16bf16
;
template
<
>
struct
intrin_mfma_f32_32x32x16bf16
<
32
,
32
>
{
template
<
class
FloatC
>
__device__
static
void
Run
(
const
bhalf8_t
&
reg_a
,
const
bhalf8_t
&
reg_b
,
FloatC
&
reg_c
)
{
#if defined(__gfx950__)
reg_c
.
template
AsType
<
float16_t
>()(
Number
<
0
>
{})
=
__builtin_amdgcn_mfma_f32_32x32x16_bf16
(
reg_a
,
reg_b
,
reg_c
.
template
AsType
<
float16_t
>()[
Number
<
0
>
{}],
0
,
0
,
0
);
#else
ignore
=
reg_a
;
ignore
=
reg_b
;
ignore
=
reg_c
;
#endif // defined(__gfx950__)
}
};
template
<
index_t
MPerWave
,
index_t
NPerWave
>
struct
intrin_mfma_f32_16x16x32bf16
;
template
<
>
struct
intrin_mfma_f32_16x16x32bf16
<
16
,
16
>
{
template
<
class
FloatC
>
__device__
static
void
Run
(
const
bhalf8_t
&
reg_a
,
const
bhalf8_t
&
reg_b
,
FloatC
&
reg_c
)
{
#if defined(__gfx950__)
reg_c
.
template
AsType
<
float4_t
>()(
Number
<
0
>
{})
=
__builtin_amdgcn_mfma_f32_16x16x32_bf16
(
reg_a
,
reg_b
,
reg_c
.
template
AsType
<
float4_t
>()[
Number
<
0
>
{}],
0
,
0
,
0
);
#else
ignore
=
reg_a
;
ignore
=
reg_b
;
ignore
=
reg_c
;
#endif // defined(__gfx950__)
}
};
template
<
index_t
MPerWave
,
index_t
NPerWave
>
struct
intrin_mfma_f32_32x32x8bf16_1k
;
...
...
@@ -298,6 +378,46 @@ struct intrin_mfma_i32_16x16x16i8<16, 16>
}
};
template
<
index_t
MPerWave
,
index_t
NPerWave
>
struct
intrin_mfma_i32_32x32x32i8
;
template
<
>
struct
intrin_mfma_i32_32x32x32i8
<
32
,
32
>
{
template
<
class
FloatC
>
__device__
static
void
Run
(
const
int8x16_t
&
reg_a
,
const
int8x16_t
&
reg_b
,
FloatC
&
reg_c
)
{
#if defined(__gfx950__)
reg_c
.
template
AsType
<
int32x16_t
>()(
Number
<
0
>
{})
=
__builtin_amdgcn_mfma_i32_32x32x32_i8
(
reg_a
,
reg_b
,
reg_c
.
template
AsType
<
int32x16_t
>()[
Number
<
0
>
{}],
0
,
0
,
0
);
#else
ignore
=
reg_a
;
ignore
=
reg_b
;
ignore
=
reg_c
;
#endif // defined(__gfx950__)
}
};
template
<
index_t
MPerWave
,
index_t
NPerWave
>
struct
intrin_mfma_i32_16x16x64i8
;
template
<
>
struct
intrin_mfma_i32_16x16x64i8
<
16
,
16
>
{
template
<
class
FloatC
>
__device__
static
void
Run
(
const
int8x16_t
&
reg_a
,
const
int8x16_t
&
reg_b
,
FloatC
&
reg_c
)
{
#if defined(__gfx950__)
reg_c
.
template
AsType
<
int32x4_t
>()(
Number
<
0
>
{})
=
__builtin_amdgcn_mfma_i32_16x16x64_i8
(
reg_a
,
reg_b
,
reg_c
.
template
AsType
<
int32x4_t
>()[
Number
<
0
>
{}],
0
,
0
,
0
);
#else
ignore
=
reg_a
;
ignore
=
reg_b
;
ignore
=
reg_c
;
#endif // defined(__gfx950__)
}
};
template
<
index_t
MPerWave
,
index_t
NPerWave
>
struct
intrin_mfma_i32_32x32x16i8
;
...
...
@@ -356,6 +476,149 @@ struct intrin_mfma_f64_16x16x4f64<16, 16>
}
};
template
<
index_t
MPerWave
,
index_t
NPerWave
>
struct
intrin_mfma_f32_32x32x64f8f6f4
;
/// @brief Performs a matrix fused multiply-accumulate operation on 32x32x64 submatrices for f8, f6,
/// and f4 data types.
///
/// @note Calls scaled version of the instruction as the original instruction is not supported in
/// the backend. That is the intended use. There is a backend optimization to select the unscaled
/// operation if the scale is 0.
template
<
>
struct
intrin_mfma_f32_32x32x64f8f6f4
<
32
,
32
>
{
template
<
class
FloatC
>
__device__
static
void
Run
(
const
f8x32_t
&
reg_a
,
const
f8x32_t
&
reg_b
,
FloatC
&
reg_c
)
{
#if defined(__gfx950__)
reg_c
.
template
AsType
<
float16_t
>()(
Number
<
0
>
{})
=
__builtin_amdgcn_mfma_scale_f32_32x32x64_f8f6f4
(
reg_a
,
reg_b
,
reg_c
.
template
AsType
<
float16_t
>()[
Number
<
0
>
{}],
0
,
// cbsz
0
,
// blgp
0
,
0
,
0
,
0
);
#else
ignore
=
reg_a
;
ignore
=
reg_b
;
ignore
=
reg_c
;
#endif
}
};
template
<
index_t
MPerWave
,
index_t
NPerWave
>
struct
intrin_mfma_scale_f32_32x32x64f8f6f4
;
template
<
>
struct
intrin_mfma_scale_f32_32x32x64f8f6f4
<
32
,
32
>
{
template
<
class
FloatC
>
__device__
static
void
Run
(
const
f8x32_t
&
reg_a
,
const
int32_t
scale_a
,
const
f8x32_t
&
reg_b
,
const
int32_t
scale_b
,
FloatC
&
reg_c
)
{
#if defined(__gfx950__)
// https://github.com/ROCm/llvm-project/blob/656552edc693e2bb4abc9258399c39d190fce2b3/llvm/test/Verifier/AMDGPU/mfma-scale.ll#L10
reg_c
.
template
AsType
<
float16_t
>()(
Number
<
0
>
{})
=
__builtin_amdgcn_mfma_scale_f32_32x32x64_f8f6f4
(
reg_a
,
reg_b
,
reg_c
.
template
AsType
<
float16_t
>()[
Number
<
0
>
{}],
0
,
// cbsz
0
,
// blgp
0
,
// { OPSEL_HI[0], OPSEL[0] }?
scale_a
,
0
,
// { OPSEL_HI[1], OPSEL[1] }?
scale_b
);
#else
ignore
=
reg_a
;
ignore
=
scale_a
;
ignore
=
reg_b
;
ignore
=
scale_b
;
ignore
=
reg_c
;
#endif
}
};
template
<
index_t
MPerWave
,
index_t
NPerWave
>
struct
intrin_mfma_scale_f32_16x16x128f8f6f4
;
template
<
>
struct
intrin_mfma_scale_f32_16x16x128f8f6f4
<
16
,
16
>
{
template
<
class
FloatC
>
__device__
static
void
Run
(
const
f8x32_t
&
reg_a
,
const
int32_t
scale_a
,
const
f8x32_t
&
reg_b
,
const
int32_t
scale_b
,
FloatC
&
reg_c
)
{
#if defined(__gfx950__)
// https://github.com/ROCm/llvm-project/blob/656552edc693e2bb4abc9258399c39d190fce2b3/llvm/test/Verifier/AMDGPU/mfma-scale.ll#L10
reg_c
.
template
AsType
<
float4_t
>()(
Number
<
0
>
{})
=
__builtin_amdgcn_mfma_scale_f32_16x16x128_f8f6f4
(
reg_a
,
reg_b
,
reg_c
.
template
AsType
<
float4_t
>()[
Number
<
0
>
{}],
0
,
// cbsz
0
,
// blgp
0
,
// { OPSEL_HI[0], OPSEL[0] }?
scale_a
,
0
,
// { OPSEL_HI[1], OPSEL[1] }?
scale_b
);
#else
ignore
=
reg_a
;
ignore
=
scale_a
;
ignore
=
reg_b
;
ignore
=
scale_b
;
ignore
=
reg_c
;
#endif
}
};
template
<
index_t
MPerWave
,
index_t
NPerWave
>
struct
intrin_mfma_f32_16x16x128f8f6f4
;
/// @brief Performs a matrix fused multiply-accumulate operation on 16x16x128 submatrices for f8f6f4
/// data types.
///
/// @note Calls scaled version of the instruction as the original instruction is not supported in
/// the backend. That is the intended use. There is a backend optimization to select the unscaled
/// operation if the scale is 0.
template
<
>
struct
intrin_mfma_f32_16x16x128f8f6f4
<
16
,
16
>
{
template
<
class
FloatC
>
__device__
static
void
Run
(
const
f8x32_t
&
reg_a
,
const
f8x32_t
&
reg_b
,
FloatC
&
reg_c
)
{
#if defined(__gfx950__)
// https://github.com/ROCm/llvm-project/blob/656552edc693e2bb4abc9258399c39d190fce2b3/llvm/test/Verifier/AMDGPU/mfma-scale.ll#L10
reg_c
.
template
AsType
<
float4_t
>()(
Number
<
0
>
{})
=
__builtin_amdgcn_mfma_scale_f32_16x16x128_f8f6f4
(
reg_a
,
reg_b
,
reg_c
.
template
AsType
<
float4_t
>()[
Number
<
0
>
{}],
0
,
// cbsz
0
,
// blgp
0
,
0
,
0
,
0
);
#else
ignore
=
reg_a
;
ignore
=
reg_b
;
ignore
=
reg_c
;
#endif
}
};
template
<
index_t
MPerWave
,
index_t
NPerWave
>
struct
intrin_mfma_f32_32x32x16f8f8
;
...
...
include/ck/utility/array.hpp
View file @
f1e53807
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-202
3
, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-202
5
, Advanced Micro Devices, Inc. All rights reserved.
#ifndef CK_ARRAY_HPP
#define CK_ARRAY_HPP
...
...
@@ -38,6 +38,8 @@ struct Array
}
__host__
__device__
constexpr
const
TData
*
begin
()
const
{
return
&
mData
[
0
];
}
__host__
__device__
constexpr
const
TData
*
end
()
const
{
return
&
mData
[
NSize
];
}
__host__
__device__
constexpr
TData
*
begin
()
{
return
&
mData
[
0
];
}
__host__
__device__
constexpr
TData
*
end
()
{
return
&
mData
[
NSize
];
}
};
// empty Array
...
...
@@ -54,7 +56,7 @@ template <typename X, typename... Xs>
__host__
__device__
constexpr
auto
make_array
(
X
&&
x
,
Xs
&&
...
xs
)
{
using
data_type
=
remove_cvref_t
<
X
>
;
return
Array
<
data_type
,
sizeof
...(
Xs
)
+
1
>
{
std
::
forward
<
X
>
(
x
),
std
::
forward
<
Xs
>
(
xs
)...};
return
Array
<
data_type
,
sizeof
...(
Xs
)
+
1
>
{
ck
::
forward
<
X
>
(
x
),
ck
::
forward
<
Xs
>
(
xs
)...};
}
// make empty array
...
...
include/ck/utility/blkgemmpipe_scheduler.hpp
View file @
f1e53807
...
...
@@ -90,14 +90,22 @@ struct BlockwiseGemmXdlops_pipeline_hotloop_inst
KPerXDL
);
printf
(
" A/B buffer load inst: %d, %d
\n
A/B LDS write inst: %d, %d
\n
A/B LDS read inst: "
"%d, %d
\n
C MFMA inst: %d
\n
"
,
"%d, %d
\n
C MFMA inst: %d
\n
"
"A/B LDS read width: %d, %d, A/B LDS write width: %d, %d, A/B buffer load width: "
"%d/ %d
\n
"
,
A_Buffer_Load_Inst_Num
,
B_Buffer_Load_Inst_Num
,
A_LDS_Write_Inst_Num
,
B_LDS_Write_Inst_Num
,
A_LDS_Read_Inst_Num
,
B_LDS_Read_Inst_Num
,
C_MFMA_Inst_Num
);
C_MFMA_Inst_Num
,
A_LDS_Read_Width
,
B_LDS_Read_Width
,
ALDSWriteWidth
,
BLDSWriteWidth
,
ABufferLoadWidth
,
BBufferLoadWidth
);
}
};
...
...
include/ck/utility/container_helper.hpp
View file @
f1e53807
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-202
3
, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-202
5
, Advanced Micro Devices, Inc. All rights reserved.
#ifndef CK_CONTAINER_HELPER_HPP
#define CK_CONTAINER_HELPER_HPP
...
...
@@ -326,14 +326,14 @@ template <typename T, index_t NX, index_t NY>
__host__
__device__
constexpr
auto
container_concat
(
const
Array
<
T
,
NX
>&
ax
,
const
Array
<
T
,
NY
>&
ay
)
{
return
unpack2
(
[
&
](
auto
&&
...
zs
)
{
return
make_array
(
std
::
forward
<
decltype
(
zs
)
>
(
zs
)...);
},
ax
,
ay
);
[
&
](
auto
&&
...
zs
)
{
return
make_array
(
ck
::
forward
<
decltype
(
zs
)
>
(
zs
)...);
},
ax
,
ay
);
}
template
<
typename
...
X
,
typename
...
Y
>
__host__
__device__
constexpr
auto
container_concat
(
const
Tuple
<
X
...
>&
tx
,
const
Tuple
<
Y
...
>&
ty
)
{
return
unpack2
(
[
&
](
auto
&&
...
zs
)
{
return
make_tuple
(
std
::
forward
<
decltype
(
zs
)
>
(
zs
)...);
},
tx
,
ty
);
[
&
](
auto
&&
...
zs
)
{
return
make_tuple
(
ck
::
forward
<
decltype
(
zs
)
>
(
zs
)...);
},
tx
,
ty
);
}
template
<
typename
Container
>
...
...
include/ck/utility/data_type.hpp
View file @
f1e53807
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-202
4
, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-202
5
, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck/utility/amd_ck_fp8.hpp"
#include "ck/utility/e8m0.hpp"
#include "ck/utility/statically_indexed_array.hpp"
#ifdef CK_CODE_GEN_RTC
using
int8_t
=
signed
char
;
using
uint8_t
=
unsigned
char
;
using
int16_t
=
signed
short
;
using
uint16_t
=
unsigned
short
;
using
float_t
=
float
;
#endif
namespace
ck
{
#ifdef CK_CODE_GEN_RTC
using
byte
=
unsigned
char
;
#else
using
std
::
byte
;
#endif
using
bhalf_t
=
ushort
;
using
half_t
=
_Float16
;
using
int4_t
=
_BitInt
(
4
);
using
f8_t
=
_BitInt
(
8
);
using
bf8_t
=
unsigned
_BitInt
(
8
);
using
f4_t
=
unsigned
_BitInt
(
4
);
using
f6_t
=
_BitInt
(
6
);
// e2m3 format
using
bf6_t
=
unsigned
_BitInt
(
6
);
// e3m2 format
struct
f4x2_pk_t
{
using
type
=
uint8_t
;
type
data
;
f4x2_pk_t
()
:
data
{
type
{}}
{}
f4x2_pk_t
(
type
init
)
:
data
{
init
}
{}
template
<
index_t
I
>
__host__
__device__
inline
type
unpack
(
Number
<
I
>
)
const
{
static_assert
(
I
<
2
,
"Index is out of range."
);
if
constexpr
(
I
==
0
)
return
data
&
0b00001111
;
else
return
(
data
>>
4
);
}
__host__
__device__
inline
type
pack
(
const
type
x0
,
const
type
x1
)
{
return
(
x1
<<
4
)
|
(
x0
&
0b00001111
);
}
};
struct
f6x16_pk_t
{
// store 16 elements of f6_t in an array of 3 uint32_t
using
element_type
=
uint32_t
;
using
type
=
StaticallyIndexedArray_v2
<
element_type
,
3
>
;
type
data
;
typedef
int8_t
test_vec_t
__attribute__
((
ext_vector_type
(
16
)));
f6x16_pk_t
()
:
data
{
type
{}}
{}
f6x16_pk_t
(
type
init
)
:
data
{
init
}
{}
template
<
index_t
I
>
__host__
__device__
inline
f6_t
unpack
(
Number
<
I
>
)
{
static_assert
(
I
<
16
,
"Index out of range for 16 f6_t elements."
);
constexpr
int
num_bits_elem
=
6
;
constexpr
int
num_bits_vec_elem
=
32
;
constexpr
int
vector_size
=
3
;
constexpr
int
bit_pos
=
I
*
num_bits_elem
;
constexpr
int
arr_idx
=
bit_pos
/
num_bits_vec_elem
;
constexpr
int
bit_offset
=
bit_pos
%
num_bits_vec_elem
;
uint32_t
bits
=
data
.
At
(
Number
<
arr_idx
>
{})
>>
bit_offset
;
constexpr
int
overhang
=
bit_offset
+
num_bits_elem
-
num_bits_vec_elem
;
if
constexpr
(
overhang
>
0
&&
(
arr_idx
+
1
)
<
vector_size
)
{
bits
|=
(
data
.
At
(
Number
<
arr_idx
+
1
>
{})
&
((
1u
<<
overhang
)
-
1
))
<<
(
num_bits_elem
-
overhang
);
}
return
static_cast
<
f6_t
>
(
bits
&
0x3F
);
}
__host__
__device__
inline
type
pack
(
const
test_vec_t
&
x
)
{
type
packed
{};
// for each of the 16 f6_t values, place its 6 bits in the correct position
ck
::
static_for
<
0
,
16
,
1
>
{}([
&
](
auto
i
)
{
uint32_t
bits
=
static_cast
<
uint32_t
>
(
x
[
static_cast
<
int
>
(
i
)])
&
0x3F
;
constexpr
int
num_bits_elem
=
6
;
constexpr
int
num_bits_vec_elem
=
32
;
constexpr
int
vector_size
=
3
;
constexpr
int
bit_pos
=
i
*
num_bits_elem
;
constexpr
int
arr_index
=
bit_pos
/
num_bits_vec_elem
;
constexpr
int
bit_offset
=
bit_pos
%
num_bits_vec_elem
;
constexpr
int
overhang
=
bit_offset
+
num_bits_elem
-
num_bits_vec_elem
;
uint32_t
old_value
=
packed
.
At
(
Number
<
arr_index
>
{});
// insert bits into the current 32-bit block
old_value
|=
(
bits
<<
bit_offset
);
packed
.
At
(
Number
<
arr_index
>
{})
=
old_value
;
// if it crosses into the next block, shift the remainder
if
constexpr
(
overhang
>
0
&&
(
arr_index
+
1
)
<
vector_size
)
{
uint32_t
next_value
=
packed
.
At
(
Number
<
arr_index
+
1
>
{});
next_value
|=
(
bits
>>
(
num_bits_elem
-
overhang
));
packed
.
At
(
Number
<
arr_index
+
1
>
{})
=
next_value
;
}
});
return
packed
;
}
};
struct
f6x32_pk_t
{
// store 32 elements of f6_t in an array of 6 uint32_t
using
element_type
=
uint32_t
;
using
type
=
StaticallyIndexedArray_v2
<
element_type
,
6
>
;
type
data
;
typedef
int8_t
test_vec_t
__attribute__
((
ext_vector_type
(
32
)));
f6x32_pk_t
()
:
data
{
type
{}}
{}
f6x32_pk_t
(
type
init
)
:
data
{
init
}
{}
template
<
index_t
I
>
__host__
__device__
inline
f6_t
unpack
(
Number
<
I
>
)
{
static_assert
(
I
<
32
,
"Index out of range for 32 f6_t elements."
);
constexpr
int
num_bits_elem
=
6
;
constexpr
int
num_bits_vec_elem
=
32
;
constexpr
int
vector_size
=
6
;
constexpr
int
bit_pos
=
I
*
num_bits_elem
;
constexpr
int
arr_idx
=
bit_pos
/
num_bits_vec_elem
;
constexpr
int
bit_offset
=
bit_pos
%
num_bits_vec_elem
;
uint32_t
bits
=
data
.
At
(
Number
<
arr_idx
>
{})
>>
bit_offset
;
constexpr
int
overhang
=
bit_offset
+
num_bits_elem
-
num_bits_vec_elem
;
if
constexpr
(
overhang
>
0
&&
(
arr_idx
+
1
)
<
vector_size
)
{
bits
|=
(
data
.
At
(
Number
<
arr_idx
+
1
>
{})
&
((
1u
<<
overhang
)
-
1
))
<<
(
num_bits_elem
-
overhang
);
}
return
static_cast
<
f6_t
>
(
bits
&
0x3F
);
}
__host__
__device__
inline
type
pack
(
const
test_vec_t
&
x
)
{
type
packed
{};
// for each of the 32 f6_t values, place its 6 bits in the correct position
ck
::
static_for
<
0
,
32
,
1
>
{}([
&
](
auto
i
)
{
uint32_t
bits
=
static_cast
<
uint32_t
>
(
x
[
static_cast
<
int
>
(
i
)])
&
0x3F
;
constexpr
int
num_bits_elem
=
6
;
constexpr
int
num_bits_vec_elem
=
32
;
constexpr
int
vector_size
=
6
;
constexpr
int
bit_pos
=
i
*
num_bits_elem
;
constexpr
int
arr_index
=
bit_pos
/
num_bits_vec_elem
;
constexpr
int
bit_offset
=
bit_pos
%
num_bits_vec_elem
;
constexpr
int
overhang
=
bit_offset
+
num_bits_elem
-
num_bits_vec_elem
;
uint32_t
old_value
=
packed
.
At
(
Number
<
arr_index
>
{});
// insert bits into the current 32-bit block
old_value
|=
(
bits
<<
bit_offset
);
packed
.
At
(
Number
<
arr_index
>
{})
=
old_value
;
// if it crosses into the next block, shift the remainder
if
constexpr
(
overhang
>
0
&&
(
arr_index
+
1
)
<
vector_size
)
{
uint32_t
next_value
=
packed
.
At
(
Number
<
arr_index
+
1
>
{});
next_value
|=
(
bits
>>
(
num_bits_elem
-
overhang
));
packed
.
At
(
Number
<
arr_index
+
1
>
{})
=
next_value
;
}
});
return
packed
;
}
};
struct
bf6x16_pk_t
{
// store 16 elements of bf6_t in an array of 3 uint32_t
using
element_type
=
uint32_t
;
using
type
=
StaticallyIndexedArray_v2
<
element_type
,
3
>
;
type
data
;
typedef
int8_t
test_vec_t
__attribute__
((
ext_vector_type
(
16
)));
bf6x16_pk_t
()
:
data
{
type
{}}
{}
bf6x16_pk_t
(
type
init
)
:
data
{
init
}
{}
template
<
index_t
I
>
__host__
__device__
inline
bf6_t
unpack
(
Number
<
I
>
)
{
static_assert
(
I
<
16
,
"Index out of range for 16 f6_t elements."
);
constexpr
int
num_bits_elem
=
6
;
constexpr
int
num_bits_vec_elem
=
32
;
constexpr
int
vector_size
=
3
;
constexpr
int
bit_pos
=
I
*
num_bits_elem
;
constexpr
int
arr_idx
=
bit_pos
/
num_bits_vec_elem
;
constexpr
int
bit_offset
=
bit_pos
%
num_bits_vec_elem
;
uint32_t
bits
=
data
.
At
(
Number
<
arr_idx
>
{})
>>
bit_offset
;
constexpr
int
overhang
=
bit_offset
+
num_bits_elem
-
num_bits_vec_elem
;
if
constexpr
(
overhang
>
0
&&
(
arr_idx
+
1
)
<
vector_size
)
{
bits
|=
(
data
.
At
(
Number
<
arr_idx
+
1
>
{})
&
((
1u
<<
overhang
)
-
1
))
<<
(
num_bits_elem
-
overhang
);
}
return
static_cast
<
bf6_t
>
(
bits
&
0x3F
);
}
__host__
__device__
inline
type
pack
(
const
test_vec_t
&
x
)
{
type
packed
{};
// for each of the 16 bf6_t values, place its 6 bits in the correct position
ck
::
static_for
<
0
,
16
,
1
>
{}([
&
](
auto
i
)
{
uint32_t
bits
=
static_cast
<
uint32_t
>
(
x
[
static_cast
<
int
>
(
i
)])
&
0x3F
;
constexpr
int
num_bits_elem
=
6
;
constexpr
int
num_bits_vec_elem
=
32
;
constexpr
int
vector_size
=
3
;
constexpr
int
bit_pos
=
i
*
num_bits_elem
;
constexpr
int
arr_index
=
bit_pos
/
num_bits_vec_elem
;
constexpr
int
bit_offset
=
bit_pos
%
num_bits_vec_elem
;
constexpr
int
overhang
=
bit_offset
+
num_bits_elem
-
num_bits_vec_elem
;
uint32_t
old_value
=
packed
.
At
(
Number
<
arr_index
>
{});
// insert bits into the current 32-bit block
old_value
|=
(
bits
<<
bit_offset
);
packed
.
At
(
Number
<
arr_index
>
{})
=
old_value
;
// if it crosses into the next block, shift the remainder
if
constexpr
(
overhang
>
0
&&
(
arr_index
+
1
)
<
vector_size
)
{
uint32_t
next_value
=
packed
.
At
(
Number
<
arr_index
+
1
>
{});
next_value
|=
(
bits
>>
(
num_bits_elem
-
overhang
));
packed
.
At
(
Number
<
arr_index
+
1
>
{})
=
next_value
;
}
});
return
packed
;
}
};
struct
bf6x32_pk_t
{
// store 32 elements of bf6_t in an array of 6 uint32_t
using
element_type
=
uint32_t
;
using
type
=
StaticallyIndexedArray_v2
<
element_type
,
6
>
;
type
data
;
typedef
int8_t
test_vec_t
__attribute__
((
ext_vector_type
(
32
)));
bf6x32_pk_t
()
:
data
{
type
{}}
{}
bf6x32_pk_t
(
type
init
)
:
data
{
init
}
{}
template
<
index_t
I
>
__host__
__device__
inline
bf6_t
unpack
(
Number
<
I
>
)
{
static_assert
(
I
<
32
,
"Index out of range for 32 f6_t elements."
);
constexpr
int
num_bits_elem
=
6
;
constexpr
int
num_bits_vec_elem
=
32
;
constexpr
int
vector_size
=
6
;
constexpr
int
bit_pos
=
I
*
num_bits_elem
;
constexpr
int
arr_idx
=
bit_pos
/
num_bits_vec_elem
;
constexpr
int
bit_offset
=
bit_pos
%
num_bits_vec_elem
;
uint32_t
bits
=
data
.
At
(
Number
<
arr_idx
>
{})
>>
bit_offset
;
constexpr
int
overhang
=
bit_offset
+
num_bits_elem
-
num_bits_vec_elem
;
if
constexpr
(
overhang
>
0
&&
(
arr_idx
+
1
)
<
vector_size
)
{
bits
|=
(
data
.
At
(
Number
<
arr_idx
+
1
>
{})
&
((
1u
<<
overhang
)
-
1
))
<<
(
num_bits_elem
-
overhang
);
}
return
static_cast
<
bf6_t
>
(
bits
&
0x3F
);
}
__host__
__device__
inline
type
pack
(
const
test_vec_t
&
x
)
{
type
packed
{};
// for each of the 32 bf6_t values, place its 6 bits in the correct position
ck
::
static_for
<
0
,
32
,
1
>
{}([
&
](
auto
i
)
{
uint32_t
bits
=
static_cast
<
uint32_t
>
(
x
[
static_cast
<
int
>
(
i
)])
&
0x3F
;
constexpr
int
num_bits_elem
=
6
;
constexpr
int
num_bits_vec_elem
=
32
;
constexpr
int
vector_size
=
6
;
constexpr
int
bit_pos
=
i
*
num_bits_elem
;
constexpr
int
arr_index
=
bit_pos
/
num_bits_vec_elem
;
constexpr
int
bit_offset
=
bit_pos
%
num_bits_vec_elem
;
constexpr
int
overhang
=
bit_offset
+
num_bits_elem
-
num_bits_vec_elem
;
uint32_t
old_value
=
packed
.
At
(
Number
<
arr_index
>
{});
// insert bits into the current 32-bit block
old_value
|=
(
bits
<<
bit_offset
);
packed
.
At
(
Number
<
arr_index
>
{})
=
old_value
;
// if it crosses into the next block, shift the remainder
if
constexpr
(
overhang
>
0
&&
(
arr_index
+
1
)
<
vector_size
)
{
uint32_t
next_value
=
packed
.
At
(
Number
<
arr_index
+
1
>
{});
next_value
|=
(
bits
>>
(
num_bits_elem
-
overhang
));
packed
.
At
(
Number
<
arr_index
+
1
>
{})
=
next_value
;
}
});
return
packed
;
}
};
// custom data type - pack int4 data
struct
pk_i4_t
{
using
type
=
int8_t
;
type
data
;
__host__
__device__
constexpr
pk_i4_t
()
:
data
{
type
{}}
{}
__host__
__device__
constexpr
pk_i4_t
(
type
init
)
:
data
{
init
}
{}
};
inline
constexpr
auto
next_pow2
(
uint32_t
x
)
{
...
...
@@ -19,14 +330,16 @@ inline constexpr auto next_pow2(uint32_t x)
return
x
>
1u
?
(
1u
<<
(
32u
-
__builtin_clz
(
x
-
1u
)))
:
x
;
}
// native types: double, float, _Float16, ushort, int32_t, int8_t, uint8_t, f8_t, bf8_t, bool
// native types: double, float, _Float16, ushort, int32_t, int8_t, uint8_t, f8_fnuz_t, bf8_fnuz_t,
// native types: bool, f4_t, f6_t, bf6_t
template
<
typename
T
>
inline
constexpr
bool
is_native_type
()
{
return
is_same
<
T
,
double
>::
value
||
is_same
<
T
,
float
>::
value
||
is_same
<
T
,
half_t
>::
value
||
is_same
<
T
,
bhalf_t
>::
value
||
is_same
<
T
,
int32_t
>::
value
||
is_same
<
T
,
int8_t
>::
value
||
is_same
<
T
,
uint8_t
>::
value
||
is_same
<
T
,
f8_t
>::
value
||
is_same
<
T
,
bf8_t
>::
value
||
is_same
<
T
,
bool
>::
value
;
is_same
<
T
,
uint8_t
>::
value
||
is_same
<
T
,
f8_fnuz_t
>::
value
||
is_same
<
T
,
bf8_fnuz_t
>::
value
||
is_same
<
T
,
bool
>::
value
||
is_same
<
T
,
f4_t
>::
value
||
is_same
<
T
,
f6_t
>::
value
||
is_same
<
T
,
bf6_t
>::
value
;
}
// vector_type
...
...
@@ -166,16 +479,37 @@ struct scalar_type<int4_t>
#endif
template
<
>
struct
scalar_type
<
f8_t
>
struct
scalar_type
<
pk_i4_t
>
{
using
type
=
pk_i4_t
;
static
constexpr
index_t
vector_size
=
1
;
};
template
<
>
struct
scalar_type
<
f8_fnuz_t
>
{
using
type
=
f8_fnuz_t
;
static
constexpr
index_t
vector_size
=
1
;
};
template
<
>
struct
scalar_type
<
bf8_fnuz_t
>
{
using
type
=
bf8_fnuz_t
;
static
constexpr
index_t
vector_size
=
1
;
};
template
<
>
struct
scalar_type
<
f8_ocp_t
>
{
using
type
=
f8_
t
;
using
type
=
f8_
ocp_t
::
data_type
;
static
constexpr
index_t
vector_size
=
1
;
};
template
<
>
struct
scalar_type
<
bf8_t
>
struct
scalar_type
<
bf8_
ocp_
t
>
{
using
type
=
bf8_
t
;
using
type
=
bf8_
ocp_t
::
data_type
;
static
constexpr
index_t
vector_size
=
1
;
};
...
...
@@ -187,7 +521,7 @@ struct scalar_type<bool>
};
template
<
typename
T
>
struct
vector_type
<
T
,
1
,
typename
std
::
enable_if_t
<
is_native_type
<
T
>
()
>>
struct
vector_type
<
T
,
1
,
typename
ck
::
enable_if_t
<
is_native_type
<
T
>
()
>>
{
using
d1_t
=
T
;
using
type
=
d1_t
;
...
...
@@ -223,7 +557,7 @@ struct vector_type<T, 1, typename std::enable_if_t<is_native_type<T>()>>
__device__
int
static
err
=
0
;
template
<
typename
T
>
struct
vector_type
<
T
,
2
,
typename
std
::
enable_if_t
<
is_native_type
<
T
>
()
>>
struct
vector_type
<
T
,
2
,
typename
ck
::
enable_if_t
<
is_native_type
<
T
>
()
>>
{
using
d1_t
=
T
;
typedef
T
d2_t
__attribute__
((
ext_vector_type
(
2
)));
...
...
@@ -283,20 +617,20 @@ struct vector_type<T, 2, typename std::enable_if_t<is_native_type<T>()>>
};
template
<
typename
T
>
struct
vector_type
<
T
,
4
,
typename
std
::
enable_if_t
<
is_native_type
<
T
>
()
>>
struct
vector_type
<
T
,
3
,
typename
ck
::
enable_if_t
<
is_native_type
<
T
>
()
>>
{
using
d1_t
=
T
;
typedef
T
d2_t
__attribute__
((
ext_vector_type
(
2
)));
typedef
T
d
4
_t
__attribute__
((
ext_vector_type
(
4
)));
typedef
T
d
3
_t
__attribute__
((
ext_vector_type
(
3
)));
using
type
=
d
4
_t
;
using
type
=
d
3
_t
;
union
{
d
4
_t
d
4
_
;
StaticallyIndexedArray
<
d1_t
,
4
>
d1x
4
_
;
StaticallyIndexedArray
<
d2_t
,
2
>
d2x
2
_
;
StaticallyIndexedArray
<
d
4
_t
,
1
>
d
4
x1_
;
d
3
_t
d
3
_
;
StaticallyIndexedArray
<
d1_t
,
3
>
d1x
3
_
;
StaticallyIndexedArray
<
d2_t
,
1
>
d2x
1
_
;
StaticallyIndexedArray
<
d
3
_t
,
1
>
d
3
x1_
;
}
data_
;
__host__
__device__
constexpr
vector_type
()
:
data_
{
type
{
0
}}
{}
...
...
@@ -306,20 +640,20 @@ struct vector_type<T, 4, typename std::enable_if_t<is_native_type<T>()>>
template
<
typename
X
>
__host__
__device__
constexpr
const
auto
&
AsType
()
const
{
static_assert
(
is_same
<
X
,
d1_t
>::
value
||
is_same
<
X
,
d2_t
>::
value
||
is_same
<
X
,
d
4
_t
>::
value
,
static_assert
(
is_same
<
X
,
d1_t
>::
value
||
is_same
<
X
,
d2_t
>::
value
||
is_same
<
X
,
d
3
_t
>::
value
,
"Something went wrong, please check src and dst types."
);
if
constexpr
(
is_same
<
X
,
d1_t
>::
value
)
{
return
data_
.
d1x
4
_
;
return
data_
.
d1x
3
_
;
}
else
if
constexpr
(
is_same
<
X
,
d2_t
>::
value
)
{
return
data_
.
d2x
2
_
;
return
data_
.
d2x
1
_
;
}
else
if
constexpr
(
is_same
<
X
,
d
4
_t
>::
value
)
else
if
constexpr
(
is_same
<
X
,
d
3
_t
>::
value
)
{
return
data_
.
d
4
x1_
;
return
data_
.
d
3
x1_
;
}
else
{
...
...
@@ -330,20 +664,20 @@ struct vector_type<T, 4, typename std::enable_if_t<is_native_type<T>()>>
template
<
typename
X
>
__host__
__device__
constexpr
auto
&
AsType
()
{
static_assert
(
is_same
<
X
,
d1_t
>::
value
||
is_same
<
X
,
d2_t
>::
value
||
is_same
<
X
,
d
4
_t
>::
value
,
static_assert
(
is_same
<
X
,
d1_t
>::
value
||
is_same
<
X
,
d2_t
>::
value
||
is_same
<
X
,
d
3
_t
>::
value
,
"Something went wrong, please check src and dst types."
);
if
constexpr
(
is_same
<
X
,
d1_t
>::
value
)
{
return
data_
.
d1x
4
_
;
return
data_
.
d1x
3
_
;
}
else
if
constexpr
(
is_same
<
X
,
d2_t
>::
value
)
{
return
data_
.
d2x
2
_
;
return
data_
.
d2x
1
_
;
}
else
if
constexpr
(
is_same
<
X
,
d
4
_t
>::
value
)
else
if
constexpr
(
is_same
<
X
,
d
3
_t
>::
value
)
{
return
data_
.
d
4
x1_
;
return
data_
.
d
3
x1_
;
}
else
{
...
...
@@ -353,22 +687,20 @@ struct vector_type<T, 4, typename std::enable_if_t<is_native_type<T>()>>
};
template
<
typename
T
>
struct
vector_type
<
T
,
8
,
typename
std
::
enable_if_t
<
is_native_type
<
T
>
()
>>
struct
vector_type
<
T
,
4
,
typename
ck
::
enable_if_t
<
is_native_type
<
T
>
()
>>
{
using
d1_t
=
T
;
typedef
T
d2_t
__attribute__
((
ext_vector_type
(
2
)));
typedef
T
d4_t
__attribute__
((
ext_vector_type
(
4
)));
typedef
T
d8_t
__attribute__
((
ext_vector_type
(
8
)));
using
type
=
d
8
_t
;
using
type
=
d
4
_t
;
union
{
d8_t
d8_
;
StaticallyIndexedArray
<
d1_t
,
8
>
d1x8_
;
StaticallyIndexedArray
<
d2_t
,
4
>
d2x4_
;
StaticallyIndexedArray
<
d4_t
,
2
>
d4x2_
;
StaticallyIndexedArray
<
d8_t
,
1
>
d8x1_
;
d4_t
d4_
;
StaticallyIndexedArray
<
d1_t
,
4
>
d1x4_
;
StaticallyIndexedArray
<
d2_t
,
2
>
d2x2_
;
StaticallyIndexedArray
<
d4_t
,
1
>
d4x1_
;
}
data_
;
__host__
__device__
constexpr
vector_type
()
:
data_
{
type
{
0
}}
{}
...
...
@@ -378,25 +710,20 @@ struct vector_type<T, 8, typename std::enable_if_t<is_native_type<T>()>>
template
<
typename
X
>
__host__
__device__
constexpr
const
auto
&
AsType
()
const
{
static_assert
(
is_same
<
X
,
d1_t
>::
value
||
is_same
<
X
,
d2_t
>::
value
||
is_same
<
X
,
d4_t
>::
value
||
is_same
<
X
,
d8_t
>::
value
,
static_assert
(
is_same
<
X
,
d1_t
>::
value
||
is_same
<
X
,
d2_t
>::
value
||
is_same
<
X
,
d4_t
>::
value
,
"Something went wrong, please check src and dst types."
);
if
constexpr
(
is_same
<
X
,
d1_t
>::
value
)
{
return
data_
.
d1x
8
_
;
return
data_
.
d1x
4
_
;
}
else
if
constexpr
(
is_same
<
X
,
d2_t
>::
value
)
{
return
data_
.
d2x
4
_
;
return
data_
.
d2x
2
_
;
}
else
if
constexpr
(
is_same
<
X
,
d4_t
>::
value
)
{
return
data_
.
d4x2_
;
}
else
if
constexpr
(
is_same
<
X
,
d8_t
>::
value
)
{
return
data_
.
d8x1_
;
return
data_
.
d4x1_
;
}
else
{
...
...
@@ -407,25 +734,20 @@ struct vector_type<T, 8, typename std::enable_if_t<is_native_type<T>()>>
template
<
typename
X
>
__host__
__device__
constexpr
auto
&
AsType
()
{
static_assert
(
is_same
<
X
,
d1_t
>::
value
||
is_same
<
X
,
d2_t
>::
value
||
is_same
<
X
,
d4_t
>::
value
||
is_same
<
X
,
d8_t
>::
value
,
static_assert
(
is_same
<
X
,
d1_t
>::
value
||
is_same
<
X
,
d2_t
>::
value
||
is_same
<
X
,
d4_t
>::
value
,
"Something went wrong, please check src and dst types."
);
if
constexpr
(
is_same
<
X
,
d1_t
>::
value
)
{
return
data_
.
d1x
8
_
;
return
data_
.
d1x
4
_
;
}
else
if
constexpr
(
is_same
<
X
,
d2_t
>::
value
)
{
return
data_
.
d2x
4
_
;
return
data_
.
d2x
2
_
;
}
else
if
constexpr
(
is_same
<
X
,
d4_t
>::
value
)
{
return
data_
.
d4x2_
;
}
else
if
constexpr
(
is_same
<
X
,
d8_t
>::
value
)
{
return
data_
.
d8x1_
;
return
data_
.
d4x1_
;
}
else
{
...
...
@@ -435,24 +757,20 @@ struct vector_type<T, 8, typename std::enable_if_t<is_native_type<T>()>>
};
template
<
typename
T
>
struct
vector_type
<
T
,
16
,
typename
std
::
enable_if_t
<
is_native_type
<
T
>
()
>>
struct
vector_type
<
T
,
5
,
typename
ck
::
enable_if_t
<
is_native_type
<
T
>
()
>>
{
using
d1_t
=
T
;
typedef
T
d2_t
__attribute__
((
ext_vector_type
(
2
)));
typedef
T
d4_t
__attribute__
((
ext_vector_type
(
4
)));
typedef
T
d8_t
__attribute__
((
ext_vector_type
(
8
)));
typedef
T
d16_t
__attribute__
((
ext_vector_type
(
16
)));
typedef
T
d5_t
__attribute__
((
ext_vector_type
(
5
)));
using
type
=
d
16
_t
;
using
type
=
d
5
_t
;
union
{
d16_t
d16_
;
StaticallyIndexedArray
<
d1_t
,
16
>
d1x16_
;
StaticallyIndexedArray
<
d2_t
,
8
>
d2x8_
;
StaticallyIndexedArray
<
d4_t
,
4
>
d4x4_
;
StaticallyIndexedArray
<
d8_t
,
2
>
d8x2_
;
StaticallyIndexedArray
<
d16_t
,
1
>
d16x1_
;
d5_t
d5_
;
StaticallyIndexedArray
<
d1_t
,
5
>
d1x5_
;
StaticallyIndexedArray
<
d4_t
,
1
>
d4x1_
;
StaticallyIndexedArray
<
d5_t
,
1
>
d5x1_
;
}
data_
;
__host__
__device__
constexpr
vector_type
()
:
data_
{
type
{
0
}}
{}
...
...
@@ -462,30 +780,20 @@ struct vector_type<T, 16, typename std::enable_if_t<is_native_type<T>()>>
template
<
typename
X
>
__host__
__device__
constexpr
const
auto
&
AsType
()
const
{
static_assert
(
is_same
<
X
,
d1_t
>::
value
||
is_same
<
X
,
d2_t
>::
value
||
is_same
<
X
,
d4_t
>::
value
||
is_same
<
X
,
d8_t
>::
value
||
is_same
<
X
,
d16_t
>::
value
,
static_assert
(
is_same
<
X
,
d1_t
>::
value
||
is_same
<
X
,
d4_t
>::
value
||
is_same
<
X
,
d5_t
>::
value
,
"Something went wrong, please check src and dst types."
);
if
constexpr
(
is_same
<
X
,
d1_t
>::
value
)
{
return
data_
.
d1x16_
;
}
else
if
constexpr
(
is_same
<
X
,
d2_t
>::
value
)
{
return
data_
.
d2x8_
;
return
data_
.
d1x5_
;
}
else
if
constexpr
(
is_same
<
X
,
d4_t
>::
value
)
{
return
data_
.
d4x4_
;
}
else
if
constexpr
(
is_same
<
X
,
d8_t
>::
value
)
{
return
data_
.
d8x2_
;
return
data_
.
d4x1_
;
}
else
if
constexpr
(
is_same
<
X
,
d
16
_t
>::
value
)
else
if
constexpr
(
is_same
<
X
,
d
5
_t
>::
value
)
{
return
data_
.
d
16
x1_
;
return
data_
.
d
5
x1_
;
}
else
{
...
...
@@ -496,30 +804,20 @@ struct vector_type<T, 16, typename std::enable_if_t<is_native_type<T>()>>
template
<
typename
X
>
__host__
__device__
constexpr
auto
&
AsType
()
{
static_assert
(
is_same
<
X
,
d1_t
>::
value
||
is_same
<
X
,
d2_t
>::
value
||
is_same
<
X
,
d4_t
>::
value
||
is_same
<
X
,
d8_t
>::
value
||
is_same
<
X
,
d16_t
>::
value
,
static_assert
(
is_same
<
X
,
d1_t
>::
value
||
is_same
<
X
,
d4_t
>::
value
||
is_same
<
X
,
d5_t
>::
value
,
"Something went wrong, please check src and dst types."
);
if
constexpr
(
is_same
<
X
,
d1_t
>::
value
)
{
return
data_
.
d1x16_
;
}
else
if
constexpr
(
is_same
<
X
,
d2_t
>::
value
)
{
return
data_
.
d2x8_
;
return
data_
.
d1x5_
;
}
else
if
constexpr
(
is_same
<
X
,
d4_t
>::
value
)
{
return
data_
.
d4x4_
;
}
else
if
constexpr
(
is_same
<
X
,
d8_t
>::
value
)
{
return
data_
.
d8x2_
;
return
data_
.
d4x1_
;
}
else
if
constexpr
(
is_same
<
X
,
d
16
_t
>::
value
)
else
if
constexpr
(
is_same
<
X
,
d
5
_t
>::
value
)
{
return
data_
.
d
16
x1_
;
return
data_
.
d
5
x1_
;
}
else
{
...
...
@@ -529,26 +827,22 @@ struct vector_type<T, 16, typename std::enable_if_t<is_native_type<T>()>>
};
template
<
typename
T
>
struct
vector_type
<
T
,
32
,
typename
std
::
enable_if_t
<
is_native_type
<
T
>
()
>>
struct
vector_type
<
T
,
7
,
typename
ck
::
enable_if_t
<
is_native_type
<
T
>
()
>>
{
using
d1_t
=
T
;
typedef
T
d2_t
__attribute__
((
ext_vector_type
(
2
)));
typedef
T
d4_t
__attribute__
((
ext_vector_type
(
4
)));
typedef
T
d8_t
__attribute__
((
ext_vector_type
(
8
)));
typedef
T
d16_t
__attribute__
((
ext_vector_type
(
16
)));
typedef
T
d32_t
__attribute__
((
ext_vector_type
(
32
)));
typedef
T
d7_t
__attribute__
((
ext_vector_type
(
7
)));
using
type
=
d
32
_t
;
using
type
=
d
7
_t
;
union
{
d32_t
d32_
;
StaticallyIndexedArray
<
d1_t
,
32
>
d1x32_
;
StaticallyIndexedArray
<
d2_t
,
16
>
d2x16_
;
StaticallyIndexedArray
<
d4_t
,
8
>
d4x8_
;
StaticallyIndexedArray
<
d8_t
,
4
>
d8x4_
;
StaticallyIndexedArray
<
d16_t
,
2
>
d16x2_
;
StaticallyIndexedArray
<
d32_t
,
1
>
d32x1_
;
d7_t
d7_
;
StaticallyIndexedArray
<
d1_t
,
7
>
d1x7_
;
StaticallyIndexedArray
<
d2_t
,
3
>
d2x3_
;
StaticallyIndexedArray
<
d4_t
,
1
>
d4x1_
;
StaticallyIndexedArray
<
d7_t
,
1
>
d7x1_
;
}
data_
;
__host__
__device__
constexpr
vector_type
()
:
data_
{
type
{
0
}}
{}
...
...
@@ -559,33 +853,24 @@ struct vector_type<T, 32, typename std::enable_if_t<is_native_type<T>()>>
__host__
__device__
constexpr
const
auto
&
AsType
()
const
{
static_assert
(
is_same
<
X
,
d1_t
>::
value
||
is_same
<
X
,
d2_t
>::
value
||
is_same
<
X
,
d4_t
>::
value
||
is_same
<
X
,
d8_t
>::
value
||
is_same
<
X
,
d16_t
>::
value
||
is_same
<
X
,
d32_t
>::
value
,
is_same
<
X
,
d4_t
>::
value
||
is_same
<
X
,
d7_t
>::
value
,
"Something went wrong, please check src and dst types."
);
if
constexpr
(
is_same
<
X
,
d1_t
>::
value
)
{
return
data_
.
d1x
32
_
;
return
data_
.
d1x
7
_
;
}
else
if
constexpr
(
is_same
<
X
,
d2_t
>::
value
)
{
return
data_
.
d2x
16
_
;
return
data_
.
d2x
3
_
;
}
else
if
constexpr
(
is_same
<
X
,
d4_t
>::
value
)
{
return
data_
.
d4x8_
;
}
else
if
constexpr
(
is_same
<
X
,
d8_t
>::
value
)
{
return
data_
.
d8x4_
;
}
else
if
constexpr
(
is_same
<
X
,
d16_t
>::
value
)
{
return
data_
.
d16x2_
;
return
data_
.
d4x1_
;
}
else
if
constexpr
(
is_same
<
X
,
d
32
_t
>::
value
)
else
if
constexpr
(
is_same
<
X
,
d
7
_t
>::
value
)
{
return
data_
.
d
32
x1_
;
return
data_
.
d
7
x1_
;
}
else
{
...
...
@@ -597,64 +882,49 @@ struct vector_type<T, 32, typename std::enable_if_t<is_native_type<T>()>>
__host__
__device__
constexpr
auto
&
AsType
()
{
static_assert
(
is_same
<
X
,
d1_t
>::
value
||
is_same
<
X
,
d2_t
>::
value
||
is_same
<
X
,
d4_t
>::
value
||
is_same
<
X
,
d8_t
>::
value
||
is_same
<
X
,
d16_t
>::
value
||
is_same
<
X
,
d32_t
>::
value
,
is_same
<
X
,
d4_t
>::
value
||
is_same
<
X
,
d7_t
>::
value
,
"Something went wrong, please check src and dst types."
);
if
constexpr
(
is_same
<
X
,
d1_t
>::
value
)
{
return
data_
.
d1x
32
_
;
return
data_
.
d1x
7
_
;
}
else
if
constexpr
(
is_same
<
X
,
d2_t
>::
value
)
{
return
data_
.
d2x
16
_
;
return
data_
.
d2x
3
_
;
}
else
if
constexpr
(
is_same
<
X
,
d4_t
>::
value
)
{
return
data_
.
d4x
8
_
;
return
data_
.
d4x
1
_
;
}
else
if
constexpr
(
is_same
<
X
,
d
8
_t
>::
value
)
else
if
constexpr
(
is_same
<
X
,
d
7
_t
>::
value
)
{
return
data_
.
d
8x4
_
;
return
data_
.
d
7x1
_
;
}
else
if
constexpr
(
is_same
<
X
,
d16_t
>::
value
)
else
{
return
data_
.
d16x2_
;
}
else
if
constexpr
(
is_same
<
X
,
d32_t
>::
value
)
{
return
data_
.
d32x1_
;
}
else
{
return
err
;
return
err
;
}
}
};
template
<
typename
T
>
struct
vector_type
<
T
,
64
,
typename
std
::
enable_if_t
<
is_native_type
<
T
>
()
>>
struct
vector_type
<
T
,
8
,
typename
ck
::
enable_if_t
<
is_native_type
<
T
>
()
>>
{
using
d1_t
=
T
;
typedef
T
d2_t
__attribute__
((
ext_vector_type
(
2
)));
typedef
T
d4_t
__attribute__
((
ext_vector_type
(
4
)));
typedef
T
d8_t
__attribute__
((
ext_vector_type
(
8
)));
typedef
T
d16_t
__attribute__
((
ext_vector_type
(
16
)));
typedef
T
d32_t
__attribute__
((
ext_vector_type
(
32
)));
typedef
T
d64_t
__attribute__
((
ext_vector_type
(
64
)));
using
type
=
d
64
_t
;
using
type
=
d
8
_t
;
union
{
d64_t
d64_
;
StaticallyIndexedArray
<
d1_t
,
64
>
d1x64_
;
StaticallyIndexedArray
<
d2_t
,
32
>
d2x32_
;
StaticallyIndexedArray
<
d4_t
,
16
>
d4x16_
;
StaticallyIndexedArray
<
d8_t
,
8
>
d8x8_
;
StaticallyIndexedArray
<
d16_t
,
4
>
d16x4_
;
StaticallyIndexedArray
<
d32_t
,
2
>
d32x2_
;
StaticallyIndexedArray
<
d64_t
,
1
>
d64x1_
;
d8_t
d8_
;
StaticallyIndexedArray
<
d1_t
,
8
>
d1x8_
;
StaticallyIndexedArray
<
d2_t
,
4
>
d2x4_
;
StaticallyIndexedArray
<
d4_t
,
2
>
d4x2_
;
StaticallyIndexedArray
<
d8_t
,
1
>
d8x1_
;
}
data_
;
__host__
__device__
constexpr
vector_type
()
:
data_
{
type
{
0
}}
{}
...
...
@@ -665,81 +935,135 @@ struct vector_type<T, 64, typename std::enable_if_t<is_native_type<T>()>>
__host__
__device__
constexpr
const
auto
&
AsType
()
const
{
static_assert
(
is_same
<
X
,
d1_t
>::
value
||
is_same
<
X
,
d2_t
>::
value
||
is_same
<
X
,
d4_t
>::
value
||
is_same
<
X
,
d8_t
>::
value
||
is_same
<
X
,
d16_t
>::
value
||
is_same
<
X
,
d32_t
>::
value
||
is_same
<
X
,
d64_t
>::
value
,
is_same
<
X
,
d4_t
>::
value
||
is_same
<
X
,
d8_t
>::
value
,
"Something went wrong, please check src and dst types."
);
if
constexpr
(
is_same
<
X
,
d1_t
>::
value
)
{
return
data_
.
d1x
64
_
;
return
data_
.
d1x
8
_
;
}
else
if
constexpr
(
is_same
<
X
,
d2_t
>::
value
)
{
return
data_
.
d2x
32
_
;
return
data_
.
d2x
4
_
;
}
else
if
constexpr
(
is_same
<
X
,
d4_t
>::
value
)
{
return
data_
.
d4x
16
_
;
return
data_
.
d4x
2
_
;
}
else
if
constexpr
(
is_same
<
X
,
d8_t
>::
value
)
{
return
data_
.
d8x
8
_
;
return
data_
.
d8x
1
_
;
}
else
if
constexpr
(
is_same
<
X
,
d16_t
>::
value
)
else
{
return
data_
.
d16x4_
;
return
err
;
}
else
if
constexpr
(
is_same
<
X
,
d32_t
>::
value
)
}
template
<
typename
X
>
__host__
__device__
constexpr
auto
&
AsType
()
{
return
data_
.
d32x2_
;
static_assert
(
is_same
<
X
,
d1_t
>::
value
||
is_same
<
X
,
d2_t
>::
value
||
is_same
<
X
,
d4_t
>::
value
||
is_same
<
X
,
d8_t
>::
value
,
"Something went wrong, please check src and dst types."
);
if
constexpr
(
is_same
<
X
,
d1_t
>::
value
)
{
return
data_
.
d1x8_
;
}
else
if
constexpr
(
is_same
<
X
,
d
64
_t
>::
value
)
else
if
constexpr
(
is_same
<
X
,
d
2
_t
>::
value
)
{
return
data_
.
d64x1_
;
return
data_
.
d2x4_
;
}
else
if
constexpr
(
is_same
<
X
,
d4_t
>::
value
)
{
return
data_
.
d4x2_
;
}
else
if
constexpr
(
is_same
<
X
,
d8_t
>::
value
)
{
return
data_
.
d8x1_
;
}
else
{
return
err
;
}
}
};
template
<
typename
T
>
struct
vector_type
<
T
,
13
,
typename
ck
::
enable_if_t
<
is_native_type
<
T
>
()
>>
{
using
d1_t
=
T
;
typedef
T
d4_t
__attribute__
((
ext_vector_type
(
4
)));
typedef
T
d8_t
__attribute__
((
ext_vector_type
(
8
)));
typedef
T
d13_t
__attribute__
((
ext_vector_type
(
13
)));
using
type
=
d13_t
;
union
{
d13_t
d13_
;
StaticallyIndexedArray
<
d1_t
,
13
>
d1x13_
;
StaticallyIndexedArray
<
d4_t
,
3
>
d4x3_
;
StaticallyIndexedArray
<
d8_t
,
1
>
d8x1_
;
StaticallyIndexedArray
<
d13_t
,
1
>
d13x1_
;
}
data_
;
__host__
__device__
constexpr
vector_type
()
:
data_
{
type
{
0
}}
{}
__host__
__device__
constexpr
vector_type
(
type
v
)
:
data_
{
v
}
{}
template
<
typename
X
>
__host__
__device__
constexpr
auto
&
AsType
()
__host__
__device__
constexpr
const
auto
&
AsType
()
const
{
static_assert
(
is_same
<
X
,
d1_t
>::
value
||
is_same
<
X
,
d2_t
>::
value
||
is_same
<
X
,
d4_t
>::
value
||
is_same
<
X
,
d8_t
>::
value
||
is_same
<
X
,
d16_t
>::
value
||
is_same
<
X
,
d32_t
>::
value
||
is_same
<
X
,
d64_t
>::
value
,
static_assert
(
is_same
<
X
,
d1_t
>::
value
||
is_same
<
X
,
d4_t
>::
value
||
is_same
<
X
,
d8_t
>::
value
||
is_same
<
X
,
d13_t
>::
value
,
"Something went wrong, please check src and dst types."
);
if
constexpr
(
is_same
<
X
,
d1_t
>::
value
)
{
return
data_
.
d1x64_
;
}
else
if
constexpr
(
is_same
<
X
,
d2_t
>::
value
)
{
return
data_
.
d2x32_
;
return
data_
.
d1x13_
;
}
else
if
constexpr
(
is_same
<
X
,
d4_t
>::
value
)
{
return
data_
.
d4x
16
_
;
return
data_
.
d4x
3
_
;
}
else
if
constexpr
(
is_same
<
X
,
d8_t
>::
value
)
{
return
data_
.
d8x
8
_
;
return
data_
.
d8x
1
_
;
}
else
if
constexpr
(
is_same
<
X
,
d1
6
_t
>::
value
)
else
if
constexpr
(
is_same
<
X
,
d1
3
_t
>::
value
)
{
return
data_
.
d1
6x4
_
;
return
data_
.
d1
3x1
_
;
}
else
if
constexpr
(
is_same
<
X
,
d32_t
>::
value
)
else
{
return
data_
.
d32x2_
;
return
err
;
}
else
if
constexpr
(
is_same
<
X
,
d64_t
>::
value
)
}
template
<
typename
X
>
__host__
__device__
constexpr
auto
&
AsType
()
{
return
data_
.
d64x1_
;
static_assert
(
is_same
<
X
,
d1_t
>::
value
||
is_same
<
X
,
d4_t
>::
value
||
is_same
<
X
,
d8_t
>::
value
||
is_same
<
X
,
d13_t
>::
value
,
"Something went wrong, please check src and dst types."
);
if
constexpr
(
is_same
<
X
,
d1_t
>::
value
)
{
return
data_
.
d1x13_
;
}
else
if
constexpr
(
is_same
<
X
,
d4_t
>::
value
)
{
return
data_
.
d4x3_
;
}
else
if
constexpr
(
is_same
<
X
,
d8_t
>::
value
)
{
return
data_
.
d8x1_
;
}
else
if
constexpr
(
is_same
<
X
,
d13_t
>::
value
)
{
return
data_
.
d13x1_
;
}
else
{
...
...
@@ -749,30 +1073,24 @@ struct vector_type<T, 64, typename std::enable_if_t<is_native_type<T>()>>
};
template
<
typename
T
>
struct
vector_type
<
T
,
1
28
,
typename
std
::
enable_if_t
<
is_native_type
<
T
>
()
>>
struct
vector_type
<
T
,
1
6
,
typename
ck
::
enable_if_t
<
is_native_type
<
T
>
()
>>
{
using
d1_t
=
T
;
typedef
T
d2_t
__attribute__
((
ext_vector_type
(
2
)));
typedef
T
d4_t
__attribute__
((
ext_vector_type
(
4
)));
typedef
T
d8_t
__attribute__
((
ext_vector_type
(
8
)));
typedef
T
d16_t
__attribute__
((
ext_vector_type
(
16
)));
typedef
T
d32_t
__attribute__
((
ext_vector_type
(
32
)));
typedef
T
d64_t
__attribute__
((
ext_vector_type
(
64
)));
typedef
T
d128_t
__attribute__
((
ext_vector_type
(
128
)));
using
type
=
d1
28
_t
;
using
type
=
d1
6
_t
;
union
{
d128_t
d128_
;
StaticallyIndexedArray
<
d1_t
,
128
>
d1x128_
;
StaticallyIndexedArray
<
d2_t
,
64
>
d2x64_
;
StaticallyIndexedArray
<
d4_t
,
32
>
d4x32_
;
StaticallyIndexedArray
<
d8_t
,
16
>
d8x16_
;
StaticallyIndexedArray
<
d16_t
,
8
>
d16x8_
;
StaticallyIndexedArray
<
d32_t
,
4
>
d32x4_
;
StaticallyIndexedArray
<
d64_t
,
2
>
d64x2_
;
StaticallyIndexedArray
<
d128_t
,
1
>
d128x1_
;
d16_t
d16_
;
StaticallyIndexedArray
<
d1_t
,
16
>
d1x16_
;
StaticallyIndexedArray
<
d2_t
,
8
>
d2x8_
;
StaticallyIndexedArray
<
d4_t
,
4
>
d4x4_
;
StaticallyIndexedArray
<
d8_t
,
2
>
d8x2_
;
StaticallyIndexedArray
<
d16_t
,
1
>
d16x1_
;
}
data_
;
__host__
__device__
constexpr
vector_type
()
:
data_
{
type
{
0
}}
{}
...
...
@@ -784,41 +1102,28 @@ struct vector_type<T, 128, typename std::enable_if_t<is_native_type<T>()>>
{
static_assert
(
is_same
<
X
,
d1_t
>::
value
||
is_same
<
X
,
d2_t
>::
value
||
is_same
<
X
,
d4_t
>::
value
||
is_same
<
X
,
d8_t
>::
value
||
is_same
<
X
,
d16_t
>::
value
||
is_same
<
X
,
d32_t
>::
value
||
is_same
<
X
,
d64_t
>::
value
||
is_same
<
X
,
d128_t
>::
value
,
is_same
<
X
,
d16_t
>::
value
,
"Something went wrong, please check src and dst types."
);
if
constexpr
(
is_same
<
X
,
d1_t
>::
value
)
{
return
data_
.
d1x1
28
_
;
return
data_
.
d1x1
6
_
;
}
else
if
constexpr
(
is_same
<
X
,
d2_t
>::
value
)
{
return
data_
.
d2x
64
_
;
return
data_
.
d2x
8
_
;
}
else
if
constexpr
(
is_same
<
X
,
d4_t
>::
value
)
{
return
data_
.
d4x
32
_
;
return
data_
.
d4x
4
_
;
}
else
if
constexpr
(
is_same
<
X
,
d8_t
>::
value
)
{
return
data_
.
d8x
16
_
;
return
data_
.
d8x
2
_
;
}
else
if
constexpr
(
is_same
<
X
,
d16_t
>::
value
)
{
return
data_
.
d16x8_
;
}
else
if
constexpr
(
is_same
<
X
,
d32_t
>::
value
)
{
return
data_
.
d32x4_
;
}
else
if
constexpr
(
is_same
<
X
,
d64_t
>::
value
)
{
return
data_
.
d64x2_
;
}
else
if
constexpr
(
is_same
<
X
,
d128_t
>::
value
)
{
return
data_
.
d128x1_
;
return
data_
.
d16x1_
;
}
else
{
...
...
@@ -831,41 +1136,28 @@ struct vector_type<T, 128, typename std::enable_if_t<is_native_type<T>()>>
{
static_assert
(
is_same
<
X
,
d1_t
>::
value
||
is_same
<
X
,
d2_t
>::
value
||
is_same
<
X
,
d4_t
>::
value
||
is_same
<
X
,
d8_t
>::
value
||
is_same
<
X
,
d16_t
>::
value
||
is_same
<
X
,
d32_t
>::
value
||
is_same
<
X
,
d64_t
>::
value
||
is_same
<
X
,
d128_t
>::
value
,
is_same
<
X
,
d16_t
>::
value
,
"Something went wrong, please check src and dst types."
);
if
constexpr
(
is_same
<
X
,
d1_t
>::
value
)
{
return
data_
.
d1x1
28
_
;
return
data_
.
d1x1
6
_
;
}
else
if
constexpr
(
is_same
<
X
,
d2_t
>::
value
)
{
return
data_
.
d2x
64
_
;
return
data_
.
d2x
8
_
;
}
else
if
constexpr
(
is_same
<
X
,
d4_t
>::
value
)
{
return
data_
.
d4x
32
_
;
return
data_
.
d4x
4
_
;
}
else
if
constexpr
(
is_same
<
X
,
d8_t
>::
value
)
{
return
data_
.
d8x
16
_
;
return
data_
.
d8x
2
_
;
}
else
if
constexpr
(
is_same
<
X
,
d16_t
>::
value
)
{
return
data_
.
d16x8_
;
}
else
if
constexpr
(
is_same
<
X
,
d32_t
>::
value
)
{
return
data_
.
d32x4_
;
}
else
if
constexpr
(
is_same
<
X
,
d64_t
>::
value
)
{
return
data_
.
d64x2_
;
}
else
if
constexpr
(
is_same
<
X
,
d128_t
>::
value
)
{
return
data_
.
d128x1_
;
return
data_
.
d16x1_
;
}
else
{
...
...
@@ -875,7 +1167,7 @@ struct vector_type<T, 128, typename std::enable_if_t<is_native_type<T>()>>
};
template
<
typename
T
>
struct
vector_type
<
T
,
2
56
,
typename
std
::
enable_if_t
<
is_native_type
<
T
>
()
>>
struct
vector_type
<
T
,
3
2
,
typename
ck
::
enable_if_t
<
is_native_type
<
T
>
()
>>
{
using
d1_t
=
T
;
typedef
T
d2_t
__attribute__
((
ext_vector_type
(
2
)));
...
...
@@ -883,24 +1175,18 @@ struct vector_type<T, 256, typename std::enable_if_t<is_native_type<T>()>>
typedef
T
d8_t
__attribute__
((
ext_vector_type
(
8
)));
typedef
T
d16_t
__attribute__
((
ext_vector_type
(
16
)));
typedef
T
d32_t
__attribute__
((
ext_vector_type
(
32
)));
typedef
T
d64_t
__attribute__
((
ext_vector_type
(
64
)));
typedef
T
d128_t
__attribute__
((
ext_vector_type
(
128
)));
typedef
T
d256_t
__attribute__
((
ext_vector_type
(
256
)));
using
type
=
d2
56
_t
;
using
type
=
d
3
2_t
;
union
{
d256_t
d256_
;
StaticallyIndexedArray
<
d1_t
,
256
>
d1x256_
;
StaticallyIndexedArray
<
d2_t
,
128
>
d2x128_
;
StaticallyIndexedArray
<
d4_t
,
64
>
d4x64_
;
StaticallyIndexedArray
<
d8_t
,
32
>
d8x32_
;
StaticallyIndexedArray
<
d16_t
,
16
>
d16x16_
;
StaticallyIndexedArray
<
d32_t
,
8
>
d32x8_
;
StaticallyIndexedArray
<
d64_t
,
4
>
d64x4_
;
StaticallyIndexedArray
<
d128_t
,
2
>
d128x2_
;
StaticallyIndexedArray
<
d256_t
,
1
>
d256x1_
;
d32_t
d32_
;
StaticallyIndexedArray
<
d1_t
,
32
>
d1x32_
;
StaticallyIndexedArray
<
d2_t
,
16
>
d2x16_
;
StaticallyIndexedArray
<
d4_t
,
8
>
d4x8_
;
StaticallyIndexedArray
<
d8_t
,
4
>
d8x4_
;
StaticallyIndexedArray
<
d16_t
,
2
>
d16x2_
;
StaticallyIndexedArray
<
d32_t
,
1
>
d32x1_
;
}
data_
;
__host__
__device__
constexpr
vector_type
()
:
data_
{
type
{
0
}}
{}
...
...
@@ -910,47 +1196,34 @@ struct vector_type<T, 256, typename std::enable_if_t<is_native_type<T>()>>
template
<
typename
X
>
__host__
__device__
constexpr
const
auto
&
AsType
()
const
{
static_assert
(
is_same
<
X
,
d1_t
>::
value
||
is_same
<
X
,
d2_t
>::
value
||
is_same
<
X
,
d4_t
>::
value
||
is_same
<
X
,
d8_t
>::
value
||
is_same
<
X
,
d16_t
>::
value
||
is_same
<
X
,
d32_t
>::
value
||
is_same
<
X
,
d64_t
>::
value
||
is_same
<
X
,
d128_t
>::
value
||
is_same
<
X
,
d256_t
>::
value
,
static_assert
(
is_same
<
X
,
d1_t
>::
value
||
is_same
<
X
,
d2_t
>::
value
||
is_same
<
X
,
d4_t
>::
value
||
is_same
<
X
,
d8_t
>::
value
||
is_same
<
X
,
d16_t
>::
value
||
is_same
<
X
,
d32_t
>::
value
,
"Something went wrong, please check src and dst types."
);
if
constexpr
(
is_same
<
X
,
d1_t
>::
value
)
{
return
data_
.
d1x2
56
_
;
return
data_
.
d1x
3
2_
;
}
else
if
constexpr
(
is_same
<
X
,
d2_t
>::
value
)
{
return
data_
.
d2x1
28
_
;
return
data_
.
d2x1
6
_
;
}
else
if
constexpr
(
is_same
<
X
,
d4_t
>::
value
)
{
return
data_
.
d4x
64
_
;
return
data_
.
d4x
8
_
;
}
else
if
constexpr
(
is_same
<
X
,
d8_t
>::
value
)
{
return
data_
.
d8x
32
_
;
return
data_
.
d8x
4
_
;
}
else
if
constexpr
(
is_same
<
X
,
d16_t
>::
value
)
{
return
data_
.
d16x
16
_
;
return
data_
.
d16x
2
_
;
}
else
if
constexpr
(
is_same
<
X
,
d32_t
>::
value
)
{
return
data_
.
d32x8_
;
}
else
if
constexpr
(
is_same
<
X
,
d64_t
>::
value
)
{
return
data_
.
d64x4_
;
}
else
if
constexpr
(
is_same
<
X
,
d128_t
>::
value
)
{
return
data_
.
d128x2_
;
}
else
if
constexpr
(
is_same
<
X
,
d256_t
>::
value
)
{
return
data_
.
d256x1_
;
return
data_
.
d32x1_
;
}
else
{
...
...
@@ -961,23 +1234,337 @@ struct vector_type<T, 256, typename std::enable_if_t<is_native_type<T>()>>
template
<
typename
X
>
__host__
__device__
constexpr
auto
&
AsType
()
{
static_assert
(
is_same
<
X
,
d1_t
>::
value
||
is_same
<
X
,
d2_t
>::
value
||
is_same
<
X
,
d4_t
>::
value
||
is_same
<
X
,
d8_t
>::
value
||
is_same
<
X
,
d16_t
>::
value
||
is_same
<
X
,
d32_t
>::
value
||
is_same
<
X
,
d64_t
>::
value
||
is_same
<
X
,
d128_t
>::
value
||
is_same
<
X
,
d256_t
>::
value
,
static_assert
(
is_same
<
X
,
d1_t
>::
value
||
is_same
<
X
,
d2_t
>::
value
||
is_same
<
X
,
d4_t
>::
value
||
is_same
<
X
,
d8_t
>::
value
||
is_same
<
X
,
d16_t
>::
value
||
is_same
<
X
,
d32_t
>::
value
,
"Something went wrong, please check src and dst types."
);
if
constexpr
(
is_same
<
X
,
d1_t
>::
value
)
{
return
data_
.
d1x2
56
_
;
return
data_
.
d1x
3
2_
;
}
else
if
constexpr
(
is_same
<
X
,
d2_t
>::
value
)
{
return
data_
.
d2x1
28
_
;
return
data_
.
d2x1
6
_
;
}
else
if
constexpr
(
is_same
<
X
,
d4_t
>::
value
)
{
return
data_
.
d4x64_
;
return
data_
.
d4x8_
;
}
else
if
constexpr
(
is_same
<
X
,
d8_t
>::
value
)
{
return
data_
.
d8x4_
;
}
else
if
constexpr
(
is_same
<
X
,
d16_t
>::
value
)
{
return
data_
.
d16x2_
;
}
else
if
constexpr
(
is_same
<
X
,
d32_t
>::
value
)
{
return
data_
.
d32x1_
;
}
else
{
return
err
;
}
}
};
template
<
typename
T
>
struct
vector_type
<
T
,
64
,
typename
ck
::
enable_if_t
<
is_native_type
<
T
>
()
>>
{
using
d1_t
=
T
;
typedef
T
d2_t
__attribute__
((
ext_vector_type
(
2
)));
typedef
T
d4_t
__attribute__
((
ext_vector_type
(
4
)));
typedef
T
d8_t
__attribute__
((
ext_vector_type
(
8
)));
typedef
T
d16_t
__attribute__
((
ext_vector_type
(
16
)));
typedef
T
d32_t
__attribute__
((
ext_vector_type
(
32
)));
typedef
T
d64_t
__attribute__
((
ext_vector_type
(
64
)));
using
type
=
d64_t
;
union
{
d64_t
d64_
;
StaticallyIndexedArray
<
d1_t
,
64
>
d1x64_
;
StaticallyIndexedArray
<
d2_t
,
32
>
d2x32_
;
StaticallyIndexedArray
<
d4_t
,
16
>
d4x16_
;
StaticallyIndexedArray
<
d8_t
,
8
>
d8x8_
;
StaticallyIndexedArray
<
d16_t
,
4
>
d16x4_
;
StaticallyIndexedArray
<
d32_t
,
2
>
d32x2_
;
StaticallyIndexedArray
<
d64_t
,
1
>
d64x1_
;
}
data_
;
__host__
__device__
constexpr
vector_type
()
:
data_
{
type
{
0
}}
{}
__host__
__device__
constexpr
vector_type
(
type
v
)
:
data_
{
v
}
{}
template
<
typename
X
>
__host__
__device__
constexpr
const
auto
&
AsType
()
const
{
static_assert
(
is_same
<
X
,
d1_t
>::
value
||
is_same
<
X
,
d2_t
>::
value
||
is_same
<
X
,
d4_t
>::
value
||
is_same
<
X
,
d8_t
>::
value
||
is_same
<
X
,
d16_t
>::
value
||
is_same
<
X
,
d32_t
>::
value
||
is_same
<
X
,
d64_t
>::
value
,
"Something went wrong, please check src and dst types."
);
if
constexpr
(
is_same
<
X
,
d1_t
>::
value
)
{
return
data_
.
d1x64_
;
}
else
if
constexpr
(
is_same
<
X
,
d2_t
>::
value
)
{
return
data_
.
d2x32_
;
}
else
if
constexpr
(
is_same
<
X
,
d4_t
>::
value
)
{
return
data_
.
d4x16_
;
}
else
if
constexpr
(
is_same
<
X
,
d8_t
>::
value
)
{
return
data_
.
d8x8_
;
}
else
if
constexpr
(
is_same
<
X
,
d16_t
>::
value
)
{
return
data_
.
d16x4_
;
}
else
if
constexpr
(
is_same
<
X
,
d32_t
>::
value
)
{
return
data_
.
d32x2_
;
}
else
if
constexpr
(
is_same
<
X
,
d64_t
>::
value
)
{
return
data_
.
d64x1_
;
}
else
{
return
err
;
}
}
template
<
typename
X
>
__host__
__device__
constexpr
auto
&
AsType
()
{
static_assert
(
is_same
<
X
,
d1_t
>::
value
||
is_same
<
X
,
d2_t
>::
value
||
is_same
<
X
,
d4_t
>::
value
||
is_same
<
X
,
d8_t
>::
value
||
is_same
<
X
,
d16_t
>::
value
||
is_same
<
X
,
d32_t
>::
value
||
is_same
<
X
,
d64_t
>::
value
,
"Something went wrong, please check src and dst types."
);
if
constexpr
(
is_same
<
X
,
d1_t
>::
value
)
{
return
data_
.
d1x64_
;
}
else
if
constexpr
(
is_same
<
X
,
d2_t
>::
value
)
{
return
data_
.
d2x32_
;
}
else
if
constexpr
(
is_same
<
X
,
d4_t
>::
value
)
{
return
data_
.
d4x16_
;
}
else
if
constexpr
(
is_same
<
X
,
d8_t
>::
value
)
{
return
data_
.
d8x8_
;
}
else
if
constexpr
(
is_same
<
X
,
d16_t
>::
value
)
{
return
data_
.
d16x4_
;
}
else
if
constexpr
(
is_same
<
X
,
d32_t
>::
value
)
{
return
data_
.
d32x2_
;
}
else
if
constexpr
(
is_same
<
X
,
d64_t
>::
value
)
{
return
data_
.
d64x1_
;
}
else
{
return
err
;
}
}
};
template
<
typename
T
>
struct
vector_type
<
T
,
128
,
typename
ck
::
enable_if_t
<
is_native_type
<
T
>
()
>>
{
using
d1_t
=
T
;
typedef
T
d2_t
__attribute__
((
ext_vector_type
(
2
)));
typedef
T
d4_t
__attribute__
((
ext_vector_type
(
4
)));
typedef
T
d8_t
__attribute__
((
ext_vector_type
(
8
)));
typedef
T
d16_t
__attribute__
((
ext_vector_type
(
16
)));
typedef
T
d32_t
__attribute__
((
ext_vector_type
(
32
)));
typedef
T
d64_t
__attribute__
((
ext_vector_type
(
64
)));
typedef
T
d128_t
__attribute__
((
ext_vector_type
(
128
)));
using
type
=
d128_t
;
union
{
d128_t
d128_
;
StaticallyIndexedArray
<
d1_t
,
128
>
d1x128_
;
StaticallyIndexedArray
<
d2_t
,
64
>
d2x64_
;
StaticallyIndexedArray
<
d4_t
,
32
>
d4x32_
;
StaticallyIndexedArray
<
d8_t
,
16
>
d8x16_
;
StaticallyIndexedArray
<
d16_t
,
8
>
d16x8_
;
StaticallyIndexedArray
<
d32_t
,
4
>
d32x4_
;
StaticallyIndexedArray
<
d64_t
,
2
>
d64x2_
;
StaticallyIndexedArray
<
d128_t
,
1
>
d128x1_
;
}
data_
;
__host__
__device__
constexpr
vector_type
()
:
data_
{
type
{
0
}}
{}
__host__
__device__
constexpr
vector_type
(
type
v
)
:
data_
{
v
}
{}
template
<
typename
X
>
__host__
__device__
constexpr
const
auto
&
AsType
()
const
{
static_assert
(
is_same
<
X
,
d1_t
>::
value
||
is_same
<
X
,
d2_t
>::
value
||
is_same
<
X
,
d4_t
>::
value
||
is_same
<
X
,
d8_t
>::
value
||
is_same
<
X
,
d16_t
>::
value
||
is_same
<
X
,
d32_t
>::
value
||
is_same
<
X
,
d64_t
>::
value
||
is_same
<
X
,
d128_t
>::
value
,
"Something went wrong, please check src and dst types."
);
if
constexpr
(
is_same
<
X
,
d1_t
>::
value
)
{
return
data_
.
d1x128_
;
}
else
if
constexpr
(
is_same
<
X
,
d2_t
>::
value
)
{
return
data_
.
d2x64_
;
}
else
if
constexpr
(
is_same
<
X
,
d4_t
>::
value
)
{
return
data_
.
d4x32_
;
}
else
if
constexpr
(
is_same
<
X
,
d8_t
>::
value
)
{
return
data_
.
d8x16_
;
}
else
if
constexpr
(
is_same
<
X
,
d16_t
>::
value
)
{
return
data_
.
d16x8_
;
}
else
if
constexpr
(
is_same
<
X
,
d32_t
>::
value
)
{
return
data_
.
d32x4_
;
}
else
if
constexpr
(
is_same
<
X
,
d64_t
>::
value
)
{
return
data_
.
d64x2_
;
}
else
if
constexpr
(
is_same
<
X
,
d128_t
>::
value
)
{
return
data_
.
d128x1_
;
}
else
{
return
err
;
}
}
template
<
typename
X
>
__host__
__device__
constexpr
auto
&
AsType
()
{
static_assert
(
is_same
<
X
,
d1_t
>::
value
||
is_same
<
X
,
d2_t
>::
value
||
is_same
<
X
,
d4_t
>::
value
||
is_same
<
X
,
d8_t
>::
value
||
is_same
<
X
,
d16_t
>::
value
||
is_same
<
X
,
d32_t
>::
value
||
is_same
<
X
,
d64_t
>::
value
||
is_same
<
X
,
d128_t
>::
value
,
"Something went wrong, please check src and dst types."
);
if
constexpr
(
is_same
<
X
,
d1_t
>::
value
)
{
return
data_
.
d1x128_
;
}
else
if
constexpr
(
is_same
<
X
,
d2_t
>::
value
)
{
return
data_
.
d2x64_
;
}
else
if
constexpr
(
is_same
<
X
,
d4_t
>::
value
)
{
return
data_
.
d4x32_
;
}
else
if
constexpr
(
is_same
<
X
,
d8_t
>::
value
)
{
return
data_
.
d8x16_
;
}
else
if
constexpr
(
is_same
<
X
,
d16_t
>::
value
)
{
return
data_
.
d16x8_
;
}
else
if
constexpr
(
is_same
<
X
,
d32_t
>::
value
)
{
return
data_
.
d32x4_
;
}
else
if
constexpr
(
is_same
<
X
,
d64_t
>::
value
)
{
return
data_
.
d64x2_
;
}
else
if
constexpr
(
is_same
<
X
,
d128_t
>::
value
)
{
return
data_
.
d128x1_
;
}
else
{
return
err
;
}
}
};
template
<
typename
T
>
struct
vector_type
<
T
,
256
,
typename
ck
::
enable_if_t
<
is_native_type
<
T
>
()
>>
{
using
d1_t
=
T
;
typedef
T
d2_t
__attribute__
((
ext_vector_type
(
2
)));
typedef
T
d4_t
__attribute__
((
ext_vector_type
(
4
)));
typedef
T
d8_t
__attribute__
((
ext_vector_type
(
8
)));
typedef
T
d16_t
__attribute__
((
ext_vector_type
(
16
)));
typedef
T
d32_t
__attribute__
((
ext_vector_type
(
32
)));
typedef
T
d64_t
__attribute__
((
ext_vector_type
(
64
)));
typedef
T
d128_t
__attribute__
((
ext_vector_type
(
128
)));
typedef
T
d256_t
__attribute__
((
ext_vector_type
(
256
)));
using
type
=
d256_t
;
union
{
d256_t
d256_
;
StaticallyIndexedArray
<
d1_t
,
256
>
d1x256_
;
StaticallyIndexedArray
<
d2_t
,
128
>
d2x128_
;
StaticallyIndexedArray
<
d4_t
,
64
>
d4x64_
;
StaticallyIndexedArray
<
d8_t
,
32
>
d8x32_
;
StaticallyIndexedArray
<
d16_t
,
16
>
d16x16_
;
StaticallyIndexedArray
<
d32_t
,
8
>
d32x8_
;
StaticallyIndexedArray
<
d64_t
,
4
>
d64x4_
;
StaticallyIndexedArray
<
d128_t
,
2
>
d128x2_
;
StaticallyIndexedArray
<
d256_t
,
1
>
d256x1_
;
}
data_
;
__host__
__device__
constexpr
vector_type
()
:
data_
{
type
{
0
}}
{}
__host__
__device__
constexpr
vector_type
(
type
v
)
:
data_
{
v
}
{}
template
<
typename
X
>
__host__
__device__
constexpr
const
auto
&
AsType
()
const
{
static_assert
(
is_same
<
X
,
d1_t
>::
value
||
is_same
<
X
,
d2_t
>::
value
||
is_same
<
X
,
d4_t
>::
value
||
is_same
<
X
,
d8_t
>::
value
||
is_same
<
X
,
d16_t
>::
value
||
is_same
<
X
,
d32_t
>::
value
||
is_same
<
X
,
d64_t
>::
value
||
is_same
<
X
,
d128_t
>::
value
||
is_same
<
X
,
d256_t
>::
value
,
"Something went wrong, please check src and dst types."
);
if
constexpr
(
is_same
<
X
,
d1_t
>::
value
)
{
return
data_
.
d1x256_
;
}
else
if
constexpr
(
is_same
<
X
,
d2_t
>::
value
)
{
return
data_
.
d2x128_
;
}
else
if
constexpr
(
is_same
<
X
,
d4_t
>::
value
)
{
return
data_
.
d4x64_
;
}
else
if
constexpr
(
is_same
<
X
,
d8_t
>::
value
)
{
...
...
@@ -1008,61 +1595,351 @@ struct vector_type<T, 256, typename std::enable_if_t<is_native_type<T>()>>
return
err
;
}
}
template
<
typename
X
>
__host__
__device__
constexpr
auto
&
AsType
()
{
static_assert
(
is_same
<
X
,
d1_t
>::
value
||
is_same
<
X
,
d2_t
>::
value
||
is_same
<
X
,
d4_t
>::
value
||
is_same
<
X
,
d8_t
>::
value
||
is_same
<
X
,
d16_t
>::
value
||
is_same
<
X
,
d32_t
>::
value
||
is_same
<
X
,
d64_t
>::
value
||
is_same
<
X
,
d128_t
>::
value
||
is_same
<
X
,
d256_t
>::
value
,
"Something went wrong, please check src and dst types."
);
if
constexpr
(
is_same
<
X
,
d1_t
>::
value
)
{
return
data_
.
d1x256_
;
}
else
if
constexpr
(
is_same
<
X
,
d2_t
>::
value
)
{
return
data_
.
d2x128_
;
}
else
if
constexpr
(
is_same
<
X
,
d4_t
>::
value
)
{
return
data_
.
d4x64_
;
}
else
if
constexpr
(
is_same
<
X
,
d8_t
>::
value
)
{
return
data_
.
d8x32_
;
}
else
if
constexpr
(
is_same
<
X
,
d16_t
>::
value
)
{
return
data_
.
d16x16_
;
}
else
if
constexpr
(
is_same
<
X
,
d32_t
>::
value
)
{
return
data_
.
d32x8_
;
}
else
if
constexpr
(
is_same
<
X
,
d64_t
>::
value
)
{
return
data_
.
d64x4_
;
}
else
if
constexpr
(
is_same
<
X
,
d128_t
>::
value
)
{
return
data_
.
d128x2_
;
}
else
if
constexpr
(
is_same
<
X
,
d256_t
>::
value
)
{
return
data_
.
d256x1_
;
}
else
{
return
err
;
}
}
};
template
<
typename
T
,
index_t
N
,
typename
Enable
=
void
>
struct
non_native_vector_base
;
template
<
typename
T
>
struct
nnvb_data_t_selector
{
using
type
=
unsigned
_BitInt
(
8
*
sizeof
(
T
));
};
template
<
>
struct
nnvb_data_t_selector
<
f8_ocp_t
>
{
using
type
=
f8_ocp_t
::
data_type
;
};
template
<
>
struct
nnvb_data_t_selector
<
bf8_ocp_t
>
{
using
type
=
bf8_ocp_t
::
data_type
;
};
template
<
>
struct
nnvb_data_t_selector
<
f6x16_pk_t
>
{
using
type
=
f6x16_pk_t
::
type
;
};
template
<
>
struct
nnvb_data_t_selector
<
f6x32_pk_t
>
{
using
type
=
f6x32_pk_t
::
type
;
};
template
<
>
struct
nnvb_data_t_selector
<
bf6x16_pk_t
>
{
using
type
=
bf6x16_pk_t
::
type
;
};
template
<
>
struct
nnvb_data_t_selector
<
bf6x32_pk_t
>
{
using
type
=
bf6x32_pk_t
::
type
;
};
template
<
>
struct
nnvb_data_t_selector
<
pk_i4_t
>
{
using
type
=
pk_i4_t
::
type
;
};
template
<
typename
T
,
index_t
N
>
struct
non_native_vector_base
<
T
,
N
,
ck
::
enable_if_t
<
sizeof
(
T
)
==
1
||
sizeof
(
T
)
==
2
||
sizeof
(
T
)
==
4
||
sizeof
(
T
)
==
8
>>
{
using
data_t
=
typename
nnvb_data_t_selector
<
T
>::
type
;
// select data_t based on the size of T
static_assert
(
sizeof
(
T
)
==
sizeof
(
data_t
),
"non_native_vector_base storage size mismatch"
);
using
data_v
=
data_t
__attribute__
((
ext_vector_type
(
N
)));
using
type
=
non_native_vector_base
<
T
,
N
>
;
union
alignas
(
next_pow2
(
N
*
sizeof
(
T
)))
{
data_v
dN
;
// storage vector;
StaticallyIndexedArray
<
data_t
,
N
>
dxN
;
StaticallyIndexedArray
<
T
,
N
>
dTxN
;
StaticallyIndexedArray
<
data_v
,
1
>
dNx1
;
}
data_
;
__host__
__device__
constexpr
non_native_vector_base
(
data_t
a
)
:
data_
{
data_v
(
a
)}
{}
__host__
__device__
constexpr
non_native_vector_base
(
T
f
)
:
non_native_vector_base
(
bit_cast
<
data_t
>
(
f
))
{
}
__host__
__device__
constexpr
non_native_vector_base
()
:
non_native_vector_base
(
T
{}){};
__host__
__device__
constexpr
non_native_vector_base
(
data_v
v
)
:
data_
{
v
}
{}
__host__
__device__
constexpr
operator
data_v
()
const
{
return
data_
.
dN
;
}
__host__
__device__
constexpr
operator
data_t
()
const
{
if
constexpr
(
N
==
1
)
{
return
data_
.
dxN
[
Number
<
0
>
{}];
}
else
{
return
data_
.
dxN
;
// XXX this should cause an error
}
}
__host__
__device__
constexpr
operator
T
()
const
{
if
constexpr
(
N
==
1
)
{
return
data_
.
dTxN
[
Number
<
0
>
{}];
}
else
{
return
data_
.
dTxN
;
// XXX this should cause an error
}
}
template
<
typename
X
>
__host__
__device__
constexpr
const
auto
&
AsType
()
const
{
static_assert
(
is_same_v
<
X
,
data_t
>
||
is_same_v
<
X
,
T
>
||
is_same_v
<
X
,
data_v
>
,
"Something went wrong, please check src and dst types."
);
if
constexpr
(
is_same_v
<
X
,
data_t
>
)
{
return
data_
.
dxN
;
}
else
if
constexpr
(
is_same_v
<
X
,
T
>
)
{
return
data_
.
dTxN
;
}
else
if
constexpr
(
is_same_v
<
X
,
data_v
>
)
{
return
data_
.
dNx1
;
}
else
{
return
err
;
}
}
template
<
typename
X
>
__host__
__device__
constexpr
auto
&
AsType
()
{
static_assert
(
is_same_v
<
X
,
data_t
>
||
is_same_v
<
X
,
T
>
||
is_same_v
<
X
,
data_v
>
,
"Something went wrong, please check src and dst types."
);
if
constexpr
(
is_same_v
<
X
,
data_t
>
)
{
return
data_
.
dxN
;
}
else
if
constexpr
(
is_same_v
<
X
,
T
>
)
{
return
data_
.
dTxN
;
}
else
if
constexpr
(
is_same_v
<
X
,
data_v
>
)
{
return
data_
.
dNx1
;
}
else
{
return
err
;
}
}
};
// implementation for f6x16 and f6x32
template
<
typename
T
,
index_t
N
>
struct
non_native_vector_base
<
T
,
N
,
std
::
enable_if_t
<
sizeof
(
T
)
==
12
||
sizeof
(
T
)
==
24
>>
{
using
data_t
=
typename
nnvb_data_t_selector
<
T
>::
type
;
// select data_t based on declared base type
using
element_t
=
typename
T
::
element_type
;
// select element_t based on declared element type
static_assert
(
sizeof
(
T
)
==
sizeof
(
data_t
),
"non_native_vector_base storage size mismatch"
);
static
constexpr
size_t
size_factor
=
sizeof
(
data_t
)
/
sizeof
(
element_t
);
// f6x16: 12/4 = 3, f6x32: 24/4 = 6
using
data_v
=
element_t
__attribute__
((
ext_vector_type
(
N
*
size_factor
)));
using
type
=
non_native_vector_base
<
T
,
N
>
;
union
alignas
(
next_pow2
(
N
*
sizeof
(
T
)))
{
data_v
dN
;
// storage vector;
StaticallyIndexedArray
<
data_t
,
N
>
dxN
;
StaticallyIndexedArray
<
T
,
N
>
dTxN
;
StaticallyIndexedArray
<
data_v
,
1
>
dNx1
;
}
data_
;
__host__
__device__
constexpr
non_native_vector_base
(
data_t
a
)
:
data_
{
data_v
(
a
.
At
(
Number
<
0
>
{}))}
{
}
__host__
__device__
constexpr
non_native_vector_base
(
T
f
)
:
non_native_vector_base
(
bit_cast
<
data_t
>
(
f
))
{
}
__host__
__device__
constexpr
non_native_vector_base
()
:
non_native_vector_base
(
T
{}){};
__host__
__device__
constexpr
non_native_vector_base
(
data_v
v
)
:
data_
{
v
}
{}
__host__
__device__
constexpr
operator
data_v
()
const
{
return
data_
.
dN
;
}
__host__
__device__
constexpr
operator
data_t
()
const
{
if
constexpr
(
N
==
1
)
{
return
data_
.
dxN
[
Number
<
0
>
{}];
}
else
{
return
data_
.
dxN
;
// XXX this should cause an error
}
}
__host__
__device__
constexpr
operator
T
()
const
{
if
constexpr
(
N
==
1
)
{
return
data_
.
dTxN
[
Number
<
0
>
{}];
}
else
{
return
data_
.
dTxN
;
// XXX this should cause an error
}
}
};
template
<
typename
T
,
index_t
N
>
struct
scalar_type
<
non_native_vector_base
<
T
,
N
>>
;
template
<
index_t
N
>
struct
scalar_type
<
non_native_vector_base
<
f8_ocp_t
,
N
>>
{
using
type
=
typename
non_native_vector_base
<
f8_ocp_t
,
N
>::
data_t
;
static
constexpr
index_t
vector_size
=
N
;
};
template
<
typename
T
,
index_t
N
>
struct
non_native_vector_base
template
<
index_t
N
>
struct
scalar_type
<
non_native_vector_base
<
bf8_ocp_t
,
N
>>
{
using
type
=
non_native_vector_base
<
T
,
N
>
;
using
type
=
typename
non_native_vector_base
<
bf8_ocp_t
,
N
>::
data_t
;
__host__
__device__
non_native_vector_base
()
=
default
;
__host__
__device__
non_native_vector_base
(
const
type
&
)
=
default
;
__host__
__device__
non_native_vector_base
(
type
&&
)
=
default
;
__host__
__device__
~
non_native_vector_base
()
=
default
;
static
constexpr
index_t
vector_size
=
N
;
};
template
<
index_t
N
>
struct
scalar_type
<
non_native_vector_base
<
pk_i4_t
,
N
>>
{
using
type
=
typename
non_native_vector_base
<
pk_i4_t
,
N
>::
data_t
;
T
d
[
N
]
;
static
constexpr
index_t
vector_size
=
N
;
};
// non-native vector_type implementation
template
<
typename
T
>
struct
vector_type
<
T
,
1
,
typename
std
::
enable_if_t
<!
is_native_type
<
T
>
()
>>
struct
vector_type
<
T
,
1
,
typename
ck
::
enable_if_t
<!
is_native_type
<
T
>
()
>>
{
using
d1_t
=
T
;
using
type
=
d1_t
;
using
d1_nnv_t
=
non_native_vector_base
<
T
,
1
>
;
using
type
=
d1_nnv_t
;
union
alignas
(
next_pow2
(
1
*
sizeof
(
T
)))
{
d1_t
d1_
;
StaticallyIndexedArray
<
d1_t
,
1
>
d1x1_
;
d1_nnv_t
d1_nnv_
;
}
data_
;
__host__
__device__
constexpr
vector_type
()
:
data_
{
type
{}}
{}
__host__
__device__
constexpr
vector_type
()
:
data_
{
d1_t
{}}
{}
__host__
__device__
constexpr
vector_type
(
type
v
)
:
data_
{
v
}
{}
template
<
typename
X
>
__host__
__device__
constexpr
const
auto
&
AsType
()
const
{
static_assert
(
is_same
<
X
,
d1_t
>::
value
,
static_assert
(
is_same
<
X
,
d1_t
>::
value
||
is_same
<
X
,
d1_nnv_t
>::
value
,
"Something went wrong, please check src and dst types."
);
if
constexpr
(
is_same
<
X
,
d1_t
>::
value
||
is_same
<
X
,
d1_nnv_t
>::
value
)
{
return
data_
.
d1x1_
;
}
else
{
return
err
;
}
}
template
<
typename
X
>
__host__
__device__
constexpr
auto
&
AsType
()
{
static_assert
(
is_same
<
X
,
d1_t
>::
value
,
static_assert
(
is_same
<
X
,
d1_t
>::
value
||
is_same
<
X
,
d1_nnv_t
>::
value
,
"Something went wrong, please check src and dst types."
);
if
constexpr
(
is_same
<
X
,
d1_t
>::
value
||
is_same
<
X
,
d1_nnv_t
>::
value
)
{
return
data_
.
d1x1_
;
}
else
{
return
err
;
}
}
};
template
<
typename
T
>
struct
vector_type
<
T
,
2
,
typename
std
::
enable_if_t
<!
is_native_type
<
T
>
()
>>
struct
vector_type
<
T
,
2
,
typename
ck
::
enable_if_t
<!
is_native_type
<
T
>
()
>>
{
using
d1_t
=
T
;
using
d1_nnv_t
=
non_native_vector_base
<
T
,
1
>
;
using
d2_t
=
non_native_vector_base
<
T
,
2
>
;
using
type
=
d2_t
;
...
...
@@ -1081,10 +1958,11 @@ struct vector_type<T, 2, typename std::enable_if_t<!is_native_type<T>()>>
template
<
typename
X
>
__host__
__device__
constexpr
const
auto
&
AsType
()
const
{
static_assert
(
is_same
<
X
,
d1_t
>::
value
||
is_same
<
X
,
d2_t
>::
value
,
static_assert
(
is_same
<
X
,
d1_t
>::
value
||
is_same
<
X
,
d1_nnv_t
>::
value
||
is_same
<
X
,
d2_t
>::
value
,
"Something went wrong, please check src and dst types."
);
if
constexpr
(
is_same
<
X
,
d1_t
>::
value
)
if
constexpr
(
is_same
<
X
,
d1_t
>::
value
||
is_same
<
X
,
d1_nnv_t
>::
value
)
{
return
data_
.
d1x2_
;
}
...
...
@@ -1101,10 +1979,11 @@ struct vector_type<T, 2, typename std::enable_if_t<!is_native_type<T>()>>
template
<
typename
X
>
__host__
__device__
constexpr
auto
&
AsType
()
{
static_assert
(
is_same
<
X
,
d1_t
>::
value
||
is_same
<
X
,
d2_t
>::
value
,
static_assert
(
is_same
<
X
,
d1_t
>::
value
||
is_same
<
X
,
d1_nnv_t
>::
value
||
is_same
<
X
,
d2_t
>::
value
,
"Something went wrong, please check src and dst types."
);
if
constexpr
(
is_same
<
X
,
d1_t
>::
value
)
if
constexpr
(
is_same
<
X
,
d1_t
>::
value
||
is_same
<
X
,
d1_nnv_t
>::
value
)
{
return
data_
.
d1x2_
;
}
...
...
@@ -1120,9 +1999,10 @@ struct vector_type<T, 2, typename std::enable_if_t<!is_native_type<T>()>>
};
template
<
typename
T
>
struct
vector_type
<
T
,
4
,
typename
std
::
enable_if_t
<!
is_native_type
<
T
>
()
>>
struct
vector_type
<
T
,
4
,
typename
ck
::
enable_if_t
<!
is_native_type
<
T
>
()
>>
{
using
d1_t
=
T
;
using
d1_nnv_t
=
non_native_vector_base
<
T
,
1
>
;
using
d2_t
=
non_native_vector_base
<
T
,
2
>
;
using
d4_t
=
non_native_vector_base
<
T
,
4
>
;
...
...
@@ -1143,10 +2023,11 @@ struct vector_type<T, 4, typename std::enable_if_t<!is_native_type<T>()>>
template
<
typename
X
>
__host__
__device__
constexpr
const
auto
&
AsType
()
const
{
static_assert
(
is_same
<
X
,
d1_t
>::
value
||
is_same
<
X
,
d2_t
>::
value
||
is_same
<
X
,
d4_t
>::
value
,
static_assert
(
is_same
<
X
,
d1_t
>::
value
||
is_same
<
X
,
d1_nnv_t
>::
value
||
is_same
<
X
,
d2_t
>::
value
||
is_same
<
X
,
d4_t
>::
value
,
"Something went wrong, please check src and dst types."
);
if
constexpr
(
is_same
<
X
,
d1_t
>::
value
)
if
constexpr
(
is_same
<
X
,
d1_t
>::
value
||
is_same
<
X
,
d1_nnv_t
>::
value
)
{
return
data_
.
d1x4_
;
}
...
...
@@ -1167,10 +2048,11 @@ struct vector_type<T, 4, typename std::enable_if_t<!is_native_type<T>()>>
template
<
typename
X
>
__host__
__device__
constexpr
auto
&
AsType
()
{
static_assert
(
is_same
<
X
,
d1_t
>::
value
||
is_same
<
X
,
d2_t
>::
value
||
is_same
<
X
,
d4_t
>::
value
,
static_assert
(
is_same
<
X
,
d1_t
>::
value
||
is_same
<
X
,
d1_nnv_t
>::
value
||
is_same
<
X
,
d2_t
>::
value
||
is_same
<
X
,
d4_t
>::
value
,
"Something went wrong, please check src and dst types."
);
if
constexpr
(
is_same
<
X
,
d1_t
>::
value
)
if
constexpr
(
is_same
<
X
,
d1_t
>::
value
||
is_same
<
X
,
d1_nnv_t
>::
value
)
{
return
data_
.
d1x4_
;
}
...
...
@@ -1190,9 +2072,10 @@ struct vector_type<T, 4, typename std::enable_if_t<!is_native_type<T>()>>
};
template
<
typename
T
>
struct
vector_type
<
T
,
8
,
typename
std
::
enable_if_t
<!
is_native_type
<
T
>
()
>>
struct
vector_type
<
T
,
8
,
typename
ck
::
enable_if_t
<!
is_native_type
<
T
>
()
>>
{
using
d1_t
=
T
;
using
d1_nnv_t
=
non_native_vector_base
<
T
,
1
>
;
using
d2_t
=
non_native_vector_base
<
T
,
2
>
;
using
d4_t
=
non_native_vector_base
<
T
,
4
>
;
using
d8_t
=
non_native_vector_base
<
T
,
8
>
;
...
...
@@ -1215,11 +2098,12 @@ struct vector_type<T, 8, typename std::enable_if_t<!is_native_type<T>()>>
template
<
typename
X
>
__host__
__device__
constexpr
const
auto
&
AsType
()
const
{
static_assert
(
is_same
<
X
,
d1_t
>::
value
||
is_same
<
X
,
d2_t
>::
value
||
is_same
<
X
,
d4_t
>::
value
||
is_same
<
X
,
d8_t
>::
value
,
static_assert
(
is_same
<
X
,
d1_t
>::
value
||
is_same
<
X
,
d1_nnv_t
>::
value
||
is_same
<
X
,
d2_t
>::
value
||
is_same
<
X
,
d4_t
>::
value
||
is_same
<
X
,
d8_t
>::
value
,
"Something went wrong, please check src and dst types."
);
if
constexpr
(
is_same
<
X
,
d1_t
>::
value
)
if
constexpr
(
is_same
<
X
,
d1_t
>::
value
||
is_same
<
X
,
d1_nnv_t
>::
value
)
{
return
data_
.
d1x8_
;
}
...
...
@@ -1244,11 +2128,12 @@ struct vector_type<T, 8, typename std::enable_if_t<!is_native_type<T>()>>
template
<
typename
X
>
__host__
__device__
constexpr
auto
&
AsType
()
{
static_assert
(
is_same
<
X
,
d1_t
>::
value
||
is_same
<
X
,
d2_t
>::
value
||
is_same
<
X
,
d4_t
>::
value
||
is_same
<
X
,
d8_t
>::
value
,
static_assert
(
is_same
<
X
,
d1_t
>::
value
||
is_same
<
X
,
d1_nnv_t
>::
value
||
is_same
<
X
,
d2_t
>::
value
||
is_same
<
X
,
d4_t
>::
value
||
is_same
<
X
,
d8_t
>::
value
,
"Something went wrong, please check src and dst types."
);
if
constexpr
(
is_same
<
X
,
d1_t
>::
value
)
if
constexpr
(
is_same
<
X
,
d1_t
>::
value
||
is_same
<
X
,
d1_nnv_t
>::
value
)
{
return
data_
.
d1x8_
;
}
...
...
@@ -1272,9 +2157,10 @@ struct vector_type<T, 8, typename std::enable_if_t<!is_native_type<T>()>>
};
template
<
typename
T
>
struct
vector_type
<
T
,
16
,
typename
std
::
enable_if_t
<!
is_native_type
<
T
>
()
>>
struct
vector_type
<
T
,
16
,
typename
ck
::
enable_if_t
<!
is_native_type
<
T
>
()
>>
{
using
d1_t
=
T
;
using
d1_nnv_t
=
non_native_vector_base
<
T
,
1
>
;
using
d2_t
=
non_native_vector_base
<
T
,
2
>
;
using
d4_t
=
non_native_vector_base
<
T
,
4
>
;
using
d8_t
=
non_native_vector_base
<
T
,
8
>
;
...
...
@@ -1299,12 +2185,12 @@ struct vector_type<T, 16, typename std::enable_if_t<!is_native_type<T>()>>
template
<
typename
X
>
__host__
__device__
constexpr
const
auto
&
AsType
()
const
{
static_assert
(
is_same
<
X
,
d1_t
>::
value
||
is_same
<
X
,
d
2
_t
>::
value
||
is_same
<
X
,
d
4
_t
>::
value
||
is_same
<
X
,
d
8
_t
>::
value
||
is_same
<
X
,
d16_t
>::
value
,
static_assert
(
is_same
<
X
,
d1_t
>::
value
||
is_same
<
X
,
d
1_nnv
_t
>::
value
||
is_same
<
X
,
d
2
_t
>::
value
||
is_same
<
X
,
d
4
_t
>::
value
||
is_same
<
X
,
d8_t
>::
value
||
is_same
<
X
,
d16_t
>::
value
,
"Something went wrong, please check src and dst types."
);
if
constexpr
(
is_same
<
X
,
d1_t
>::
value
)
if
constexpr
(
is_same
<
X
,
d1_t
>::
value
||
is_same
<
X
,
d1_nnv_t
>::
value
)
{
return
data_
.
d1x16_
;
}
...
...
@@ -1333,12 +2219,12 @@ struct vector_type<T, 16, typename std::enable_if_t<!is_native_type<T>()>>
template
<
typename
X
>
__host__
__device__
constexpr
auto
&
AsType
()
{
static_assert
(
is_same
<
X
,
d1_t
>::
value
||
is_same
<
X
,
d
2
_t
>::
value
||
is_same
<
X
,
d
4
_t
>::
value
||
is_same
<
X
,
d
8
_t
>::
value
||
is_same
<
X
,
d16_t
>::
value
,
static_assert
(
is_same
<
X
,
d1_t
>::
value
||
is_same
<
X
,
d
1_nnv
_t
>::
value
||
is_same
<
X
,
d
2
_t
>::
value
||
is_same
<
X
,
d
4
_t
>::
value
||
is_same
<
X
,
d8_t
>::
value
||
is_same
<
X
,
d16_t
>::
value
,
"Something went wrong, please check src and dst types."
);
if
constexpr
(
is_same
<
X
,
d1_t
>::
value
)
if
constexpr
(
is_same
<
X
,
d1_t
>::
value
||
is_same
<
X
,
d1_nnv_t
>::
value
)
{
return
data_
.
d1x16_
;
}
...
...
@@ -1366,7 +2252,7 @@ struct vector_type<T, 16, typename std::enable_if_t<!is_native_type<T>()>>
};
template
<
typename
T
>
struct
vector_type
<
T
,
32
,
typename
std
::
enable_if_t
<!
is_native_type
<
T
>
()
>>
struct
vector_type
<
T
,
32
,
typename
ck
::
enable_if_t
<!
is_native_type
<
T
>
()
>>
{
using
d1_t
=
T
;
using
d2_t
=
non_native_vector_base
<
T
,
2
>
;
...
...
@@ -1470,7 +2356,7 @@ struct vector_type<T, 32, typename std::enable_if_t<!is_native_type<T>()>>
};
template
<
typename
T
>
struct
vector_type
<
T
,
64
,
typename
std
::
enable_if_t
<!
is_native_type
<
T
>
()
>>
struct
vector_type
<
T
,
64
,
typename
ck
::
enable_if_t
<!
is_native_type
<
T
>
()
>>
{
using
d1_t
=
T
;
using
d2_t
=
non_native_vector_base
<
T
,
2
>
;
...
...
@@ -1541,134 +2427,415 @@ struct vector_type<T, 64, typename std::enable_if_t<!is_native_type<T>()>>
}
}
template
<
typename
X
>
__host__
__device__
constexpr
auto
&
AsType
()
{
static_assert
(
is_same
<
X
,
d1_t
>::
value
||
is_same
<
X
,
d2_t
>::
value
||
is_same
<
X
,
d4_t
>::
value
||
is_same
<
X
,
d8_t
>::
value
||
is_same
<
X
,
d16_t
>::
value
||
is_same
<
X
,
d32_t
>::
value
||
is_same
<
X
,
d64_t
>::
value
,
"Something went wrong, please check src and dst types."
);
template
<
typename
X
>
__host__
__device__
constexpr
auto
&
AsType
()
{
static_assert
(
is_same
<
X
,
d1_t
>::
value
||
is_same
<
X
,
d2_t
>::
value
||
is_same
<
X
,
d4_t
>::
value
||
is_same
<
X
,
d8_t
>::
value
||
is_same
<
X
,
d16_t
>::
value
||
is_same
<
X
,
d32_t
>::
value
||
is_same
<
X
,
d64_t
>::
value
,
"Something went wrong, please check src and dst types."
);
if
constexpr
(
is_same
<
X
,
d1_t
>::
value
)
{
return
data_
.
d1x64_
;
}
else
if
constexpr
(
is_same
<
X
,
d2_t
>::
value
)
{
return
data_
.
d2x32_
;
}
else
if
constexpr
(
is_same
<
X
,
d4_t
>::
value
)
{
return
data_
.
d4x16_
;
}
else
if
constexpr
(
is_same
<
X
,
d8_t
>::
value
)
{
return
data_
.
d8x8_
;
}
else
if
constexpr
(
is_same
<
X
,
d16_t
>::
value
)
{
return
data_
.
d16x4_
;
}
else
if
constexpr
(
is_same
<
X
,
d32_t
>::
value
)
{
return
data_
.
d32x2_
;
}
else
if
constexpr
(
is_same
<
X
,
d64_t
>::
value
)
{
return
data_
.
d64x1_
;
}
else
{
return
err
;
}
}
};
using
int64_t
=
long
;
// fp64
using
double2_t
=
typename
vector_type
<
double
,
2
>::
type
;
using
double4_t
=
typename
vector_type
<
double
,
4
>::
type
;
// fp32
using
float2_t
=
typename
vector_type
<
float
,
2
>::
type
;
using
float4_t
=
typename
vector_type
<
float
,
4
>::
type
;
using
float8_t
=
typename
vector_type
<
float
,
8
>::
type
;
using
float16_t
=
typename
vector_type
<
float
,
16
>::
type
;
using
float32_t
=
typename
vector_type
<
float
,
32
>::
type
;
using
float64_t
=
typename
vector_type
<
float
,
64
>::
type
;
// fp16
using
half2_t
=
typename
vector_type
<
half_t
,
2
>::
type
;
using
half4_t
=
typename
vector_type
<
half_t
,
4
>::
type
;
using
half8_t
=
typename
vector_type
<
half_t
,
8
>::
type
;
using
half16_t
=
typename
vector_type
<
half_t
,
16
>::
type
;
using
half32_t
=
typename
vector_type
<
half_t
,
32
>::
type
;
using
half64_t
=
typename
vector_type
<
half_t
,
64
>::
type
;
// bfp16
using
bhalf2_t
=
typename
vector_type
<
bhalf_t
,
2
>::
type
;
using
bhalf4_t
=
typename
vector_type
<
bhalf_t
,
4
>::
type
;
using
bhalf8_t
=
typename
vector_type
<
bhalf_t
,
8
>::
type
;
using
bhalf16_t
=
typename
vector_type
<
bhalf_t
,
16
>::
type
;
using
bhalf32_t
=
typename
vector_type
<
bhalf_t
,
32
>::
type
;
using
bhalf64_t
=
typename
vector_type
<
bhalf_t
,
64
>::
type
;
// i32
using
int32x2_t
=
typename
vector_type
<
int32_t
,
2
>::
type
;
using
int32x4_t
=
typename
vector_type
<
int32_t
,
4
>::
type
;
using
int32x8_t
=
typename
vector_type
<
int32_t
,
8
>::
type
;
using
int32x16_t
=
typename
vector_type
<
int32_t
,
16
>::
type
;
using
int32x32_t
=
typename
vector_type
<
int32_t
,
32
>::
type
;
using
int32x64_t
=
typename
vector_type
<
int32_t
,
64
>::
type
;
// i8
using
int8x2_t
=
typename
vector_type
<
int8_t
,
2
>::
type
;
using
int8x4_t
=
typename
vector_type
<
int8_t
,
4
>::
type
;
using
int8x8_t
=
typename
vector_type
<
int8_t
,
8
>::
type
;
using
int8x16_t
=
typename
vector_type
<
int8_t
,
16
>::
type
;
using
int8x32_t
=
typename
vector_type
<
int8_t
,
32
>::
type
;
using
int8x64_t
=
typename
vector_type
<
int8_t
,
64
>::
type
;
// f8
using
f8x2_fnuz_t
=
typename
vector_type
<
f8_fnuz_t
,
2
>::
type
;
using
f8x4_fnuz_t
=
typename
vector_type
<
f8_fnuz_t
,
4
>::
type
;
using
f8x8_fnuz_t
=
typename
vector_type
<
f8_fnuz_t
,
8
>::
type
;
using
f8x16_fnuz_t
=
typename
vector_type
<
f8_fnuz_t
,
16
>::
type
;
using
f8x32_fnuz_t
=
typename
vector_type
<
f8_fnuz_t
,
32
>::
type
;
using
f8x64_fnuz_t
=
typename
vector_type
<
f8_fnuz_t
,
64
>::
type
;
// bf8
using
bf8x2_fnuz_t
=
typename
vector_type
<
bf8_fnuz_t
,
2
>::
type
;
using
bf8x4_fnuz_t
=
typename
vector_type
<
bf8_fnuz_t
,
4
>::
type
;
using
bf8x8_fnuz_t
=
typename
vector_type
<
bf8_fnuz_t
,
8
>::
type
;
using
bf8x16_fnuz_t
=
typename
vector_type
<
bf8_fnuz_t
,
16
>::
type
;
using
bf8x32_fnuz_t
=
typename
vector_type
<
bf8_fnuz_t
,
32
>::
type
;
using
bf8x64_fnuz_t
=
typename
vector_type
<
bf8_fnuz_t
,
64
>::
type
;
// f8
using
f8x2_ocp_t
=
typename
vector_type
<
f8_ocp_t
,
2
>::
type
;
using
f8x4_ocp_t
=
typename
vector_type
<
f8_ocp_t
,
4
>::
type
;
using
f8x8_ocp_t
=
typename
vector_type
<
f8_ocp_t
,
8
>::
type
;
using
f8x16_ocp_t
=
typename
vector_type
<
f8_ocp_t
,
16
>::
type
;
using
f8x32_ocp_t
=
typename
vector_type
<
f8_ocp_t
,
32
>::
type
;
using
f8x64_ocp_t
=
typename
vector_type
<
f8_ocp_t
,
64
>::
type
;
// bf8
using
bf8x2_ocp_t
=
typename
vector_type
<
bf8_ocp_t
,
2
>::
type
;
using
bf8x4_ocp_t
=
typename
vector_type
<
bf8_ocp_t
,
4
>::
type
;
using
bf8x8_ocp_t
=
typename
vector_type
<
bf8_ocp_t
,
8
>::
type
;
using
bf8x16_ocp_t
=
typename
vector_type
<
bf8_ocp_t
,
16
>::
type
;
using
bf8x32_ocp_t
=
typename
vector_type
<
bf8_ocp_t
,
32
>::
type
;
using
bf8x64_ocp_t
=
typename
vector_type
<
bf8_ocp_t
,
64
>::
type
;
#if CK_FP8_TYPE_OCP
// f8
using
f8x2_t
=
f8x2_ocp_t
;
using
f8x4_t
=
f8x4_ocp_t
;
using
f8x8_t
=
f8x8_ocp_t
;
using
f8x16_t
=
f8x16_ocp_t
;
using
f8x32_t
=
f8x32_ocp_t
;
using
f8x64_t
=
f8x64_ocp_t
;
// bf8
using
bf8x2_t
=
bf8x2_ocp_t
;
using
bf8x4_t
=
bf8x4_ocp_t
;
using
bf8x8_t
=
bf8x8_ocp_t
;
using
bf8x16_t
=
bf8x16_ocp_t
;
using
bf8x32_t
=
bf8x32_ocp_t
;
using
bf8x64_t
=
bf8x64_ocp_t
;
#elif CK_FP8_TYPE_FNUZ
// f8
using
f8x2_t
=
f8x2_fnuz_t
;
using
f8x4_t
=
f8x4_fnuz_t
;
using
f8x8_t
=
f8x8_fnuz_t
;
using
f8x16_t
=
f8x16_fnuz_t
;
using
f8x32_t
=
f8x32_fnuz_t
;
using
f8x64_t
=
f8x64_fnuz_t
;
// bf8
using
bf8x2_t
=
bf8x2_fnuz_t
;
using
bf8x4_t
=
bf8x4_fnuz_t
;
using
bf8x8_t
=
bf8x8_fnuz_t
;
using
bf8x16_t
=
bf8x16_fnuz_t
;
using
bf8x32_t
=
bf8x32_fnuz_t
;
using
bf8x64_t
=
bf8x64_fnuz_t
;
#endif
// u8
using
uint8x2_t
=
typename
vector_type
<
uint8_t
,
2
>::
type
;
using
uint8x4_t
=
typename
vector_type
<
uint8_t
,
4
>::
type
;
using
uint8x8_t
=
typename
vector_type
<
uint8_t
,
8
>::
type
;
using
uint8x16_t
=
typename
vector_type
<
uint8_t
,
16
>::
type
;
using
uint8x32_t
=
typename
vector_type
<
uint8_t
,
32
>::
type
;
using
uint8x64_t
=
typename
vector_type
<
uint8_t
,
64
>::
type
;
// f4
using
f4x2_t
=
typename
vector_type
<
f4x2_pk_t
,
1
>::
type
;
using
f4x4_t
=
typename
vector_type
<
f4x2_pk_t
,
2
>::
type
;
using
f4x8_t
=
typename
vector_type
<
f4x2_pk_t
,
4
>::
type
;
using
f4x16_t
=
typename
vector_type
<
f4x2_pk_t
,
8
>::
type
;
using
f4x32_t
=
typename
vector_type
<
f4x2_pk_t
,
16
>::
type
;
using
f4x64_t
=
typename
vector_type
<
f4x2_pk_t
,
32
>::
type
;
// f6
using
f6x16_t
=
typename
vector_type
<
f6x16_pk_t
,
1
>::
type
;
using
f6x32_t
=
typename
vector_type
<
f6x32_pk_t
,
1
>::
type
;
// bf6
using
bf6x16_t
=
typename
vector_type
<
bf6x16_pk_t
,
1
>::
type
;
using
bf6x32_t
=
typename
vector_type
<
bf6x32_pk_t
,
1
>::
type
;
// pack int4
using
pk_i4x2_t
=
typename
vector_type
<
pk_i4_t
,
2
>::
type
;
using
pk_i4x4_t
=
typename
vector_type
<
pk_i4_t
,
4
>::
type
;
using
pk_i4x8_t
=
typename
vector_type
<
pk_i4_t
,
8
>::
type
;
#ifdef CK_CODE_GEN_RTC
template
<
typename
T
>
struct
NumericLimits
;
template
<
>
struct
NumericLimits
<
int32_t
>
{
__host__
__device__
static
constexpr
int32_t
Lowest
()
noexcept
{
return
-
2147483647
-
1
;
}
__host__
__device__
static
constexpr
int32_t
Min
()
noexcept
{
return
-
2147483647
-
1
;
}
__host__
__device__
static
constexpr
int32_t
Max
()
noexcept
{
return
2147483647
;
}
__host__
__device__
static
constexpr
int32_t
Infinity
()
noexcept
{
return
0
;
}
__host__
__device__
static
constexpr
int32_t
QuietNaN
()
{
return
0
;
}
};
template
<
>
struct
NumericLimits
<
int16_t
>
{
__host__
__device__
static
constexpr
int16_t
Lowest
()
noexcept
{
return
-
32768
;
}
__host__
__device__
static
constexpr
int16_t
Min
()
noexcept
{
return
-
32768
;
}
__host__
__device__
static
constexpr
int16_t
Max
()
noexcept
{
return
32767
;
}
__host__
__device__
static
constexpr
int16_t
Infinity
()
noexcept
{
return
0
;
}
__host__
__device__
static
constexpr
int16_t
QuietNaN
()
{
return
0
;
}
};
template
<
>
struct
NumericLimits
<
int8_t
>
{
__host__
__device__
static
constexpr
int8_t
Lowest
()
noexcept
{
return
-
128
;
}
__host__
__device__
static
constexpr
int8_t
Min
()
noexcept
{
return
-
128
;
}
__host__
__device__
static
constexpr
int8_t
Max
()
noexcept
{
return
127
;
}
__host__
__device__
static
constexpr
int8_t
Infinity
()
noexcept
{
return
0
;
}
__host__
__device__
static
constexpr
int8_t
QuietNaN
()
{
return
0
;
}
};
template
<
>
struct
NumericLimits
<
uint32_t
>
{
__host__
__device__
static
constexpr
uint32_t
Lowest
()
noexcept
{
return
0
;
}
__host__
__device__
static
constexpr
uint32_t
Min
()
noexcept
{
return
0
;
}
__host__
__device__
static
constexpr
uint32_t
Max
()
noexcept
{
return
4294967295U
;
}
__host__
__device__
static
constexpr
uint32_t
Infinity
()
noexcept
{
return
0
;
}
__host__
__device__
static
constexpr
uint32_t
QuietNaN
()
{
return
0
;
}
};
template
<
>
struct
NumericLimits
<
uint16_t
>
{
__host__
__device__
static
constexpr
uint16_t
Lowest
()
noexcept
{
return
0
;
}
__host__
__device__
static
constexpr
uint16_t
Min
()
noexcept
{
return
0
;
}
__host__
__device__
static
constexpr
uint16_t
Max
()
noexcept
{
return
65535U
;
}
__host__
__device__
static
constexpr
uint16_t
Infinity
()
noexcept
{
return
0
;
}
__host__
__device__
static
constexpr
uint16_t
QuietNaN
()
{
return
0
;
}
};
template
<
>
struct
NumericLimits
<
float
>
{
static
constexpr
unsigned
int
binary_min
=
0x00800000
;
static
constexpr
unsigned
int
binary_max
=
0x7F7FFFFF
;
static
constexpr
unsigned
int
binary_lowest
=
0xFF7FFFFF
;
static
constexpr
unsigned
int
binary_qnan
=
0xFFC00001
;
static
constexpr
unsigned
int
binary_inf
=
0x7F8000000
;
__host__
__device__
static
constexpr
float
Min
()
{
return
bit_cast
<
float
>
(
binary_min
);
}
__host__
__device__
static
constexpr
float
Max
()
{
return
bit_cast
<
float
>
(
binary_max
);
}
__host__
__device__
static
constexpr
float
Lowest
()
{
return
bit_cast
<
float
>
(
binary_lowest
);
}
__host__
__device__
static
constexpr
float
QuietNaN
()
{
return
bit_cast
<
float
>
(
binary_qnan
);
}
__host__
__device__
static
constexpr
float
Infinity
()
{
return
bit_cast
<
float
>
(
binary_inf
);
}
};
template
<
>
struct
NumericLimits
<
half_t
>
{
static
constexpr
unsigned
short
binary_min
=
0x0400
;
static
constexpr
unsigned
short
binary_max
=
0x7BFF
;
static
constexpr
unsigned
short
binary_lowest
=
0xFBFF
;
static
constexpr
unsigned
short
binary_qnan
=
0x7FFF
;
__host__
__device__
static
constexpr
half_t
Min
()
{
return
bit_cast
<
half_t
>
(
binary_min
);
}
__host__
__device__
static
constexpr
half_t
Max
()
{
return
bit_cast
<
half_t
>
(
binary_max
);
}
__host__
__device__
static
constexpr
half_t
Lowest
()
{
return
bit_cast
<
half_t
>
(
binary_lowest
);
}
__host__
__device__
static
constexpr
half_t
QuietNaN
()
{
return
bit_cast
<
half_t
>
(
binary_qnan
);
}
};
#ifdef CK_EXPERIMENTAL_BIT_INT_EXTENSION_INT4
template
<
>
struct
NumericLimits
<
int4_t
>
{
__host__
__device__
static
constexpr
int4_t
Min
()
{
return
int4_t
(
-
8
);
}
__host__
__device__
static
constexpr
int4_t
Max
()
{
return
int4_t
(
7
);
}
__host__
__device__
static
constexpr
int4_t
Lowest
()
{
return
int4_t
(
-
8
);
}
};
#endif // CK_EXPERIMENTAL_BIT_INT_EXTENSION_INT4
template
<
>
struct
NumericLimits
<
f8_fnuz_t
>
{
// negative zero nan mode with exp bias = 8
static
constexpr
uint8_t
binary_min
=
0x08
;
// 0b00001000
static
constexpr
uint8_t
binary_max
=
0x7F
;
// 0b01111111
static
constexpr
uint8_t
binary_lowest
=
0xFF
;
// 0b11111111
static
constexpr
uint8_t
binary_qnan
=
0x80
;
// 0b10000000
// ieee mode with exp bias = 7
// static constexpr uint8_t binary_min = 0x08; // 0b00001000
// static constexpr uint8_t binary_max = 0x77; // 0b01110111
// static constexpr uint8_t binary_lowest = 0xF7; // 0b11110111
// static constexpr uint8_t binary_qnan = 0x79; // any sign, exp=1111, mant!=0
__host__
__device__
static
constexpr
f8_fnuz_t
Min
()
{
return
f8_fnuz_t
(
binary_min
);
}
if
constexpr
(
is_same
<
X
,
d1_t
>::
value
)
{
return
data_
.
d1x64_
;
}
else
if
constexpr
(
is_same
<
X
,
d2_t
>::
value
)
{
return
data_
.
d2x32_
;
}
else
if
constexpr
(
is_same
<
X
,
d4_t
>::
value
)
{
return
data_
.
d4x16_
;
}
else
if
constexpr
(
is_same
<
X
,
d8_t
>::
value
)
{
return
data_
.
d8x8_
;
}
else
if
constexpr
(
is_same
<
X
,
d16_t
>::
value
)
{
return
data_
.
d16x4_
;
}
else
if
constexpr
(
is_same
<
X
,
d32_t
>::
value
)
{
return
data_
.
d32x2_
;
}
else
if
constexpr
(
is_same
<
X
,
d64_t
>::
value
)
{
return
data_
.
d64x1_
;
}
else
{
return
err
;
}
}
__host__
__device__
static
constexpr
f8_fnuz_t
Max
()
{
return
f8_fnuz_t
(
binary_max
);
}
__host__
__device__
static
constexpr
f8_fnuz_t
Lowest
()
{
return
f8_fnuz_t
(
binary_lowest
);
}
__host__
__device__
static
constexpr
f8_fnuz_t
QuietNaN
()
{
return
f8_fnuz_t
(
binary_qnan
);
}
};
using
int64_t
=
long
;
template
<
>
struct
NumericLimits
<
bf8_fnuz_t
>
{
// negative zero nan mode with exp bias = 16
static
constexpr
uint8_t
binary_min
=
0x04
;
// 0b00000100
static
constexpr
uint8_t
binary_max
=
0x7F
;
// 0b01111111
static
constexpr
uint8_t
binary_lowest
=
0xFF
;
// 0b11111111
static
constexpr
uint8_t
binary_qnan
=
0x80
;
// 0b10000000
// ieee mode with exp bias = 15
// static constexpr uint8_t binary_min = 0x04; // 0b00000100
// static constexpr uint8_t binary_max = 0x7B; // 0b01111011
// static constexpr uint8_t binary_lowest = 0xFB; // 0b11111011
// static constexpr uint8_t binary_qnan = 0x79; // any sign, exp=1111, mant!=
// fp64
using
double2_t
=
typename
vector_type
<
double
,
2
>::
type
;
using
double4_t
=
typename
vector_type
<
double
,
4
>::
type
;
__host__
__device__
static
constexpr
bf8_fnuz_t
Min
()
{
return
bf8_fnuz_t
(
binary_min
);
}
// fp32
using
float2_t
=
typename
vector_type
<
float
,
2
>::
type
;
using
float4_t
=
typename
vector_type
<
float
,
4
>::
type
;
using
float8_t
=
typename
vector_type
<
float
,
8
>::
type
;
using
float16_t
=
typename
vector_type
<
float
,
16
>::
type
;
using
float32_t
=
typename
vector_type
<
float
,
32
>::
type
;
using
float64_t
=
typename
vector_type
<
float
,
64
>::
type
;
__host__
__device__
static
constexpr
bf8_fnuz_t
Max
()
{
return
bf8_fnuz_t
(
binary_max
);
}
// fp16
using
half2_t
=
typename
vector_type
<
half_t
,
2
>::
type
;
using
half4_t
=
typename
vector_type
<
half_t
,
4
>::
type
;
using
half8_t
=
typename
vector_type
<
half_t
,
8
>::
type
;
using
half16_t
=
typename
vector_type
<
half_t
,
16
>::
type
;
using
half32_t
=
typename
vector_type
<
half_t
,
32
>::
type
;
using
half64_t
=
typename
vector_type
<
half_t
,
64
>::
type
;
__host__
__device__
static
constexpr
bf8_fnuz_t
Lowest
()
{
return
bf8_fnuz_t
(
binary_lowest
);
}
// bfp16
using
bhalf2_t
=
typename
vector_type
<
bhalf_t
,
2
>::
type
;
using
bhalf4_t
=
typename
vector_type
<
bhalf_t
,
4
>::
type
;
using
bhalf8_t
=
typename
vector_type
<
bhalf_t
,
8
>::
type
;
using
bhalf16_t
=
typename
vector_type
<
bhalf_t
,
16
>::
type
;
using
bhalf32_t
=
typename
vector_type
<
bhalf_t
,
32
>::
type
;
using
bhalf64_t
=
typename
vector_type
<
bhalf_t
,
64
>::
type
;
__host__
__device__
static
constexpr
bf8_fnuz_t
QuietNaN
()
{
return
bf8_fnuz_t
(
binary_qnan
);
}
};
// i32
using
int32x2_t
=
typename
vector_type
<
int32_t
,
2
>::
type
;
using
int32x4_t
=
typename
vector_type
<
int32_t
,
4
>::
type
;
using
int32x8_t
=
typename
vector_type
<
int32_t
,
8
>::
type
;
using
int32x16_t
=
typename
vector_type
<
int32_t
,
16
>::
type
;
using
int32x32_t
=
typename
vector_type
<
int32_t
,
32
>::
type
;
using
int32x64_t
=
typename
vector_type
<
int32_t
,
64
>::
type
;
template
<
>
struct
NumericLimits
<
f8_ocp_t
>
{
static
constexpr
uint8_t
binary_min
=
0x08
;
// 0b00001000 = 2^-6
static
constexpr
uint8_t
binary_max
=
0x7E
;
// 0b01111110 = 448
static
constexpr
uint8_t
binary_lowest
=
0xFE
;
// 0b11111110 = -448
static
constexpr
uint8_t
binary_qnan
=
0x7F
;
// 0b01111111
// i8
using
int8x2_t
=
typename
vector_type
<
int8_t
,
2
>::
type
;
using
int8x4_t
=
typename
vector_type
<
int8_t
,
4
>::
type
;
using
int8x8_t
=
typename
vector_type
<
int8_t
,
8
>::
type
;
using
int8x16_t
=
typename
vector_type
<
int8_t
,
16
>::
type
;
using
int8x32_t
=
typename
vector_type
<
int8_t
,
32
>::
type
;
using
int8x64_t
=
typename
vector_type
<
int8_t
,
64
>::
type
;
__host__
__device__
static
constexpr
f8_ocp_t
Min
()
{
return
bit_cast
<
f8_ocp_t
>
(
binary_min
);
}
// f8
using
f8x2_t
=
typename
vector_type
<
f8_t
,
2
>::
type
;
using
f8x4_t
=
typename
vector_type
<
f8_t
,
4
>::
type
;
using
f8x8_t
=
typename
vector_type
<
f8_t
,
8
>::
type
;
using
f8x16_t
=
typename
vector_type
<
f8_t
,
16
>::
type
;
using
f8x32_t
=
typename
vector_type
<
f8_t
,
32
>::
type
;
using
f8x64_t
=
typename
vector_type
<
f8_t
,
64
>::
type
;
__host__
__device__
static
constexpr
f8_ocp_t
Max
()
{
return
bit_cast
<
f8_ocp_t
>
(
binary_max
);
}
// bf8
using
bf8x2_t
=
typename
vector_type
<
bf8_t
,
2
>::
type
;
using
bf8x4_t
=
typename
vector_type
<
bf8_t
,
4
>::
type
;
using
bf8x8_t
=
typename
vector_type
<
bf8_t
,
8
>::
type
;
using
bf8x16_t
=
typename
vector_type
<
bf8_t
,
16
>::
type
;
using
bf8x32_t
=
typename
vector_type
<
bf8_t
,
32
>::
type
;
using
bf8x64_t
=
typename
vector_type
<
bf8_t
,
64
>::
type
;
__host__
__device__
static
constexpr
f8_ocp_t
Lowest
()
{
return
bit_cast
<
f8_ocp_t
>
(
binary_lowest
);
}
// u8
using
uint8x2_t
=
typename
vector_type
<
uint8_t
,
2
>::
type
;
using
uint8x4_t
=
typename
vector_type
<
uint8_t
,
4
>::
type
;
using
uint8x8_t
=
typename
vector_type
<
uint8_t
,
8
>::
type
;
using
uint8x16_t
=
typename
vector_type
<
uint8_t
,
16
>::
type
;
using
uint8x32_t
=
typename
vector_type
<
uint8_t
,
32
>::
type
;
using
uint8x64_t
=
typename
vector_type
<
uint8_t
,
64
>::
type
;
__host__
__device__
static
constexpr
f8_ocp_t
QuietNaN
()
{
return
bit_cast
<
f8_ocp_t
>
(
binary_qnan
);
}
};
template
<
>
struct
NumericLimits
<
bf8_ocp_t
>
{
static
constexpr
uint8_t
binary_min
=
0x04
;
// 0b00000100 = 2^-14
static
constexpr
uint8_t
binary_max
=
0x7B
;
// 0b01111011 = 57344
static
constexpr
uint8_t
binary_lowest
=
0xFB
;
// 0b11111011 = -57344
static
constexpr
uint8_t
binary_qnan
=
0x7D
;
// 0b01111101
__host__
__device__
static
constexpr
bf8_ocp_t
Min
()
{
return
bit_cast
<
bf8_ocp_t
>
(
binary_min
);
}
__host__
__device__
static
constexpr
bf8_ocp_t
Max
()
{
return
bit_cast
<
bf8_ocp_t
>
(
binary_max
);
}
__host__
__device__
static
constexpr
bf8_ocp_t
Lowest
()
{
return
bit_cast
<
bf8_ocp_t
>
(
binary_lowest
);
}
__host__
__device__
static
constexpr
bf8_ocp_t
QuietNaN
()
{
return
bit_cast
<
bf8_ocp_t
>
(
binary_qnan
);
}
};
#else
template
<
typename
T
>
struct
NumericLimits
{
__host__
__device__
static
constexpr
T
Min
()
{
return
std
::
numeric_limits
<
T
>::
min
();
}
__host__
__device__
static
constexpr
T
Max
()
{
return
std
::
numeric_limits
<
T
>::
max
();
}
__host__
__device__
static
constexpr
T
Lowest
()
{
return
std
::
numeric_limits
<
T
>::
lowest
();
}
__host__
__device__
static
constexpr
T
QuietNaN
()
{
return
std
::
numeric_limits
<
T
>::
quiet_NaN
();
}
__host__
__device__
static
constexpr
T
Infinity
()
{
return
std
::
numeric_limits
<
T
>::
infinity
();
}
};
...
...
@@ -1702,7 +2869,7 @@ struct NumericLimits<int4_t>
#endif // CK_EXPERIMENTAL_BIT_INT_EXTENSION_INT4
template
<
>
struct
NumericLimits
<
f8_t
>
struct
NumericLimits
<
f8_
fnuz_
t
>
{
// negative zero nan mode with exp bias = 8
static
constexpr
uint8_t
binary_min
=
0x08
;
// 0b00001000
...
...
@@ -1715,17 +2882,17 @@ struct NumericLimits<f8_t>
// static constexpr uint8_t binary_lowest = 0xF7; // 0b11110111
// static constexpr uint8_t binary_qnan = 0x79; // any sign, exp=1111, mant!=0
__host__
__device__
static
constexpr
f8_t
Min
()
{
return
f8_t
(
binary_min
);
}
__host__
__device__
static
constexpr
f8_
fnuz_
t
Min
()
{
return
f8_
fnuz_
t
(
binary_min
);
}
__host__
__device__
static
constexpr
f8_t
Max
()
{
return
f8_t
(
binary_max
);
}
__host__
__device__
static
constexpr
f8_
fnuz_
t
Max
()
{
return
f8_
fnuz_
t
(
binary_max
);
}
__host__
__device__
static
constexpr
f8_t
Lowest
()
{
return
f8_t
(
binary_lowest
);
}
__host__
__device__
static
constexpr
f8_
fnuz_
t
Lowest
()
{
return
f8_
fnuz_
t
(
binary_lowest
);
}
__host__
__device__
static
constexpr
f8_t
QuietNaN
()
{
return
f8_t
(
binary_qnan
);
}
__host__
__device__
static
constexpr
f8_
fnuz_
t
QuietNaN
()
{
return
f8_
fnuz_
t
(
binary_qnan
);
}
};
template
<
>
struct
NumericLimits
<
bf8_t
>
struct
NumericLimits
<
bf8_
fnuz_
t
>
{
// negative zero nan mode with exp bias = 16
static
constexpr
uint8_t
binary_min
=
0x04
;
// 0b00000100
...
...
@@ -1738,13 +2905,172 @@ struct NumericLimits<bf8_t>
// static constexpr uint8_t binary_lowest = 0xFB; // 0b11111011
// static constexpr uint8_t binary_qnan = 0x79; // any sign, exp=1111, mant!=
__host__
__device__
static
constexpr
bf8_t
Min
()
{
return
bf8_t
(
binary_min
);
}
__host__
__device__
static
constexpr
bf8_fnuz_t
Min
()
{
return
bf8_fnuz_t
(
binary_min
);
}
__host__
__device__
static
constexpr
bf8_fnuz_t
Max
()
{
return
bf8_fnuz_t
(
binary_max
);
}
__host__
__device__
static
constexpr
bf8_fnuz_t
Lowest
()
{
return
bf8_fnuz_t
(
binary_lowest
);
}
__host__
__device__
static
constexpr
bf8_t
Max
()
{
return
bf8_t
(
binary_max
);
}
__host__
__device__
static
constexpr
bf8_fnuz_t
QuietNaN
()
{
return
bf8_fnuz_t
(
binary_qnan
);
}
};
template
<
>
struct
NumericLimits
<
f8_ocp_t
>
{
static
constexpr
uint8_t
binary_min
=
0x08
;
// 0b00001000 = 2^-6
static
constexpr
uint8_t
binary_max
=
0x7E
;
// 0b01111110 = 448
static
constexpr
uint8_t
binary_lowest
=
0xFE
;
// 0b11111110 = -448
static
constexpr
uint8_t
binary_qnan
=
0x7F
;
// 0b01111111
__host__
__device__
static
constexpr
f8_ocp_t
Min
()
{
return
bit_cast
<
f8_ocp_t
>
(
binary_min
);
}
__host__
__device__
static
constexpr
f8_ocp_t
Max
()
{
return
bit_cast
<
f8_ocp_t
>
(
binary_max
);
}
__host__
__device__
static
constexpr
f8_ocp_t
Lowest
()
{
return
bit_cast
<
f8_ocp_t
>
(
binary_lowest
);
}
__host__
__device__
static
constexpr
f8_ocp_t
QuietNaN
()
{
return
bit_cast
<
f8_ocp_t
>
(
binary_qnan
);
}
};
template
<
>
struct
NumericLimits
<
bf8_ocp_t
>
{
static
constexpr
uint8_t
binary_min
=
0x04
;
// 0b00000100 = 2^-14
static
constexpr
uint8_t
binary_max
=
0x7B
;
// 0b01111011 = 57344
static
constexpr
uint8_t
binary_lowest
=
0xFB
;
// 0b11111011 = -57344
static
constexpr
uint8_t
binary_qnan
=
0x7D
;
// 0b01111101
__host__
__device__
static
constexpr
bf8_ocp_t
Min
()
{
return
bit_cast
<
bf8_ocp_t
>
(
binary_min
);
}
__host__
__device__
static
constexpr
bf8_ocp_t
Max
()
{
return
bit_cast
<
bf8_ocp_t
>
(
binary_max
);
}
__host__
__device__
static
constexpr
bf8_ocp_t
Lowest
()
{
return
bit_cast
<
bf8_ocp_t
>
(
binary_lowest
);
}
__host__
__device__
static
constexpr
bf8_ocp_t
QuietNaN
()
{
return
bit_cast
<
bf8_ocp_t
>
(
binary_qnan
);
}
};
#endif
template
<
>
struct
NumericLimits
<
f4_t
>
{
static
constexpr
uint8_t
binary_min_normal
=
0x2
;
// 0b0010
static
constexpr
uint8_t
binary_max_normal
=
0x7
;
// 0b0111
static
constexpr
uint8_t
binary_lowest_normal
=
0xF
;
// 0b1111
static
constexpr
uint8_t
binary_min_subnorm
=
0x1
;
// 0b0001
static
constexpr
uint8_t
binary_max_subnorm
=
0x1
;
// 0b0001
static
constexpr
float
data_max_normal_number
=
6
;
static
constexpr
float
data_min_subnormal_number
=
0.5
;
__host__
__device__
static
constexpr
f4_t
Min
()
{
return
f4_t
(
binary_min_normal
);
}
__host__
__device__
static
constexpr
f4_t
Max
()
{
return
f4_t
(
binary_max_normal
);
}
__host__
__device__
static
constexpr
f4_t
Lowest
()
{
return
f4_t
(
binary_lowest_normal
);
}
__host__
__device__
static
constexpr
f4_t
MinSubnorm
()
{
return
f4_t
(
binary_min_subnorm
);
}
__host__
__device__
static
constexpr
f4_t
MaxSubnorm
()
{
return
f4_t
(
binary_max_subnorm
);
}
__host__
__device__
static
constexpr
float
DataMaxNorm
()
{
return
data_max_normal_number
;
}
__host__
__device__
static
constexpr
float
DataMinSubnorm
()
{
return
data_min_subnormal_number
;
}
};
template
<
>
struct
NumericLimits
<
f6_t
>
{
static
constexpr
uint8_t
binary_min_normal
=
0x08
;
// 0b001000
static
constexpr
uint8_t
binary_max_normal
=
0x1F
;
// 0b011111
static
constexpr
uint8_t
binary_lowest_normal
=
0x3F
;
// 0b111111
static
constexpr
uint8_t
binary_min_subnorm
=
0x01
;
// 0b000001
static
constexpr
uint8_t
binary_max_subnorm
=
0x07
;
// 0b000111
static
constexpr
float
data_max_normal_number
=
7.5
;
static
constexpr
float
data_min_subnormal_number
=
0.125
;
__host__
__device__
static
constexpr
f6_t
Min
()
{
return
f6_t
(
binary_min_normal
&
0b111111
);
}
__host__
__device__
static
constexpr
f6_t
Max
()
{
return
f6_t
(
binary_max_normal
&
0b111111
);
}
__host__
__device__
static
constexpr
f6_t
Lowest
()
{
return
f6_t
(
binary_lowest_normal
&
0b111111
);
}
__host__
__device__
static
constexpr
f6_t
MinSubnorm
()
{
return
f6_t
(
binary_min_subnorm
&
0b111111
);
}
__host__
__device__
static
constexpr
f6_t
MaxSubnorm
()
{
return
f6_t
(
binary_max_subnorm
&
0b111111
);
}
__host__
__device__
static
constexpr
float
DataMaxNorm
()
{
return
data_max_normal_number
;
}
__host__
__device__
static
constexpr
float
DataMinSubnorm
()
{
return
data_min_subnormal_number
;
}
};
template
<
>
struct
NumericLimits
<
bf6_t
>
{
static
constexpr
uint8_t
binary_min_normal
=
0x08
;
// 0b001000
static
constexpr
uint8_t
binary_max_normal
=
0x1F
;
// 0b011111
static
constexpr
uint8_t
binary_lowest_normal
=
0x3F
;
// 0b111111
static
constexpr
uint8_t
binary_min_subnorm
=
0x01
;
// 0b000001
static
constexpr
uint8_t
binary_max_subnorm
=
0x03
;
// 0b000011
static
constexpr
float
data_max_normal_number
=
28
;
static
constexpr
float
data_min_subnormal_number
=
0.0625
;
__host__
__device__
static
constexpr
bf6_t
Min
()
{
return
bf6_t
(
binary_min_normal
);
}
__host__
__device__
static
constexpr
bf6_t
Max
()
{
return
bf6_t
(
binary_max_normal
);
}
__host__
__device__
static
constexpr
bf6_t
Lowest
()
{
return
bf6_t
(
binary_lowest_normal
);
}
__host__
__device__
static
constexpr
bf6_t
MinSubnorm
()
{
return
bf6_t
(
binary_min_subnorm
);
}
__host__
__device__
static
constexpr
bf6_t
MaxSubnorm
()
{
return
bf6_t
(
binary_max_subnorm
);
}
__host__
__device__
static
constexpr
float
DataMaxNorm
()
{
return
data_max_normal_number
;
}
__host__
__device__
static
constexpr
float
DataMinSubnorm
()
{
return
data_min_subnormal_number
;
}
};
__host__
__device__
static
constexpr
bf8_t
Lowest
()
{
return
bf8_t
(
binary_lowest
);
}
template
<
>
struct
NumericLimits
<
e8m0_bexp_t
>
{
static
constexpr
e8m0_bexp_t
binary_min
=
0x00
;
// 0b00000000
static
constexpr
e8m0_bexp_t
binary_max
=
0xFE
;
// 0b11111110
static
constexpr
e8m0_bexp_t
binary_qnan
=
0xFF
;
// 0b11111111
static
constexpr
e8m0_bexp_t
binary_1
=
0x7F
;
// 0b01111111
static
constexpr
e8m0_bexp_t
binary_2
=
0x80
;
// 0b10000000
static
constexpr
e8m0_bexp_t
binary_3
=
0x82
;
// 0b10000010
static
constexpr
e8m0_bexp_t
binary_135
=
0x87
;
// 0b10000111
static
constexpr
e8m0_bexp_t
binary_142
=
0x8E
;
// 0b10001110
__host__
__device__
static
constexpr
bf8_t
QuietNaN
()
{
return
bf8_t
(
binary_qnan
);
}
__host__
__device__
static
constexpr
e8m0_bexp_t
Min
()
{
return
e8m0_bexp_t
(
binary_min
);
}
__host__
__device__
static
constexpr
e8m0_bexp_t
Max
()
{
return
e8m0_bexp_t
(
binary_max
);
}
__host__
__device__
static
constexpr
e8m0_bexp_t
QuietNaN
()
{
return
e8m0_bexp_t
(
binary_qnan
);
}
__host__
__device__
static
constexpr
e8m0_bexp_t
Binary_1
()
{
return
e8m0_bexp_t
(
binary_1
);
}
__host__
__device__
static
constexpr
e8m0_bexp_t
Binary_2
()
{
return
e8m0_bexp_t
(
binary_2
);
}
__host__
__device__
static
constexpr
e8m0_bexp_t
Binary_3
()
{
return
e8m0_bexp_t
(
binary_3
);
}
__host__
__device__
static
constexpr
e8m0_bexp_t
Binary_135
()
{
return
e8m0_bexp_t
(
binary_135
);
}
__host__
__device__
static
constexpr
e8m0_bexp_t
Binary_142
()
{
return
e8m0_bexp_t
(
binary_142
);
}
};
template
<
typename
T
>
...
...
@@ -1766,6 +3092,7 @@ struct NumericUtils<float>
static
constexpr
uint32_t
NegInf
=
0xFF800000
;
static
constexpr
uint32_t
NaN
=
0x7F800001
;
static
constexpr
uint32_t
Neg0
=
0x80000000
;
static
constexpr
bool
has_inf
=
true
;
using
bitwise_type
=
uint32_t
;
};
...
...
@@ -1783,33 +3110,158 @@ struct NumericUtils<half_t>
static
constexpr
uint32_t
NegInf
=
0xFC00
;
static
constexpr
uint32_t
NaN
=
0x7C01
;
static
constexpr
uint32_t
Neg0
=
0x8000
;
static
constexpr
bool
has_inf
=
true
;
using
bitwise_type
=
uint16_t
;
};
template
<
>
struct
NumericUtils
<
f8_t
>
struct
NumericUtils
<
bhalf_t
>
{
static
constexpr
int
exp
=
8
;
static
constexpr
int
mant
=
7
;
static
constexpr
int
bias
=
128
;
// negative zero nan mode
// static constexpr int bias = 127; // ieee mode
};
template
<
>
struct
NumericUtils
<
f8_fnuz_t
>
{
static
constexpr
int
exp
=
4
;
static
constexpr
int
mant
=
3
;
static
constexpr
int
bias
=
8
;
// negative zero nan mode
// static constexpr int bias = 7; // ieee mode
static
constexpr
bool
has_inf
=
false
;
};
template
<
>
struct
NumericUtils
<
bf8_t
>
struct
NumericUtils
<
bf8_
fnuz_
t
>
{
static
constexpr
int
exp
=
5
;
static
constexpr
int
mant
=
2
;
static
constexpr
int
bias
=
16
;
// negative zero nan mode
// static constexpr int bias = 15; // ieee mode
static
constexpr
bool
has_inf
=
false
;
};
template
<
>
struct
NumericUtils
<
f8_ocp_t
>
{
static
constexpr
int
exp
=
4
;
static
constexpr
int
mant
=
3
;
static
constexpr
int
bias
=
7
;
};
template
<
>
struct
NumericUtils
<
bhalf_t
>
struct
NumericUtils
<
bf8_ocp_t
>
{
static
constexpr
int
exp
=
5
;
static
constexpr
int
mant
=
2
;
static
constexpr
int
bias
=
15
;
};
template
<
>
struct
NumericUtils
<
f4_t
>
{
static
constexpr
int
exp
=
2
;
static
constexpr
int
mant
=
1
;
static
constexpr
int
bias
=
1
;
static
constexpr
uint32_t
sr_shift
=
10
;
static
constexpr
int
unbiased_exp_min
=
0
;
static
constexpr
int
unbiased_exp_max
=
2
;
static
constexpr
int
biased_exp_min
=
1
;
static
constexpr
int
biased_exp_max
=
3
;
static
constexpr
uint8_t
positive_zero_mask
=
0b0000
;
static
constexpr
uint8_t
negative_zero_mask
=
0b1000
;
static
constexpr
uint8_t
one_mask
=
0b0010
;
static
constexpr
uint8_t
set_sign_mask
=
0b0111
;
static
constexpr
uint8_t
data_max_positive_normal_mask
=
0b0111
;
static
constexpr
uint8_t
data_max_negative_normal_mask
=
0b1111
;
static
constexpr
uint8_t
data_max_positive_subnormal_mask
=
0b0001
;
static
constexpr
uint8_t
data_max_negative_subnormal_mask
=
0b1001
;
static
constexpr
bool
has_inf
=
false
;
using
bitwise_type
=
uint8_t
;
};
template
<
>
struct
NumericUtils
<
f6_t
>
{
static
constexpr
int
exp
=
2
;
static
constexpr
int
mant
=
3
;
static
constexpr
int
bias
=
1
;
static
constexpr
uint32_t
sr_shift
=
12
;
static
constexpr
int
unbiased_exp_min
=
0
;
static
constexpr
int
unbiased_exp_max
=
2
;
static
constexpr
int
biased_exp_min
=
1
;
static
constexpr
int
biased_exp_max
=
3
;
static
constexpr
uint8_t
positive_zero_mask
=
0b000000
;
static
constexpr
uint8_t
negative_zero_mask
=
0b100000
;
static
constexpr
uint8_t
set_sign_mask
=
0b011111
;
static
constexpr
uint8_t
data_max_positive_normal_mask
=
0b011111
;
static
constexpr
uint8_t
data_max_negative_normal_mask
=
0b111111
;
static
constexpr
uint8_t
data_max_positive_subnormal_mask
=
0b000111
;
static
constexpr
uint8_t
data_max_negative_subnormal_mask
=
0b100111
;
static
constexpr
bool
has_inf
=
false
;
static
constexpr
bool
has_nan
=
false
;
static
constexpr
bool
has_zero
=
true
;
using
bitwise_type
=
uint8_t
;
};
template
<
>
struct
NumericUtils
<
bf6_t
>
{
static
constexpr
int
exp
=
3
;
static
constexpr
int
mant
=
2
;
static
constexpr
int
bias
=
3
;
static
constexpr
uint32_t
sr_shift
=
11
;
static
constexpr
int
unbiased_exp_min
=
-
2
;
static
constexpr
int
unbiased_exp_max
=
4
;
static
constexpr
int
biased_exp_min
=
1
;
static
constexpr
int
biased_exp_max
=
7
;
static
constexpr
uint8_t
positive_zero_mask
=
0b000000
;
static
constexpr
uint8_t
negative_zero_mask
=
0b100000
;
static
constexpr
uint8_t
set_sign_mask
=
0b011111
;
static
constexpr
uint8_t
data_max_positive_normal_mask
=
0b011111
;
static
constexpr
uint8_t
data_max_negative_normal_mask
=
0b111111
;
static
constexpr
uint8_t
data_max_positive_subnormal_mask
=
0b000011
;
static
constexpr
uint8_t
data_max_negative_subnormal_mask
=
0b100011
;
static
constexpr
bool
has_inf
=
false
;
static
constexpr
bool
has_nan
=
false
;
static
constexpr
bool
has_zero
=
true
;
using
bitwise_type
=
uint8_t
;
};
template
<
>
struct
NumericUtils
<
e8m0_bexp_t
>
{
static
constexpr
int
exp
=
8
;
static
constexpr
int
mant
=
7
;
static
constexpr
int
bias
=
128
;
// negative zero nan mode
// static constexpr int bias = 127; // ieee mode
static
constexpr
int
mant
=
0
;
static
constexpr
int
bias
=
127
;
static
constexpr
int
unbiased_exp_min
=
-
127
;
static
constexpr
int
unbiased_exp_max
=
127
;
static
constexpr
int
biased_exp_min
=
0
;
static
constexpr
int
biased_exp_max
=
254
;
using
bitwise_type
=
uint8_t
;
};
}
// namespace ck
include/ck/utility/debug.hpp
View file @
f1e53807
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-202
4
, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-202
5
, Advanced Micro Devices, Inc. All rights reserved.
#ifndef UTILITY_DEBUG_HPP
#define UTILITY_DEBUG_HPP
#include "type.hpp"
namespace
ck
{
namespace
debug
{
...
...
include/ck/utility/dynamic_buffer.hpp
View file @
f1e53807
...
...
@@ -29,6 +29,13 @@ struct DynamicBuffer
ElementSpaceSize
element_space_size_
;
T
invalid_element_value_
=
T
{
0
};
static
constexpr
index_t
PackedSize
=
[]()
{
if
constexpr
(
is_same_v
<
remove_cvref_t
<
T
>
,
pk_i4_t
>
)
return
2
;
else
return
1
;
}();
__host__
__device__
constexpr
DynamicBuffer
(
T
*
p_data
,
ElementSpaceSize
element_space_size
)
:
p_data_
{
p_data
},
element_space_size_
{
element_space_size
}
{
...
...
@@ -54,7 +61,8 @@ struct DynamicBuffer
template
<
typename
X
,
typename
enable_if
<
is_same
<
typename
scalar_type
<
remove_cvref_t
<
X
>
>::
type
,
typename
scalar_type
<
remove_cvref_t
<
T
>>::
type
>::
value
,
typename
scalar_type
<
remove_cvref_t
<
T
>>::
type
>::
value
||
!
is_native_type
<
X
>
(),
bool
>::
type
=
false
>
__host__
__device__
constexpr
auto
Get
(
index_t
i
,
bool
is_valid_element
)
const
{
...
...
@@ -81,14 +89,18 @@ struct DynamicBuffer
return
amd_buffer_load_invalid_element_return_zero
<
remove_cvref_t
<
T
>
,
t_per_x
,
coherence
>
(
p_data_
,
i
,
is_valid_element
,
element_space_size_
);
p_data_
,
i
,
is_valid_element
,
element_space_size_
/
PackedSize
);
}
else
{
return
amd_buffer_load_invalid_element_return_customized_value
<
remove_cvref_t
<
T
>
,
t_per_x
,
coherence
>
(
p_data_
,
i
,
is_valid_element
,
element_space_size_
,
invalid_element_value_
);
p_data_
,
i
,
is_valid_element
,
element_space_size_
/
PackedSize
,
invalid_element_value_
);
}
}
else
...
...
@@ -190,12 +202,13 @@ struct DynamicBuffer
dst_buf
.
p_data_
,
dst_offset
,
is_valid_element
,
element_space_size_
);
element_space_size_
/
PackedSize
);
}
template
<
typename
X
,
typename
enable_if
<
is_same
<
typename
scalar_type
<
remove_cvref_t
<
X
>
>::
type
,
typename
scalar_type
<
remove_cvref_t
<
T
>>::
type
>::
value
,
typename
scalar_type
<
remove_cvref_t
<
T
>>::
type
>::
value
||
!
is_native_type
<
X
>
(),
bool
>::
type
=
false
>
__host__
__device__
void
Set
(
index_t
i
,
bool
is_valid_element
,
const
X
&
x
)
{
...
...
@@ -224,7 +237,7 @@ struct DynamicBuffer
constexpr
index_t
t_per_x
=
scalar_per_x_vector
/
scalar_per_t_vector
;
amd_buffer_store
<
remove_cvref_t
<
T
>
,
t_per_x
,
coherence
>
(
x
,
p_data_
,
i
,
is_valid_element
,
element_space_size_
);
x
,
p_data_
,
i
,
is_valid_element
,
element_space_size_
/
PackedSize
);
}
else
if
constexpr
(
GetAddressSpace
()
==
AddressSpaceEnum
::
Lds
&&
is_same
<
typename
scalar_type
<
remove_cvref_t
<
T
>>::
type
,
int8_t
>::
value
&&
...
...
@@ -376,7 +389,7 @@ struct DynamicBuffer
constexpr
index_t
t_per_x
=
scalar_per_x_vector
/
scalar_per_t_vector
;
amd_buffer_atomic_add
<
remove_cvref_t
<
T
>
,
t_per_x
>
(
x
,
p_data_
,
i
,
is_valid_element
,
element_space_size_
);
x
,
p_data_
,
i
,
is_valid_element
,
element_space_size_
/
PackedSize
);
}
else
{
...
...
@@ -415,7 +428,7 @@ struct DynamicBuffer
constexpr
index_t
t_per_x
=
scalar_per_x_vector
/
scalar_per_t_vector
;
amd_buffer_atomic_max
<
remove_cvref_t
<
T
>
,
t_per_x
>
(
x
,
p_data_
,
i
,
is_valid_element
,
element_space_size_
);
x
,
p_data_
,
i
,
is_valid_element
,
element_space_size_
/
PackedSize
);
}
else
if
(
is_valid_element
)
{
...
...
include/ck/utility/e8m0.hpp
0 → 100644
View file @
f1e53807
// SPDX-License-Identifier: MIT
// Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck/utility/type.hpp"
namespace
ck
{
/**
* @brief Unsigned representation of a conventional biased Float32 exponent.
*
* bias = 127;
*
* E8M0_1 = 0b01111111; => 2^(127-127) = 1
* E8M0_2 = 0b10000000; => 2^(128-127) = 2^1 = 2
* E8M0_3 = 0b10000010; => 2^(130-127) = 2^3 = 8
* E8M0_135 = 0b10000111; => 2^(135-127) = 2^8 = 256
* E8M0_142 = 0b10001110; => 2^(142-127) = 2^15 = 32768
* E8M0_MIN = 0b00000000; => 2^-127
* E8M0_MAX = 0b11111110; => 2^127
* E8M0_NAN = 0b11111111; => NaN
*/
struct
e8m0_bexp_t
{
using
type
=
uint8_t
;
type
data
;
constexpr
static
type
bias
=
127
;
constexpr
static
type
nan_mask
=
0xFF
;
__host__
__device__
constexpr
e8m0_bexp_t
()
:
data
{
type
{}}
{}
__host__
__device__
constexpr
e8m0_bexp_t
(
type
init
)
:
data
{
init
}
{}
__host__
__device__
constexpr
e8m0_bexp_t
(
int
init
)
:
data
{
static_cast
<
type
>
(
init
&
nan_mask
)}
{
}
__host__
__device__
explicit
constexpr
e8m0_bexp_t
(
float
scale
)
:
data
{
static_cast
<
type
>
((
bit_cast
<
uint32_t
>
(
scale
)
&
(
nan_mask
<<
23
))
>>
23
)}
{
}
__host__
__device__
explicit
constexpr
operator
float
()
const
{
if
(
data
==
nan_mask
||
data
==
0
)
{
uint32_t
bits
=
data
<<
1
;
bits
|=
1
;
bits
<<=
22
;
return
bit_cast
<
float
>
(
bits
);
}
else
{
uint32_t
bits
=
data
<<
23
;
return
bit_cast
<
float
>
(
bits
);
}
}
__host__
__device__
constexpr
bool
operator
==
(
const
e8m0_bexp_t
&
other
)
const
{
// strict IEEE compliance for NaN
return
data
==
other
.
data
&&
data
!=
nan_mask
;
}
__host__
__device__
constexpr
bool
is_nan
()
const
{
return
data
==
nan_mask
;
}
};
namespace
utils
{
template
<
typename
T
>
__host__
__device__
inline
int
get_exponent_value
(
T
x
);
template
<
>
__host__
__device__
inline
int
get_exponent_value
<
e8m0_bexp_t
>
(
e8m0_bexp_t
x
)
{
return
x
.
data
;
}
}
// namespace utils
}
// namespace ck
include/ck/utility/enable_if.hpp
View file @
f1e53807
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-202
3
, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-202
5
, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
namespace
ck
{
#ifndef CK_CODE_GEN_RTC
template
<
bool
B
,
typename
T
=
void
>
using
enable_if
=
std
::
enable_if
<
B
,
T
>
;
template
<
bool
B
,
typename
T
=
void
>
using
enable_if_t
=
typename
std
::
enable_if
<
B
,
T
>::
type
;
#else
template
<
bool
B
,
class
T
=
void
>
struct
enable_if
{
};
template
<
class
T
>
struct
enable_if
<
true
,
T
>
{
using
type
=
T
;
};
template
<
bool
B
,
class
T
=
void
>
using
enable_if_t
=
typename
enable_if
<
B
,
T
>::
type
;
#endif
}
// namespace ck
include/ck/utility/env.hpp
View file @
f1e53807
// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2024
-2025
, Advanced Micro Devices, Inc. All rights reserved.
#ifndef CK_CODE_GEN_RTC
#pragma once
#include <cstdlib>
...
...
@@ -183,3 +184,4 @@ void UpdateEnvVar(EnvVar, const std::string_view& val)
}
}
// namespace ck
#endif
include/ck/utility/functional.hpp
View file @
f1e53807
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-202
3
, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-202
5
, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
...
...
@@ -120,11 +120,11 @@ constexpr auto conditional_expr(X&& x, Y&& y)
{
if
constexpr
(
predicate
)
{
return
std
::
forward
<
X
>
(
x
);
return
ck
::
forward
<
X
>
(
x
);
}
else
{
return
std
::
forward
<
Y
>
(
y
);
return
ck
::
forward
<
Y
>
(
y
);
}
}
...
...
include/ck/utility/functional4.hpp
View file @
f1e53807
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-202
3
, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-202
5
, Advanced Micro Devices, Inc. All rights reserved.
#ifndef CK_FUNCTIONAL4_HPP
#define CK_FUNCTIONAL4_HPP
...
...
@@ -21,7 +21,7 @@ struct unpack_impl<Sequence<Is...>>
template
<
typename
F
,
typename
X
>
__host__
__device__
constexpr
auto
operator
()(
F
&&
f
,
X
&&
x
)
const
{
return
std
::
forward
<
F
>
(
f
)(
std
::
forward
<
X
>
(
x
).
At
(
Number
<
Is
>
{})...);
return
ck
::
forward
<
F
>
(
f
)(
ck
::
forward
<
X
>
(
x
).
At
(
Number
<
Is
>
{})...);
}
};
...
...
@@ -35,8 +35,8 @@ struct unpack2_impl<Sequence<Is...>, Sequence<Js...>>
template
<
typename
F
,
typename
X
,
typename
Y
>
__host__
__device__
constexpr
auto
operator
()(
F
&&
f
,
X
&&
x
,
Y
&&
y
)
const
{
return
std
::
forward
<
F
>
(
f
)(
std
::
forward
<
X
>
(
x
).
At
(
Number
<
Is
>
{})...,
std
::
forward
<
Y
>
(
y
).
At
(
Number
<
Js
>
{})...);
return
ck
::
forward
<
F
>
(
f
)(
ck
::
forward
<
X
>
(
x
).
At
(
Number
<
Is
>
{})...,
ck
::
forward
<
Y
>
(
y
).
At
(
Number
<
Js
>
{})...);
}
};
...
...
@@ -47,7 +47,7 @@ __host__ __device__ constexpr auto unpack(F&& f, X&& x)
{
using
X_
=
remove_reference_t
<
X
>
;
return
detail
::
unpack_impl
<
typename
arithmetic_sequence_gen
<
0
,
X_
::
Size
(),
1
>::
type
>
{}(
std
::
forward
<
F
>
(
f
),
std
::
forward
<
X
>
(
x
));
ck
::
forward
<
F
>
(
f
),
ck
::
forward
<
X
>
(
x
));
}
// TODO: properly implement unpack that takes any number of containers
...
...
@@ -58,7 +58,7 @@ __host__ __device__ constexpr auto unpack2(F&& f, X&& x, Y&& y)
using
Y_
=
remove_reference_t
<
Y
>
;
return
detail
::
unpack2_impl
<
typename
arithmetic_sequence_gen
<
0
,
X_
::
Size
(),
1
>::
type
,
typename
arithmetic_sequence_gen
<
0
,
Y_
::
Size
(),
1
>::
type
>
{}(
std
::
forward
<
F
>
(
f
),
std
::
forward
<
X
>
(
x
),
std
::
forward
<
Y
>
(
y
));
ck
::
forward
<
F
>
(
f
),
ck
::
forward
<
X
>
(
x
),
ck
::
forward
<
Y
>
(
y
));
}
}
// namespace ck
...
...
include/ck/utility/integral_constant.hpp
View file @
f1e53807
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-202
3
, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-202
5
, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
...
...
@@ -48,4 +48,9 @@ __host__ __device__ constexpr auto operator%(integral_constant<TX, X>, integral_
return
integral_constant
<
decltype
(
X
%
Y
),
X
%
Y
>
{};
}
template
<
bool
B
>
using
bool_constant
=
integral_constant
<
bool
,
B
>
;
using
true_type
=
bool_constant
<
true
>
;
using
false_type
=
bool_constant
<
false
>
;
}
// namespace ck
include/ck/utility/is_detected.hpp
View file @
f1e53807
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-202
3
, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-202
5
, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck/utility/integral_constant.hpp"
namespace
ck
{
namespace
detail
{
template
<
class
Default
,
class
AlwaysVoid
,
template
<
class
...
>
class
Op
,
class
...
Args
>
struct
detector
{
using
value_t
=
std
::
false_type
;
using
value_t
=
integral_constant
<
bool
,
false
>
;
using
type
=
Default
;
};
template
<
class
Default
,
template
<
class
...
>
class
Op
,
class
...
Args
>
struct
detector
<
Default
,
std
::
void_t
<
Op
<
Args
...
>>
,
Op
,
Args
...
>
struct
detector
<
Default
,
ck
::
void_t
<
Op
<
Args
...
>>
,
Op
,
Args
...
>
{
using
value_t
=
std
::
true_type
;
using
value_t
=
integral_constant
<
bool
,
true
>
;
using
type
=
Op
<
Args
...
>
;
};
}
// namespace detail
...
...
@@ -32,12 +34,12 @@ template <template <class...> class Op, class... Args>
using
is_detected
=
typename
detail
::
detector
<
nonesuch
,
void
,
Op
,
Args
...
>::
value_t
;
template
<
typename
T
>
using
is_pack2_invocable_t
=
decltype
(
std
::
declval
<
T
&>
().
is_pack2_invocable
);
using
is_pack2_invocable_t
=
decltype
(
ck
::
declval
<
T
&>
().
is_pack2_invocable
);
template
<
typename
T
>
using
is_pack4_invocable_t
=
decltype
(
std
::
declval
<
T
&>
().
is_pack4_invocable
);
using
is_pack4_invocable_t
=
decltype
(
ck
::
declval
<
T
&>
().
is_pack4_invocable
);
template
<
typename
T
>
using
is_pack8_invocable_t
=
decltype
(
std
::
declval
<
T
&>
().
is_pack8_invocable
);
using
is_pack8_invocable_t
=
decltype
(
ck
::
declval
<
T
&>
().
is_pack8_invocable
);
}
// namespace ck
include/ck/utility/loop_scheduler.hpp
View file @
f1e53807
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
#ifndef CK_CODE_GEN_RTC
#include <ostream>
#endif
#pragma once
#include "ck/utility/common_header.hpp"
#include "ck/tensor_description/tensor_adaptor.hpp"
namespace
ck
{
...
...
@@ -26,6 +28,7 @@ constexpr LoopScheduler make_default_loop_scheduler()
}
// namespace ck
#ifndef CK_CODE_GEN_RTC
inline
std
::
ostream
&
operator
<<
(
std
::
ostream
&
os
,
const
ck
::
LoopScheduler
&
s
)
{
switch
(
s
)
...
...
@@ -36,3 +39,4 @@ inline std::ostream& operator<<(std::ostream& os, const ck::LoopScheduler& s)
}
return
os
;
}
#endif
include/ck/utility/magic_division.hpp
View file @
f1e53807
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-202
3
, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-202
4
, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
...
...
@@ -9,6 +9,10 @@
#include "type.hpp"
#include "tuple.hpp"
#ifdef CK_CODE_GEN_RTC
#define INT32_MAX 2147483647
#endif
namespace
ck
{
// magic number division
...
...
Prev
1
…
15
16
17
18
19
20
21
22
23
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment