Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel_ROCM
Commits
3c5717df
Unverified
Commit
3c5717df
authored
Feb 10, 2025
by
Illia Silin
Committed by
GitHub
Feb 10, 2025
Browse files
Merge branch 'develop' into gemm_elementwise_gemm
parents
171b9030
d9f1ead3
Changes
454
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
1711 additions
and
395 deletions
+1711
-395
include/ck_tile/core/arch/utility.hpp
include/ck_tile/core/arch/utility.hpp
+24
-0
include/ck_tile/core/config.hpp
include/ck_tile/core/config.hpp
+24
-3
include/ck_tile/core/container/meta_data_buffer.hpp
include/ck_tile/core/container/meta_data_buffer.hpp
+3
-3
include/ck_tile/core/container/tuple.hpp
include/ck_tile/core/container/tuple.hpp
+1
-1
include/ck_tile/core/numeric/bfloat16.hpp
include/ck_tile/core/numeric/bfloat16.hpp
+47
-1
include/ck_tile/core/numeric/float8.hpp
include/ck_tile/core/numeric/float8.hpp
+593
-340
include/ck_tile/core/numeric/half.hpp
include/ck_tile/core/numeric/half.hpp
+6
-5
include/ck_tile/core/numeric/numeric.hpp
include/ck_tile/core/numeric/numeric.hpp
+2
-1
include/ck_tile/core/numeric/pk_int4.hpp
include/ck_tile/core/numeric/pk_int4.hpp
+140
-0
include/ck_tile/core/numeric/vector_type.hpp
include/ck_tile/core/numeric/vector_type.hpp
+18
-1
include/ck_tile/core/tensor/buffer_view.hpp
include/ck_tile/core/tensor/buffer_view.hpp
+80
-6
include/ck_tile/core/tensor/load_tile.hpp
include/ck_tile/core/tensor/load_tile.hpp
+45
-9
include/ck_tile/core/tensor/static_distributed_tensor.hpp
include/ck_tile/core/tensor/static_distributed_tensor.hpp
+27
-0
include/ck_tile/core/tensor/tensor_view.hpp
include/ck_tile/core/tensor/tensor_view.hpp
+42
-0
include/ck_tile/core/tensor/tile_window.hpp
include/ck_tile/core/tensor/tile_window.hpp
+92
-5
include/ck_tile/core/tensor/tile_window_linear.hpp
include/ck_tile/core/tensor/tile_window_linear.hpp
+142
-17
include/ck_tile/core/tensor/tile_window_utils.hpp
include/ck_tile/core/tensor/tile_window_utils.hpp
+54
-0
include/ck_tile/core/tensor/transpose_tile.hpp
include/ck_tile/core/tensor/transpose_tile.hpp
+202
-0
include/ck_tile/core/tensor/update_tile.hpp
include/ck_tile/core/tensor/update_tile.hpp
+53
-3
include/ck_tile/core/utility/static_counter.hpp
include/ck_tile/core/utility/static_counter.hpp
+116
-0
No files found.
Too many changes to show.
To preserve performance only
454 of 454+
files are displayed.
Plain diff
Email patch
include/ck_tile/core/arch/utility.hpp
View file @
3c5717df
...
...
@@ -102,4 +102,28 @@ CK_TILE_DEVICE T warp_shuffle(const T& v_local, uint32_t src_lane)
#endif
}
template
<
typename
T
>
CK_TILE_DEVICE
auto
flag_to_exec
(
const
T
&
v_flag
)
{
static_assert
(
sizeof
(
T
)
==
4
);
// per-thread v_flag store into 2x sgpr
uint32x2_t
exec_flag
;
asm
volatile
(
"v_cmp_ge_u32 %[s_exec_flag], %[v_flag], 1"
:
[
s_exec_flag
]
"=s"
(
exec_flag
)
:
[
v_flag
]
"v"
(
v_flag
));
return
exec_flag
;
}
template
<
typename
X
,
typename
Y
>
CK_TILE_DEVICE
auto
cmp_lt_to_exec
(
const
X
&
x
,
const
Y
&
y
)
{
static_assert
(
sizeof
(
X
)
==
4
&&
sizeof
(
Y
)
==
4
);
// per-thread cmp store into 2x sgpr
uint32x2_t
exec_flag
;
asm
volatile
(
"v_cmp_lt_u32 %[s_exec_flag], %[v_x], %[v_y]"
:
[
s_exec_flag
]
"=s"
(
exec_flag
)
:
[
v_x
]
"v"
(
x
),
[
v_y
]
"v"
(
y
));
return
exec_flag
;
}
}
// namespace ck_tile
include/ck_tile/core/config.hpp
View file @
3c5717df
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-202
4
, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-202
5
, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#if defined(__gfx908__) || defined(__gfx90a__) || defined(__gfx940__) || defined(__gfx941__) || \
defined(__gfx942__)
defined(__gfx942__)
|| defined(__gfx950__)
#define __gfx9__
#endif
#if defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__)
#if defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__)
|| defined(__gfx950__)
#define __gfx94__
#endif
#if defined(__gfx1030__) || defined(__gfx1031__) || defined(__gfx1032__) || \
...
...
@@ -64,6 +64,7 @@
#define CK_TILE_FLOAT_TO_BFLOAT16_TRUNCATE_WITH_NAN 1
#define CK_TILE_FLOAT_TO_BFLOAT16_TRUNCATE 2
#define CK_TILE_FLOAT_TO_BFLOAT16_STANDARD_ASM 3
#define CK_TILE_FLOAT_TO_BFLOAT16_RTA_ASM 4
#ifndef CK_TILE_FLOAT_TO_BFLOAT16_DEFAULT
#define CK_TILE_FLOAT_TO_BFLOAT16_DEFAULT CK_TILE_FLOAT_TO_BFLOAT16_TRUNCATE
...
...
@@ -143,6 +144,10 @@
#define CK_TILE_USE_AMD_BUFFER_ATOMIC_ADD_INTEGER 1
#endif
#ifndef CK_TILE_USE_PK4_LAYOUT_SHUFFLE
#define CK_TILE_USE_PK4_LAYOUT_SHUFFLE 1
#endif
// buffer atomic add: floating point
#ifndef __HIP_DEVICE_COMPILE__ // for host code
#define CK_TILE_USE_AMD_BUFFER_ATOMIC_ADD_FLOAT 1
...
...
@@ -225,3 +230,19 @@
#ifndef CK_TILE_WORKAROUND_SWDEV_383542
#define CK_TILE_WORKAROUND_SWDEV_383542 1
#endif
#ifndef CK_TILE_REFERENCE_MOE_SORTING_MOCK_ID
#define CK_TILE_REFERENCE_MOE_SORTING_MOCK_ID 1
#endif
#ifndef __HIP_DEVICE_COMPILE__ // for host code
#ifdef CK_TILE_USE_OCP_FP8
#define CK_TILE_USE_OCP_FP8 1
#else
#define CK_TILE_USE_OCP_FP8 0
#endif
#elif defined(__gfx950__) || defined(__gfx12__) // for GPU code
#define CK_TILE_USE_OCP_FP8 1
#else // for GPU code
#define CK_TILE_USE_OCP_FP8 0
#endif
include/ck_tile/core/container/meta_data_buffer.hpp
View file @
3c5717df
...
...
@@ -30,7 +30,7 @@ struct meta_data_buffer
{
constexpr
index_t
size
=
sizeof
(
T
);
auto
tmp
=
bit_cast
<
array
<
std
::
byte
,
size
>>
(
data
);
auto
tmp
=
ck_tile
::
bit_cast
<
array
<
std
::
byte
,
size
>>
(
data
);
for
(
int
i
=
0
;
i
<
size
;
i
++
)
{
...
...
@@ -66,7 +66,7 @@ struct meta_data_buffer
pos
++
;
}
data
=
bit_cast
<
T
>
(
tmp
);
data
=
ck_tile
::
bit_cast
<
T
>
(
tmp
);
}
return
data
;
...
...
@@ -86,7 +86,7 @@ struct meta_data_buffer
pos
++
;
}
auto
data
=
bit_cast
<
T
>
(
tmp
);
auto
data
=
ck_tile
::
bit_cast
<
T
>
(
tmp
);
return
data
;
}
...
...
include/ck_tile/core/container/tuple.hpp
View file @
3c5717df
...
...
@@ -546,7 +546,7 @@ CK_TILE_HOST_DEVICE constexpr auto tuple_reverse(const tuple<Ts...>& t)
using
Idx
=
number
<
tuple
<
Ts
...
>::
size
()
-
i
-
1
>
;
return
t
.
at
(
Idx
{});
},
number
<
tuple
<
Ts
...
>::
size
()
()
>
{});
number
<
tuple
<
Ts
...
>::
size
()
>
{});
}
// Reduce tuple values in specific range using Function
...
...
include/ck_tile/core/numeric/bfloat16.hpp
View file @
3c5717df
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-202
4
, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-202
5
, Advanced Micro Devices, Inc. All rights reserved.
#include "ck_tile/core/config.hpp"
#include "ck_tile/core/utility/bit_cast.hpp"
...
...
@@ -18,6 +18,7 @@ enum class bf16_rounding_mode
truncate_with_nan
,
truncate
,
standard_asm
,
rta_asm
,
// round to nearest away
};
template
<
bf16_rounding_mode
rounding
=
...
...
@@ -180,6 +181,39 @@ uint16_t float_to_bf16_rtn_asm(float f)
return
uint16_t
(
u
.
int32
);
}
// TODO: do we need this on host?
CK_TILE_HOST
uint16_t
float_to_bf16_rta_asm
(
float
f
)
{
return
float_to_bf16_rtn_raw
(
f
);
}
CK_TILE_DEVICE
uint16_t
float_to_bf16_rta_asm
(
float
f
)
{
union
{
float
fp32
;
struct
{
uint16_t
lo
;
uint16_t
hi
;
};
}
u
=
{
f
};
const
uint32_t
low_nan
=
0x7fff
;
const
uint32_t
hi_nan
=
0x7fff0000
;
using
uint32x2_t
=
uint32_t
__attribute__
((
ext_vector_type
(
2
)));
uint32x2_t
check_nan
;
asm
volatile
(
"v_cmp_u_f32 %[s_cnan], %[v_x], %[v_x]
\n
"
"v_add3_u32 %[v_x], %[v_x], %[v_blo], 1
\n
"
"v_cndmask_b32 %[v_x], %[v_x], %[v_bhi], %[s_cnan]"
:
[
s_cnan
]
"+s"
(
check_nan
),
[
v_x
]
"+v"
(
u
.
fp32
)
:
[
v_blo
]
"v"
(
low_nan
),
[
v_bhi
]
"v"
(
hi_nan
));
// Note: in above code snipet, we use hi 16 bit
return
u
.
hi
;
}
// Truncate instead of rounding, preserving SNaN
CK_TILE_HOST_DEVICE
constexpr
uint16_t
float_to_bf16_truc_nan_raw
(
float
f
)
...
...
@@ -213,6 +247,8 @@ CK_TILE_HOST_DEVICE constexpr uint16_t float_to_bf16_raw(float f, constant<round
return
float_to_bf16_rtn_asm
(
f
);
else
if
constexpr
(
rounding
==
bf16_rounding_mode
::
truncate_with_nan
)
return
float_to_bf16_truc_nan_raw
(
f
);
else
if
constexpr
(
rounding
==
bf16_rounding_mode
::
rta_asm
)
return
float_to_bf16_rta_asm
(
f
);
else
return
float_to_bf16_truc_raw
(
f
);
}
...
...
@@ -340,6 +376,16 @@ struct numeric<bfloat16_t>
}
};
template
<
typename
T
>
struct
numeric_traits
;
template
<
>
struct
numeric_traits
<
bfloat16_t
>
{
static
constexpr
int
exp
=
8
;
static
constexpr
int
mant
=
7
;
};
#if CK_TILE_USE_CUSTOM_DATA_TYPE
CK_TILE_ARITHMETIC_USING_FLOAT
(
CK_TILE_HOST_DEVICE
,
bfloat16_t
)
#endif
...
...
include/ck_tile/core/numeric/float8.hpp
View file @
3c5717df
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-202
4
, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-202
5
, Advanced Micro Devices, Inc. All rights reserved.
#include "ck_tile/core/config.hpp"
#include "ck_tile/core/utility/bit_cast.hpp"
...
...
@@ -14,6 +14,12 @@
#pragma once
#if(defined(__gfx94__) || defined(__gfx12__)) && __HIP_DEVICE_COMPILE__
#define CK_TILE_FP8_CVT_DEVICE 1
#else
#define CK_TILE_FP8_CVT_DEVICE 0
#endif
namespace
ck_tile
{
// fp8 rounding modes
...
...
@@ -25,15 +31,26 @@ enum class fp8_rounding_mode
stochastic
};
/**
* \brief FP8 interpretation used in conversion algorithms
*/
enum
class
fp8_interpretation
{
E4M3_OCP
=
0
,
// OCP FP8 E4M3
E5M2_OCP
=
1
,
// OCP BF8 E5M2
E4M3_FNUZ
=
2
,
// FNUZ FP8 E4M3
E5M2_FNUZ
=
3
,
// FNUZ BF8 E5M2
};
/*
* ______________
NANOO
_________________ | ______________
IEEE
________________
* ______________
FNUZ
_________________ | ______________
OCP
________________
* e4m3 e5m2 | e4m3 e5m2
* bias : 8 16 | 7 15
* inf : 1.0000.000 1.00000.00 | N/A s.11111.00
* Nan : 1.0000.000 1.00000.00 | s.1111.111 s.11111.{01, 10, 11}
* zero : 0.0000.000 0.00000.00 | s.0000.000 s.00000.00
* Max(norm) : s.1111.111 (240) s.11111.11(57344) | s.1111.110(448) s.11110.11(57344)
* Max(snorm): s.0000.111 s.00000.11 | s.0000.111
(448)
s.00000.11
(57344)
* Max(snorm): s.0000.111 s.00000.11 | s.0000.111
s.00000.11
* 0.0068359375 2.288818e-05 | 0.013671875 4.57763671875e-05
* Min(norm) : s.0001.000 s.00001.00 | s.0001.000 s.00001.00
* 2^-7(0.00078125) 2^-15(3.05176e-05) | 2^-6(0.015625) 2^-14(6.10352e-05)
...
...
@@ -55,10 +72,10 @@ struct alignas(1) float8_e4m3_t
{
static
constexpr
int
exponent
=
4
;
static
constexpr
int
mantissa
=
3
;
#if
defined(__gfx94__)
static
constexpr
int
bias
=
1
<<
(
exponent
-
1
);
// NANOO
#if
CK_TILE_USE_OCP_FP8
static
constexpr
int
bias
=
7
;
// OCP
#else
static
constexpr
int
bias
=
(
1
<<
(
exponent
-
1
))
-
1
;
// IEEE
static
constexpr
int
bias
=
8
;
// FNUZ
#endif
using
raw_type
=
uint8_t
;
raw_type
data
;
...
...
@@ -113,10 +130,10 @@ struct alignas(1) float8_e5m2_t
{
static
constexpr
int
exponent
=
5
;
static
constexpr
int
mantissa
=
2
;
#if
defined(__gfx94__)
static
constexpr
int
bias
=
1
<<
(
exponent
-
1
);
// NANOO
#if
CK_TILE_USE_OCP_FP8
static
constexpr
int
bias
=
1
5
;
// OCP
#else
static
constexpr
int
bias
=
(
1
<<
(
exponent
-
1
))
-
1
;
//
IEEE
static
constexpr
int
bias
=
1
6
;
//
FNUZ
#endif
using
raw_type
=
uint8_t
;
raw_type
data
;
...
...
@@ -183,501 +200,727 @@ struct native_t<bf8_t>
};
#else
using
fp8_t
=
_BitInt
(
8
);
using
fp8_raw_t
=
uint8_t
;
using
bf8_t
=
unsigned
_BitInt
(
8
);
using
bf8_raw_t
=
uint8_t
;
#endif
// below is sw fp8 conversion, not utilizing hw instruction
namespace
impl
{
template
<
typename
T
>
struct
numeric_traits
;
template
<
typename
X
,
typename
Y
,
bool
negative_zero_nan
,
bool
clip
,
bool
stoch
>
CK_TILE_HOST_DEVICE
Y
run_cast_to_f8
(
X
x
,
uint32_t
rng
)
template
<
>
struct
numeric_traits
<
fp8_t
>
{
// fp8/bf8 exponent/mantissa layout
constexpr
int
out_exp
=
numeric_traits
<
Y
>::
exp
;
constexpr
int
out_mant
=
numeric_traits
<
Y
>::
mant
;
using
bitwise_type
=
fp8_raw_t
;
static
constexpr
int
exp
=
4
;
static
constexpr
int
mant
=
3
;
#if CK_TILE_USE_OCP_FP8
static
constexpr
int
bias
=
7
;
static
constexpr
fp8_interpretation
f8_interpret
=
fp8_interpretation
::
E4M3_OCP
;
#else
static
constexpr
int
bias
=
8
;
static
constexpr
fp8_interpretation
f8_interpret
=
fp8_interpretation
::
E4M3_FNUZ
;
#endif
static
constexpr
uint8_t
abs_mask
=
0x7F
;
};
// original type exponent/mantissa layout
constexpr
int
in_exp
=
numeric_traits
<
X
>::
exp
;
constexpr
int
in_mant
=
numeric_traits
<
X
>::
mant
;
template
<
>
struct
numeric_traits
<
bf8_t
>
{
using
bitwise_type
=
bf8_raw_t
;
int
exponent
,
bias
;
uint32_t
head
,
mantissa
,
sign
;
// nan code is same for float and half
#if CK_TILE_USE_CUSTOM_DATA_TYPE
constexpr
Y
nan_code
=
numeric
<
Y
>::
quiet_NaN
();
// __builtin_bit_cast(Y, static_cast<uint8_t>(0x80));
static
constexpr
int
exp
=
5
;
static
constexpr
int
mant
=
2
;
#if CK_TILE_USE_OCP_FP8
static
constexpr
int
bias
=
15
;
static
constexpr
fp8_interpretation
f8_interpret
=
fp8_interpretation
::
E5M2_OCP
;
#else
constexpr
Y
nan_code
=
0x80
;
static
constexpr
int
bias
=
16
;
static
constexpr
fp8_interpretation
f8_interpret
=
fp8_interpretation
::
E5M2_FNUZ
;
#endif
static
constexpr
uint8_t
abs_mask
=
0x7F
;
};
// below is sw fp8 conversion, not utilizing hw instruction
namespace
impl
{
template
<
typename
SrcT
,
typename
DstT
,
bool
clip
=
true
,
bool
stoch
=
false
>
CK_TILE_HOST_DEVICE
DstT
run_cast_to_f8
(
SrcT
src
,
unsigned
int
rng
=
0
)
{
static_assert
(
std
::
is_same
<
DstT
,
fp8_t
>::
value
||
std
::
is_same
<
DstT
,
bf8_t
>::
value
,
"DstT type must be fp8 or bf8."
);
constexpr
uint32_t
nan_mask
=
numeric_traits
<
X
>::
nan_mask
;
constexpr
bool
is_half
=
std
::
is_same
<
SrcT
,
half_t
>::
value
;
constexpr
bool
is_float
=
std
::
is_same
<
SrcT
,
float
>::
value
;
static_assert
(
is_half
||
is_float
,
"Only half and float can be cast to f8"
);
// convert to bitwise
using
T_bitwise
=
typename
numeric_traits
<
X
>::
bitwise_type
;
T_bitwise
x_bitwise
=
*
(
reinterpret_cast
<
T_bitwise
*>
(
&
x
));
// fp8/bf8 type exponent/mantissa layout
constexpr
int
DstT_exp
=
numeric_traits
<
DstT
>::
exp
;
// exponent width of the destination type
constexpr
int
DstT_mant
=
numeric_traits
<
DstT
>::
mant
;
// mantissa width of the destination type
constexpr
bool
is_fnuz
=
(
numeric_traits
<
DstT
>::
f8_interpret
==
fp8_interpretation
::
E4M3_FNUZ
)
||
(
numeric_traits
<
DstT
>::
f8_interpret
==
fp8_interpretation
::
E5M2_FNUZ
);
// unpack the input, depends on datatype
head
=
x_bitwise
&
numeric_traits
<
X
>::
head_mask
;
mantissa
=
x_bitwise
&
numeric_traits
<
X
>::
mant_mask
;
exponent
=
(
head
>>
in_mant
)
&
numeric_traits
<
X
>::
exp_mask
;
sign
=
head
>>
(
in_exp
+
in_mant
);
bias
=
numeric_traits
<
X
>::
bias
;
constexpr
int
SrcT_exp
=
numeric_traits
<
SrcT
>::
exp
;
constexpr
int
SrcT_mant
=
numeric_traits
<
SrcT
>::
mant
;
uint32_t
signed_inf
=
(
sign
<<
(
in_exp
+
in_mant
))
+
(((
1
<<
in_exp
)
-
1
)
<<
in_mant
);
uint32_t
drop_mask
=
(
1
<<
(
in_mant
-
out_mant
))
-
1
;
constexpr
int
max_exp
=
(
1
<<
out_exp
)
-
(
negative_zero_nan
?
1
:
2
);
using
SrcT_bitwise
=
typename
numeric_traits
<
SrcT
>::
bitwise_type
;
SrcT_bitwise
src_bitwise
=
bit_cast
<
SrcT_bitwise
>
(
src
);
if
constexpr
(
negative_zero_nan
)
unsigned
long
long
head
,
mantissa
;
int
exponent
,
bias
;
unsigned
int
sign
;
unsigned
long
long
fInf
,
abs_mask
;
head
=
src_bitwise
&
numeric_traits
<
SrcT
>::
head_mask
;
mantissa
=
src_bitwise
&
numeric_traits
<
SrcT
>::
mant_mask
;
exponent
=
(
head
>>
SrcT_mant
)
&
numeric_traits
<
SrcT
>::
exp_mask
;
sign
=
head
>>
(
SrcT_exp
+
SrcT_mant
);
bias
=
numeric_traits
<
SrcT
>::
bias
;
fInf
=
numeric_traits
<
SrcT
>::
Inf
;
abs_mask
=
numeric_traits
<
SrcT
>::
abs_mask
;
unsigned
int
signed_inf
=
0
;
unsigned
int
nan
=
0
;
if
constexpr
(
is_fnuz
)
{
if
((
x_bitwise
&
nan_mask
)
==
nan_mask
)
return
nan_code
;
signed_inf
=
clip
?
((
sign
<<
7
)
+
0x7f
)
:
0x80
;
nan
=
0x80
;
}
else
{
if
((
x_bitwise
&
nan_mask
)
==
nan_mask
)
return
signed_inf
+
(
mantissa
!=
0
?
1
:
0
);
if
constexpr
(
DstT_exp
==
4
)
{
// e4m3
signed_inf
=
(
sign
<<
7
)
+
(
clip
?
0x7e
:
0x7f
);
}
else
{
// e5m2
signed_inf
=
(
sign
<<
7
)
+
(
clip
?
0x7b
:
0x7c
);
}
nan
=
(
sign
<<
7
)
+
0x7f
;
}
// Max values
unsigned
long
long
ifmax
=
0
;
if
constexpr
(
is_float
)
{
if
constexpr
(
DstT_exp
==
5
)
{
ifmax
=
0x47600000
;
}
else
{
if
constexpr
(
is_fnuz
)
{
ifmax
=
0x43700000
;
}
else
{
ifmax
=
0x43E00000
;
}
}
}
else
if
constexpr
(
is_half
)
{
if
constexpr
(
DstT_exp
==
5
)
{
ifmax
=
0x7B00
;
}
else
{
if
constexpr
(
is_fnuz
)
{
ifmax
=
0x5B80
;
}
else
{
ifmax
=
0x5F00
;
}
}
}
// check if x is 0.0
if
(
x_bitwise
==
0
)
return
__builtin_bit_cast
(
Y
,
static_cast
<
uint8_t
>
(
0
));
// Deal with inf and NaNs
if
((
src_bitwise
&
fInf
)
==
fInf
)
{
if
constexpr
(
is_fnuz
)
return
signed_inf
;
return
mantissa
!=
0
?
nan
:
signed_inf
;
}
if
((
src_bitwise
&
abs_mask
)
>
ifmax
)
{
return
signed_inf
;
}
if
(
src_bitwise
==
0
)
{
return
0
;
}
// First need to check if it is normal or denorm as there is a difference of
implict 1
// Then need to adjust the exponent to align with the F8 exponent,
in the meanwhile, shift
// The mantissa. Then for stochastic rounding, add rng
to mantissa and truncate. And for
// RNE, no need to add rng. Then probably
need to check whether there is carry and adjust
// exponent and mantissa again
3
// First need to check if it is normal or denorm as there is a difference of
//
implicit 1
Then need to adjust the exponent to align with the F8 exponent,
//
in the meanwhile, shift
The mantissa. Then for stochastic rounding, add rng
//
to mantissa and truncate. And for
RNE, no need to add rng. Then probably
//
need to check whether there is carry and adjust
exponent and mantissa again
// For IEEE bias mode, the bias is 2^(k-1)-1 where k is the width of exponent bits
const
int
out_bias
=
(
1
<<
(
out_exp
-
1
))
-
1
+
(
negative_zero_nan
?
1
:
0
);
const
int
out_denormal_act_exponent
=
1
-
out_bias
;
// actual exponent of f8 denormal
// For IEEE bias mode, the bias is 2^(k-1) -1 where k is the width of exponent
// bits
const
int
f8_bias
=
(
1
<<
(
DstT_exp
-
1
))
-
1
+
(
is_fnuz
?
1
:
0
);
const
int
f8_denormal_act_exponent
=
1
-
f8_bias
;
// actual exponent of f8 denormal
// act_exponent is the actual exponent of fp32/fp16 (after subtracting bias)
//
out
_exponent is the converted f8 exponent with bias encoding
//
f8
_exponent is the converted f8 exponent with bias encoding
// exponent_diff is the diff between fp32/fp16 exponent and f8 exponent,
// the difference needs to be adjusted and mantissa shifted
int
act_exponent
,
out
_exponent
,
exponent_diff
;
int
act_exponent
,
f8
_exponent
,
exponent_diff
;
if
(
exponent
==
0
)
{
// fp32/fp16 is in denormal.
/* fp32 denormal is below 2^-127 so it is usually not a concern here, we mostly concern fp16
here. In this case, f8 is usually in denormal. But there could be exceptions. fp16 denormal has
exponent bias 15 while bf8 with NANOO has exponent bias 16. It means that there are some numbers in
fp16 denormal but they are bf8 (NANOO) normals - smallest bf8 (NANOO) normal is 2^-15. fp16 numbers
where exponent==0 (actual exponent -14) and highest bit of mantissa is 1 are bf8 (NANOO) normal.
In this case, the fp16 mantissa should be shift left by 1 */
/* fp32 denormal is below 2^-127 so it is usually not a concern here, we
mostly concern fp16 here. In this case, f8 is usually in denormal. But there
could be exceptions. fp16 denormal has exponent bias 15 while bf8 with NANOO has
exponent bias 16. It means that there are some numbers in fp16 denormal but they
are bf8 (NANOO) normals - smallest bf8 (NANOO) normal is 2^-15. fp16 numbers
where exponent==0 (actual exponent -14) and highest bit of mantissa is 1 are bf8
(NANOO) normal. In this case, the fp16 mantissa should be shift left by 1 */
act_exponent
=
exponent
-
bias
+
1
;
exponent_diff
=
out
_denormal_act_exponent
-
exponent_diff
=
f8
_denormal_act_exponent
-
act_exponent
;
// actual exponent is exponent-bias+1 as it is denormal
}
else
{
// fp32/fp16 is normal with implicit 1
act_exponent
=
exponent
-
bias
;
if
(
act_exponent
<=
out
_denormal_act_exponent
)
if
(
act_exponent
<=
f8
_denormal_act_exponent
)
{
/* This is the case where fp32/fp16 is normal but it is in f8 denormal
range.
For example fp8 nanoo mode, denormal exponent is -7, but if the fp32/fp16
actual exponent is -7, it is actually larger due to the implict 1,
Therefore it needs to be adjust to -6 and mantissa shift right by 1.
So for fp32/fp16, exponent -8 is the cut point to convert to fp8 nanoo */
exponent_diff
=
out
_denormal_act_exponent
-
act_exponent
;
/* This is the case where fp32/fp16 is normal but it is in f8 denormal
range.
For example fp8 nanoo mode, denormal exponent is -7, but if the fp32/fp16
actual exponent is -7, it is actually larger due to the implic
i
t 1,
Therefore it needs to be adjust to -6 and mantissa shift right by 1.
So for fp32/fp16, exponent -8 is the cut point to convert to fp8 nanoo */
exponent_diff
=
f8
_denormal_act_exponent
-
act_exponent
;
}
else
{
// both fp32/fp16 and f8 are in normal range
exponent_diff
=
0
;
// exponent_diff=0 does not mean there is no difference for this case,
// act_exponent could be larger. Just
that it does not need shift mantissa
{
// both fp32/fp16 and f8 are in normal range
exponent_diff
=
0
;
// exponent_diff=0 does not mean there is no difference
// for this case, act_exponent could be larger. Just
//
that it does not need shift mantissa
}
mantissa
+=
(
1
<<
in
_mant
);
// Add the implicit 1 into mantissa
mantissa
+=
(
1
ull
<<
SrcT
_mant
);
// Add the implicit 1 into mantissa
}
bool
midpoint
=
(
mantissa
&
((
1
<<
(
in_mant
-
out_mant
+
exponent_diff
))
-
1
))
==
(
1
<<
(
in_mant
-
out_mant
+
exponent_diff
-
1
));
/* This part is a bit tricky. The judgment of whether it is a tie needs to be done before we
shift right as shift right could rip off some residual part and make something not midpoint look
like midpoint. For example, the fp16 number 0x1002 (0 00100 0000000010), it is larger than
midpoint, but after shift right by 4 bits, it would look like midpoint. */
bool
midpoint
=
(
mantissa
&
((
1ull
<<
(
SrcT_mant
-
DstT_mant
+
exponent_diff
))
-
1
))
==
(
1ull
<<
(
SrcT_mant
-
DstT_mant
+
exponent_diff
-
1
));
/* This part is a bit tricky. The judgment of whether it is a tie needs to be
done before we shift right as shift right could rip off some residual part and
make something not midpoint look like midpoint. For example, the fp16 number
0x1002 (0 00100 0000000010), it is larger than midpoint, but after shift right
by 4 bits, it would look like midpoint.
*/
if
(
exponent_diff
>
0
)
mantissa
>>=
exponent_diff
;
else
if
(
exponent_diff
==
-
1
)
mantissa
<<=
-
exponent_diff
;
bool
implicit_one
=
mantissa
&
(
1
<<
in_mant
);
// if there is no implict 1, it means the f8 is denormal and need to adjust to denorm exponent
out_exponent
=
(
act_exponent
+
exponent_diff
)
/*actual f8 exponent*/
+
out_bias
-
(
implicit_one
?
0
:
1
);
bool
implicit_one
=
mantissa
&
(
1ull
<<
SrcT_mant
);
// if there is no implicit 1, it means the f8 is denormal and need to adjust
// to denorm exponent
f8_exponent
=
(
act_exponent
+
exponent_diff
)
/*actual f8 exponent*/
+
f8_bias
-
(
implicit_one
?
0
:
1
);
// Now we have the exponent and mantissa adjusted
unsigned
long
long
drop_mask
=
(
1ull
<<
(
SrcT_mant
-
DstT_mant
))
-
1
;
bool
odd
=
mantissa
&
(
1
<<
(
in_mant
-
out_mant
));
// if the least significant bit that is not truncated is 1
mantissa
+=
(
stoch
?
rng
:
(
midpoint
?
(
odd
?
mantissa
:
mantissa
-
1
)
:
mantissa
))
&
drop_mask
;
mantissa
&
(
1ull
<<
(
SrcT_mant
-
DstT_mant
));
// if the least significant bit that is not truncated is 1
mantissa
+=
(
stoch
?
rng
:
(
midpoint
?
(
odd
?
mantissa
:
mantissa
-
1ull
)
:
mantissa
))
&
drop_mask
;
// Now we deal with overflow
if
(
out
_exponent
==
0
)
if
(
f8
_exponent
==
0
)
{
if
((
1
<<
in
_mant
)
&
mantissa
)
if
((
1
ull
<<
SrcT
_mant
)
&
mantissa
)
{
out_exponent
=
1
;
// denormal overflow to become normal, promote exponent
// No need to make 1 implicit now as it will be addressed later
f8_exponent
=
1
;
// denormal overflow to become normal, promote exponent
}
}
else
{
if
((
1
<<
(
in
_mant
+
1
))
&
mantissa
)
if
((
1
ull
<<
(
SrcT
_mant
+
1
))
&
mantissa
)
{
mantissa
>>=
1
;
out_exponent
++
;
// No need to make 1 implicit now as it will be addressed later
f8_exponent
++
;
}
}
mantissa
>>=
(
in
_mant
-
out
_mant
);
mantissa
>>=
(
SrcT
_mant
-
DstT
_mant
);
if
(
out_exponent
>
max_exp
)
// above range: quantize to maximum possible float of the same sign
const
int
max_exp
=
(
1
<<
DstT_exp
)
-
1
;
if
(
f8_exponent
>
max_exp
)
{
if
(
clip
)
if
constexpr
(
clip
)
{
mantissa
=
(
1
<<
out
_mant
)
-
1
;
out
_exponent
=
max_exp
;
mantissa
=
(
1
<<
DstT
_mant
)
-
1
;
f8
_exponent
=
max_exp
;
}
else
{
return
__builtin_bit_cast
(
Y
,
static_cast
<
uint8_t
>
(
signed_inf
))
;
return
signed_inf
;
}
}
// check if x is 0.0 or -0.0
if
(
out_exponent
==
0
&&
mantissa
==
0
)
return
__builtin_bit_cast
(
Y
,
static_cast
<
uint8_t
>
(
negative_zero_nan
?
0
:
(
sign
<<
(
out_exp
+
out_mant
))));
mantissa
&=
(
1
<<
out_mant
)
-
1
;
return
__builtin_bit_cast
(
Y
,
static_cast
<
uint8_t
>
((
sign
<<
(
out_exp
+
out_mant
))
|
(
out_exponent
<<
out_mant
)
|
mantissa
));
if
(
f8_exponent
==
0
&&
mantissa
==
0
)
return
is_fnuz
?
0
:
(
sign
<<
7
);
mantissa
&=
(
1
<<
DstT_mant
)
-
1
;
return
(
sign
<<
7
)
|
(
f8_exponent
<<
DstT_mant
)
|
mantissa
;
}
template
<
typename
X
,
typename
Y
,
bool
negative_zero_nan
>
CK_TILE_HOST_DEVICE
Y
run_cast_from_f8
(
X
x
)
template
<
typename
SrcT
,
typename
DstT
,
bool
clip
=
true
>
CK_TILE_HOST_DEVICE
DstT
run_cast_from_f8
(
SrcT
x
)
{
// fp8/bf8 exponent/mantissa layout
constexpr
int
in_exp
=
numeric_traits
<
X
>::
exp
;
constexpr
int
in_mant
=
numeric_traits
<
X
>::
mant
;
// resulting type exponent/mantissa layout
constexpr
int
out_exp
=
numeric_traits
<
Y
>::
exp
;
constexpr
int
out_mant
=
numeric_traits
<
Y
>::
mant
;
uint8_t
x_raw
=
__builtin_bit_cast
(
uint8_t
,
x
);
// prepare the codes
constexpr
uint8_t
nan_code
=
0x80
;
Y
Inf
,
NegInf
,
NaN
,
Neg0
;
using
T_bitwise
=
typename
numeric_traits
<
Y
>::
bitwise_type
;
constexpr
T_bitwise
Inf_bitwise
=
numeric_traits
<
Y
>::
Inf
;
constexpr
T_bitwise
NegInf_bitwise
=
numeric_traits
<
Y
>::
NegInf
;
constexpr
T_bitwise
NaN_bitwise
=
numeric_traits
<
Y
>::
NaN
;
constexpr
T_bitwise
Neg0_bitwise
=
numeric_traits
<
Y
>::
Neg0
;
Inf
=
*
(
reinterpret_cast
<
const
Y
*>
(
&
Inf_bitwise
));
NegInf
=
*
(
reinterpret_cast
<
const
Y
*>
(
&
NegInf_bitwise
));
NaN
=
*
(
reinterpret_cast
<
const
Y
*>
(
&
NaN_bitwise
));
Neg0
=
*
(
reinterpret_cast
<
const
Y
*>
(
&
Neg0_bitwise
));
// check if x is 0.0
if
(
x_raw
==
0
)
return
static_cast
<
Y
>
(
0
);
// unpack the input
uint32_t
sign
=
x_raw
>>
(
in_exp
+
in_mant
);
uint32_t
mantissa
=
x_raw
&
((
1
<<
in_mant
)
-
1
);
int
exponent
=
(
x_raw
&
0x7F
)
>>
in_mant
;
static_assert
(
std
::
is_same
<
SrcT
,
fp8_t
>::
value
||
std
::
is_same
<
SrcT
,
bf8_t
>::
value
,
"SrcT type must be fp8 or bf8."
);
constexpr
int
SrcT_exp
=
numeric_traits
<
SrcT
>::
exp
;
constexpr
int
SrcT_mant
=
numeric_traits
<
SrcT
>::
mant
;
constexpr
bool
is_fnuz
=
(
numeric_traits
<
SrcT
>::
f8_interpret
==
fp8_interpretation
::
E4M3_FNUZ
)
||
(
numeric_traits
<
SrcT
>::
f8_interpret
==
fp8_interpretation
::
E5M2_FNUZ
);
constexpr
bool
is_half
=
std
::
is_same
<
DstT
,
half_t
>::
value
;
constexpr
bool
is_float
=
std
::
is_same
<
DstT
,
float
>::
value
;
static_assert
(
is_half
||
is_float
,
"DstT type must be half_t or float."
);
// destination type exponent/mantissa layout
constexpr
int
DstT_exp
=
numeric_traits
<
DstT
>::
exp
;
// exponent width of the destination type
constexpr
int
DstT_mant
=
numeric_traits
<
DstT
>::
mant
;
// mantissa width of the destination type
constexpr
DstT
fInf
=
bit_cast
<
DstT
>
(
numeric_traits
<
DstT
>::
Inf
);
constexpr
DstT
fNegInf
=
bit_cast
<
DstT
>
(
numeric_traits
<
DstT
>::
NegInf
);
constexpr
DstT
fNaN
=
bit_cast
<
DstT
>
(
numeric_traits
<
DstT
>::
NaN
);
constexpr
DstT
fNeg0
=
bit_cast
<
DstT
>
(
numeric_traits
<
DstT
>::
Neg0
);
DstT
fmax
{
0
},
fmin
{
0
};
// Max number in e5m2 57344
if
constexpr
(
is_half
)
{
fmax
=
bit_cast
<
DstT
>
(
static_cast
<
typename
numeric_traits
<
DstT
>::
bitwise_type
>
(
0x7B00
));
fmin
=
bit_cast
<
DstT
>
(
static_cast
<
typename
numeric_traits
<
DstT
>::
bitwise_type
>
(
0xFB00
));
}
else
if
constexpr
(
is_float
)
{
fmax
=
bit_cast
<
DstT
>
(
static_cast
<
typename
numeric_traits
<
DstT
>::
bitwise_type
>
(
0x47600000
));
fmin
=
bit_cast
<
DstT
>
(
static_cast
<
typename
numeric_traits
<
DstT
>::
bitwise_type
>
(
0xC7600000
));
}
constexpr
int
exp_low_cutoff
=
(
1
<<
(
out_exp
-
1
))
-
(
1
<<
(
in_exp
-
1
))
+
1
-
(
negative_zero_nan
?
1
:
0
);
T_bitwise
retval
;
if
(
x
==
0
)
{
return
0
;
}
if
constexpr
(
negative_zero_nan
)
unsigned
long
long
sign
=
x
>>
7
;
unsigned
long
long
mantissa
=
x
&
((
1
<<
SrcT_mant
)
-
1
);
int
exponent
=
(
x
&
0x7F
)
>>
SrcT_mant
;
if
constexpr
(
is_fnuz
)
{
if
(
x_raw
==
nan_code
)
return
NaN
;
if
(
x
==
0x80
)
{
return
fNaN
;
}
}
else
{
if
(
x_raw
==
nan_code
)
return
Neg0
;
if
(
exponent
==
((
1
<<
in_exp
)
-
1
))
return
(
mantissa
==
0
)
?
(
sign
?
NegInf
:
Inf
)
:
NaN
;
if
(
x
==
0x80
)
{
return
fNeg0
;
}
if
constexpr
(
SrcT_exp
==
4
)
{
// e4m3
if
((
x
&
0x7F
)
==
0x7F
)
{
return
fNaN
;
}
}
else
if
((
x
&
0x7C
)
==
0x7C
)
{
// e5m2
if
((
x
&
0x3
)
==
0
)
{
if
constexpr
(
clip
)
{
return
sign
?
fmin
:
fmax
;
}
return
sign
?
fNegInf
:
fInf
;
}
return
fNaN
;
}
}
if
((
numeric_traits
<
Y
>::
mant
==
10
)
&&
(
numeric_traits
<
X
>::
mant
==
2
)
&&
!
negative_zero_nan
)
typename
numeric_traits
<
DstT
>::
bitwise_type
retval
;
if
constexpr
(
SrcT_exp
==
5
&&
is_half
&&
!
is_fnuz
)
{
retval
=
x_raw
;
retval
<<=
8
;
return
*
(
reinterpret_cast
<
const
Y
*>
(
&
retval
));
retval
=
x
<<
8
;
return
bit_cast
<
DstT
>
(
retval
);
}
const
int
exp_low_cutoff
=
(
1
<<
(
DstT_exp
-
1
))
-
(
1
<<
(
SrcT_exp
-
1
))
+
1
-
(
is_fnuz
?
1
:
0
);
// subnormal input
if
(
exponent
==
0
)
{
// guaranteed mantissa!=0 since cases 0x0 and 0x80 are handled above
int
sh
=
1
+
clz
(
mantissa
)
-
(
32
-
in_mant
);
int
sh
=
1
+
clz
(
mantissa
)
-
(
32
-
SrcT_mant
);
mantissa
<<=
sh
;
exponent
+=
1
-
sh
;
mantissa
&=
((
1
<<
in
_mant
)
-
1
);
mantissa
&=
((
1
ull
<<
SrcT
_mant
)
-
1
);
}
exponent
+=
exp_low_cutoff
-
1
;
mantissa
<<=
out
_mant
-
in
_mant
;
mantissa
<<=
DstT
_mant
-
SrcT
_mant
;
// subnormal output (occurs when
T=
half, we=5,
negative_zero_nan
=true)
// subnormal output (occurs when
DstT is
half
_t
, we=5,
is_fnuz
=true)
if
(
exponent
<=
0
)
{
mantissa
|=
1
<<
out
_mant
;
mantissa
|=
1
<<
DstT
_mant
;
mantissa
>>=
1
-
exponent
;
exponent
=
0
;
}
retval
=
(
sign
<<
(
out_exp
+
out_mant
))
|
(
exponent
<<
out_mant
)
|
mantissa
;
return
*
(
reinterpret_cast
<
const
Y
*>
(
&
retval
));
}
template
<
typename
X
,
typename
Y
,
bool
negative_zero_nan
,
bool
clip
,
bool
stoch
>
CK_TILE_HOST_DEVICE
Y
cast_to_f8
(
X
x
,
uint32_t
rng
)
{
// check datatypes
constexpr
bool
is_half
=
std
::
is_same
<
X
,
half_t
>::
value
;
constexpr
bool
is_float
=
std
::
is_same
<
X
,
float
>::
value
;
static_assert
(
is_half
||
is_float
,
"Only half and float can be casted."
);
retval
=
(
sign
<<
(
DstT_exp
+
DstT_mant
))
|
(
exponent
<<
DstT_mant
)
|
mantissa
;
return
run
_cast
_to_f8
<
X
,
Y
,
negative_zero_nan
,
clip
,
stoch
>
(
x
,
rng
);
return
bit
_cast
<
DstT
>
(
retval
);
}
template
<
typename
X
,
typename
Y
,
bool
negative_zero_nan
>
CK_TILE_HOST_DEVICE
Y
cast_
from
_f8
(
X
x
)
template
<
typename
X
,
typename
Y
,
bool
clip
,
bool
stoch
>
CK_TILE_HOST_DEVICE
Y
cast_
to
_f8
(
X
x
,
uint32_t
rng
)
{
// check datatype
constexpr
bool
is_half
=
std
::
is_same
<
Y
,
half_t
>::
value
;
constexpr
bool
is_float
=
std
::
is_same
<
Y
,
float
>::
value
;
static_assert
(
is_half
||
is_float
,
"only half and float are supported."
);
return
run_cast_from_f8
<
X
,
Y
,
negative_zero_nan
>
(
x
);
return
bit_cast
<
Y
>
(
run_cast_to_f8
<
X
,
Y
,
clip
,
stoch
>
(
x
,
rng
));
}
}
// namespace impl
CK_TILE_HOST_DEVICE
fp8_raw_t
float_to_fp8_sr_raw
(
float
x
)
#if CK_TILE_FP8_CVT_DEVICE
/**
* @brief Cast float to fp8/bf8 using device conversion instructions
*/
template
<
fp8_interpretation
interpret
,
bool
saturate
,
bool
stochastic_rounding
=
false
>
CK_TILE_DEVICE
uint8_t
cast_to_f8_from_f32
(
float
v
,
unsigned
int
rng
=
0
)
{
constexpr
int
seed
=
42
;
uint32_t
rng
=
prand_generator_t
<
float
,
seed
>
{}(
reinterpret_cast
<
uintptr_t
>
(
&
x
),
x
);
#if defined(__gfx94__)
float
max_fp8
=
240.0
f
;
x
=
x
>
max_fp8
?
max_fp8
:
(
x
<
-
max_fp8
?
-
max_fp8
:
x
);
uint8_t
i8data
;
union
{
float
fval
;
u
int32_
t
i32val
;
u
int8_t
i8val
[
4
];
// not endian independent
u
nsigned
in
t
i32val
;
u
nsigned
char
i8val
[
4
];
//
NOTE:
not endian independent
}
val
;
val
.
fval
=
x
;
uint32_t
ival
=
0
;
ival
=
__builtin_amdgcn_cvt_sr_fp8_f32
(
val
.
fval
,
rng
,
ival
,
0
);
// 0 pos
val
.
i32val
=
ival
;
return
val
.
i8val
[
0
];
// little endian
#else
constexpr
bool
negative_zero_nan
=
true
;
constexpr
bool
clip
=
true
;
constexpr
fp8_rounding_mode
rm
=
fp8_rounding_mode
::
stochastic
;
return
bit_cast
<
fp8_raw_t
>
(
impl
::
cast_to_f8
<
float
,
fp8_t
,
negative_zero_nan
,
clip
,
(
rm
==
fp8_rounding_mode
::
stochastic
)
>
(
x
,
rng
));
#endif
}
CK_TILE_HOST_DEVICE
bf8_raw_t
float_to_bf8_sr_raw
(
float
x
)
{
constexpr
int
seed
=
42
;
uint32_t
rng
=
prand_generator_t
<
float
,
seed
>
{}(
reinterpret_cast
<
uintptr_t
>
(
&
x
),
x
);
#if defined(__gfx94__)
union
unsigned
int
ival
=
0
;
val
.
fval
=
v
;
if
constexpr
(
saturate
)
{
float
fval
;
uint32_t
i32val
;
uint8_t
i8val
[
4
];
// not endian independent
}
val
;
val
.
fval
=
x
;
uint32_t
ival
=
0
;
ival
=
__builtin_amdgcn_cvt_sr_bf8_f32
(
val
.
fval
,
rng
,
ival
,
0
);
// 0 pos
val
.
i32val
=
ival
;
return
val
.
i8val
[
0
];
// little endian
#else
constexpr
bool
negative_zero_nan
=
true
;
constexpr
bool
clip
=
true
;
constexpr
fp8_rounding_mode
rm
=
fp8_rounding_mode
::
stochastic
;
return
bit_cast
<
bf8_raw_t
>
(
impl
::
cast_to_f8
<
float
,
bf8_t
,
negative_zero_nan
,
clip
,
(
rm
==
fp8_rounding_mode
::
stochastic
)
>
(
x
,
rng
));
#endif
if
constexpr
(
interpret
==
fp8_interpretation
::
E4M3_FNUZ
)
{
if
((
val
.
i32val
&
0x7F800000
)
!=
0x7F800000
)
{
/// propagate NAN/INF, no clipping
val
.
fval
=
__builtin_amdgcn_fmed3f
(
val
.
fval
,
240.0
,
-
240.0
);
}
}
else
if
constexpr
(
interpret
==
fp8_interpretation
::
E4M3_OCP
)
{
// OCP type
if
((
val
.
i32val
&
0x7F800000
)
!=
0x7F800000
)
{
/// propagate NAN/INF, no clipping
val
.
fval
=
__builtin_amdgcn_fmed3f
(
val
.
fval
,
448.0
,
-
448.0
);
}
}
else
{
if
((
val
.
i32val
&
0x7F800000
)
!=
0x7F800000
)
{
/// propagate NAN/INF, no clipping
val
.
fval
=
__builtin_amdgcn_fmed3f
(
val
.
fval
,
57344.0
,
-
57344.0
);
}
}
}
if
constexpr
(
stochastic_rounding
)
{
ival
=
(
interpret
==
fp8_interpretation
::
E4M3_FNUZ
)
||
(
interpret
==
fp8_interpretation
::
E4M3_OCP
)
?
__builtin_amdgcn_cvt_sr_fp8_f32
(
val
.
fval
,
rng
,
ival
,
0
)
:
__builtin_amdgcn_cvt_sr_bf8_f32
(
val
.
fval
,
rng
,
ival
,
0
);
// 0 pos
val
.
i32val
=
ival
;
i8data
=
val
.
i8val
[
0
];
// little endian
}
else
{
// RNE CVT
ival
=
(
interpret
==
fp8_interpretation
::
E4M3_FNUZ
)
||
(
interpret
==
fp8_interpretation
::
E4M3_OCP
)
?
__builtin_amdgcn_cvt_pk_fp8_f32
(
val
.
fval
,
val
.
fval
,
ival
,
false
)
:
__builtin_amdgcn_cvt_pk_bf8_f32
(
val
.
fval
,
val
.
fval
,
ival
,
false
);
// false -> WORD0
val
.
i32val
=
ival
;
i8data
=
val
.
i8val
[
0
];
}
return
i8data
;
}
#endif // CK_TILE_FP8_CVT_DEVICE
CK_TILE_HOST_DEVICE
fp8_raw_t
float_to_fp8_rtn_raw
(
float
x
)
}
// namespace impl
/**
* @brief Converts a floating-point value to an 8-bit floating-point representation with stochastic
* rounding.
*
* This function converts a floating-point value (float or half_t) to an 8-bit floating-point
* representation of type fp8_t or bf8_t. The conversion process may
* involve clipping and uses a pseudo-random number generator for the stochastic rounding.
*
* @tparam DstT The destination type (fp8_t or bf8_t).
* @tparam SrcT The source type (float or half_t) to be converted.
* @param x The floating-point value to be converted.
* @return The 8-bit floating-point representation of the input value.
*/
template
<
typename
SrcT
,
typename
DstT
>
CK_TILE_HOST_DEVICE
typename
numeric_traits
<
DstT
>::
bitwise_type
float_to_fp8_sr_raw
(
SrcT
x
)
{
#if defined(__gfx94__)
float
max_fp8
=
240.0
f
;
x
=
x
>
max_fp8
?
max_fp8
:
(
x
<
-
max_fp8
?
-
max_fp8
:
x
);
union
{
float
fval
;
uint32_t
i32val
;
uint8_t
i8val
[
4
];
// not endian independent
}
val
;
val
.
fval
=
x
;
uint32_t
ival
=
0
;
ival
=
__builtin_amdgcn_cvt_pk_fp8_f32
(
val
.
fval
,
val
.
fval
,
ival
,
false
);
// false -> WORD0
val
.
i32val
=
ival
;
return
val
.
i8val
[
0
];
constexpr
bool
clip
=
true
;
constexpr
int
seed
=
42
;
uint32_t
rng
=
prand_generator_t
<
SrcT
,
seed
>
{}(
reinterpret_cast
<
uintptr_t
>
(
&
x
),
x
);
#if CK_TILE_FP8_CVT_DEVICE
return
impl
::
cast_to_f8_from_f32
<
numeric_traits
<
DstT
>::
f8_interpret
,
clip
,
true
>
(
x
,
rng
);
#else
constexpr
bool
negative_zero_nan
=
true
;
constexpr
bool
clip
=
true
;
constexpr
fp8_rounding_mode
rm
=
fp8_rounding_mode
::
standard
;
constexpr
uint32_t
rng
=
0
;
return
bit_cast
<
fp8_raw_t
>
(
impl
::
cast_to_f8
<
float
,
fp8_t
,
negative_zero_nan
,
clip
,
(
rm
==
fp8_rounding_mode
::
stochastic
)
>
(
x
,
rng
));
return
bit_cast
<
typename
numeric_traits
<
DstT
>::
bitwise_type
>
(
impl
::
cast_to_f8
<
SrcT
,
DstT
,
clip
,
true
>
(
x
,
rng
));
#endif
}
CK_TILE_HOST_DEVICE
bf8_raw_t
float_to_bf8_rtn_raw
(
float
x
)
/**
* @brief Converts a floating-point value to an 8-bit floating-point representation with rounding to
* nearest even.
*
* This function converts a floating-point value (float or half_t) to an 8-bit floating-point
* representation of type fp8_t or bf8_t. The conversion process may involve clipping.
*
* @tparam DstT The destination type (fp8_t or bf8_t).
* @tparam SrcT The source type (float or half_t) to be converted.
* @param x The floating-point value to be converted.
* @return The 8-bit floating-point representation of the input value.
*/
template
<
typename
SrcT
,
typename
DstT
>
CK_TILE_HOST_DEVICE
typename
numeric_traits
<
DstT
>::
bitwise_type
float_to_fp8_rtn_raw
(
SrcT
x
)
{
#if defined(__gfx94__)
union
{
float
fval
;
uint32_t
i32val
;
uint8_t
i8val
[
4
];
// not endian independent
}
val
;
val
.
fval
=
x
;
uint32_t
ival
=
0
;
ival
=
__builtin_amdgcn_cvt_pk_bf8_f32
(
val
.
fval
,
val
.
fval
,
ival
,
false
);
// false -> WORD0
val
.
i32val
=
ival
;
return
val
.
i8val
[
0
];
constexpr
bool
clip
=
true
;
#if CK_TILE_FP8_CVT_DEVICE
return
impl
::
cast_to_f8_from_f32
<
numeric_traits
<
DstT
>::
f8_interpret
,
clip
,
false
>
(
x
,
0
);
#else
constexpr
bool
negative_zero_nan
=
true
;
constexpr
bool
clip
=
true
;
constexpr
fp8_rounding_mode
rm
=
fp8_rounding_mode
::
standard
;
constexpr
uint32_t
rng
=
0
;
return
bit_cast
<
bf8_raw_t
>
(
impl
::
cast_to_f8
<
float
,
bf8_t
,
negative_zero_nan
,
clip
,
(
rm
==
fp8_rounding_mode
::
stochastic
)
>
(
x
,
rng
));
return
bit_cast
<
typename
numeric_traits
<
DstT
>::
bitwise_type
>
(
impl
::
cast_to_f8
<
SrcT
,
DstT
,
clip
,
false
>
(
x
,
0
));
#endif
}
// clang-format off
template
<
fp8_rounding_mode
rounding
>
template
<
fp8_rounding_mode
rounding
>
CK_TILE_HOST_DEVICE
fp8_raw_t
float_to_fp8_raw
(
float
x
,
constant
<
rounding
>
)
{
if
constexpr
(
rounding
==
fp8_rounding_mode
::
standard
)
return
float_to_fp8_rtn_raw
(
x
);
else
if
constexpr
(
rounding
==
fp8_rounding_mode
::
stochastic
)
return
float_to_fp8_sr_raw
(
x
);
else
return
fp8_raw_t
{
0
};
if
constexpr
(
rounding
==
fp8_rounding_mode
::
standard
)
{
return
float_to_fp8_rtn_raw
<
float
,
fp8_t
>
(
x
);
}
else
if
constexpr
(
rounding
==
fp8_rounding_mode
::
stochastic
)
{
return
float_to_fp8_sr_raw
<
float
,
fp8_t
>
(
x
);
}
else
{
return
fp8_raw_t
{
0
};
}
}
template
<
fp8_rounding_mode
rounding
>
template
<
fp8_rounding_mode
rounding
>
CK_TILE_HOST_DEVICE
bf8_raw_t
float_to_bf8_raw
(
float
x
,
constant
<
rounding
>
)
{
if
constexpr
(
rounding
==
fp8_rounding_mode
::
standard
)
return
float_to_bf8_rtn_raw
(
x
);
else
if
constexpr
(
rounding
==
fp8_rounding_mode
::
stochastic
)
return
float_to_bf8_sr_raw
(
x
);
else
return
bf8_raw_t
{
0
};
if
constexpr
(
rounding
==
fp8_rounding_mode
::
standard
)
{
return
float_to_fp8_rtn_raw
<
float
,
bf8_t
>
(
x
);
}
else
if
constexpr
(
rounding
==
fp8_rounding_mode
::
stochastic
)
{
return
float_to_fp8_sr_raw
<
float
,
bf8_t
>
(
x
);
}
else
{
return
bf8_raw_t
{
0
};
}
}
CK_TILE_HOST_DEVICE
float
fp8_to_float_raw
(
fp8_raw_t
x
)
{
#if
defined(__gfx94__)
#if
CK_TILE_FP8_CVT_DEVICE
float
fval
;
uint32_t
i32val
=
static_cast
<
uint32_t
>
(
x
);
fval
=
__builtin_amdgcn_cvt_f32_fp8
(
i32val
,
0
);
// asm volatile("v_cvt_f32_fp8 %0, %1 src0_sel:BYTE_0" : "=v"(fval) : "v"(i32val));
return
fval
;
#else
constexpr
bool
negative_zero_nan
=
true
;
return
impl
::
cast_from_f8
<
fp8_t
,
float
,
negative_zero_nan
>
(
bit_cast
<
fp8_t
>
(
x
));
return
impl
::
run_cast_from_f8
<
fp8_t
,
float
>
(
bit_cast
<
fp8_t
>
(
x
));
#endif
}
CK_TILE_HOST_DEVICE
float
bf8_to_float_raw
(
bf8_raw_t
x
)
{
#if
defined(__gfx94__)
#if
CK_TILE_FP8_CVT_DEVICE
float
fval
;
uint32_t
i32val
=
static_cast
<
uint32_t
>
(
x
);
fval
=
__builtin_amdgcn_cvt_f32_bf8
(
i32val
,
0
);
// asm volatile("v_cvt_f32_bf8 %0, %1 src0_sel:BYTE_0" : "=v"(fval) : "v"(i32val));
return
fval
;
#else
constexpr
bool
negative_zero_nan
=
true
;
return
impl
::
cast_from_f8
<
bf8_t
,
float
,
negative_zero_nan
>
(
bit_cast
<
bf8_t
>
(
x
));
return
impl
::
run_cast_from_f8
<
bf8_t
,
float
>
(
bit_cast
<
bf8_t
>
(
x
));
#endif
}
template
<
fp8_rounding_mode
rounding
=
static_cast
<
fp8_rounding_mode
>(
CK_TILE_FLOAT_TO_FP8_DEFAULT
)
>
template
<
fp8_rounding_mode
rounding
=
static_cast
<
fp8_rounding_mode
>(
CK_TILE_FLOAT_TO_FP8_DEFAULT
)
>
CK_TILE_HOST_DEVICE
fp8_t
float_to_fp8
(
float
x
,
constant
<
rounding
>
=
{})
{
return
bit_cast
<
fp8_t
>
(
float_to_fp8_raw
(
x
,
constant
<
rounding
>
{}));
}
template
<
fp8_rounding_mode
rounding
=
static_cast
<
fp8_rounding_mode
>(
CK_TILE_FLOAT_TO_FP8_DEFAULT
)
>
template
<
fp8_rounding_mode
rounding
=
static_cast
<
fp8_rounding_mode
>(
CK_TILE_FLOAT_TO_FP8_DEFAULT
)
>
CK_TILE_HOST_DEVICE
bf8_t
float_to_bf8
(
float
x
,
constant
<
rounding
>
=
{})
{
return
bit_cast
<
bf8_t
>
(
float_to_bf8_raw
(
x
,
constant
<
rounding
>
{}));
}
CK_TILE_HOST_DEVICE
float
fp8_to_float
(
fp8_t
x
)
{
return
fp8_to_float_raw
(
bit_cast
<
fp8_raw_t
>
(
x
));
}
CK_TILE_HOST_DEVICE
float
fp8_to_float
(
fp8_t
x
)
{
return
fp8_to_float_raw
(
bit_cast
<
fp8_raw_t
>
(
x
));
}
CK_TILE_HOST_DEVICE
float
bf8_to_float
(
bf8_t
x
)
{
return
bf8_to_float_raw
(
bit_cast
<
bf8_raw_t
>
(
x
));
}
CK_TILE_HOST_DEVICE
float
bf8_to_float
(
bf8_t
x
)
{
return
bf8_to_float_raw
(
bit_cast
<
bf8_raw_t
>
(
x
));
}
// clang-format on
template
<
typename
T
>
struct
numeric_traits
;
template
<
class
T
>
struct
numeric
;
#if CK_TILE_USE_OCP_FP8
template
<
>
struct
numeric
_traits
<
fp8_t
>
struct
numeric
<
fp8_t
>
{
static
constexpr
int
exp
=
4
;
static
constexpr
int
mant
=
3
;
#if defined(__gfx94__)
static
constexpr
int
bias
=
8
;
#else
static
constexpr
int
bias
=
7
;
#endif
// minimum finite value, or minimum positive normal value
CK_TILE_HOST_DEVICE
static
constexpr
fp8_t
min
()
{
return
bit_cast
<
fp8_t
>
(
static_cast
<
fp8_raw_t
>
(
0x08
));
// 0b00001000 = 2^-6
}
// minumum finite value
CK_TILE_HOST_DEVICE
static
constexpr
fp8_t
lowest
()
{
return
bit_cast
<
fp8_t
>
(
static_cast
<
fp8_raw_t
>
(
0xfe
));
// 0b11111110 = -448
}
// maximum finite value
CK_TILE_HOST_DEVICE
static
constexpr
fp8_t
max
()
{
return
bit_cast
<
fp8_t
>
(
static_cast
<
fp8_raw_t
>
(
0x7e
));
// 0b01111110 = 448
}
// difference between 1.0 and next representable f8 value (1.125)
// returns fp8_t(0.125)
CK_TILE_HOST_DEVICE
static
constexpr
fp8_t
epsilon
()
{
return
bit_cast
<
fp8_t
>
(
static_cast
<
fp8_raw_t
>
(
0x20
));
// 0.125
}
// rounding error (0.0625)
// half of epsilon
CK_TILE_HOST_DEVICE
static
constexpr
fp8_t
round_error
()
{
return
bit_cast
<
fp8_t
>
(
static_cast
<
fp8_raw_t
>
(
0x18
));
// 0.0625
}
// quiet NaN
CK_TILE_HOST_DEVICE
static
constexpr
fp8_t
quiet_NaN
()
{
return
bit_cast
<
fp8_t
>
(
static_cast
<
fp8_raw_t
>
(
0x7F
));
// 0b01111111
}
// signaling NaN
CK_TILE_HOST_DEVICE
static
constexpr
fp8_t
signaling_NaN
()
{
return
bit_cast
<
fp8_t
>
(
static_cast
<
fp8_raw_t
>
(
0xFF
));
// 0b11111111
}
// smallest positive subnormal value
CK_TILE_HOST_DEVICE
static
constexpr
fp8_t
denorm_min
()
{
return
bit_cast
<
fp8_t
>
(
static_cast
<
fp8_raw_t
>
(
0x01
));
}
CK_TILE_HOST_DEVICE
static
constexpr
fp8_t
zero
()
{
return
bit_cast
<
fp8_t
>
(
static_cast
<
fp8_raw_t
>
(
0
));
}
};
template
<
>
struct
numeric
_traits
<
bf8_t
>
struct
numeric
<
bf8_t
>
{
static
constexpr
int
exp
=
5
;
static
constexpr
int
mant
=
2
;
#if defined(__gfx94__)
static
constexpr
int
bias
=
16
;
#else
static
constexpr
int
bias
=
15
;
// IEEE
#endif
};
// minimum finite value, or minimum positive normalized value for float
CK_TILE_HOST_DEVICE
static
constexpr
bf8_t
min
()
{
return
bit_cast
<
bf8_t
>
(
static_cast
<
bf8_raw_t
>
(
0x04
));
// 0b00000100 = 2^-14
}
template
<
class
T
>
struct
numeric
;
// minumum finite value
CK_TILE_HOST_DEVICE
static
constexpr
bf8_t
lowest
()
{
return
bit_cast
<
bf8_t
>
(
static_cast
<
bf8_raw_t
>
(
0xfb
));
// 0b11111011 = -57344
}
// maximum finite value
CK_TILE_HOST_DEVICE
static
constexpr
bf8_t
max
()
{
return
bit_cast
<
bf8_t
>
(
static_cast
<
bf8_raw_t
>
(
0x7b
));
// 0b01111011 = 57344
}
// difference between 1.0 and next representable bf8 value (1.25)
CK_TILE_HOST_DEVICE
static
constexpr
bf8_t
epsilon
()
{
return
bit_cast
<
bf8_t
>
(
static_cast
<
bf8_raw_t
>
(
0x34
));
// 0.25
}
// rounding error (0.125)
// half of epsilon
CK_TILE_HOST_DEVICE
static
constexpr
bf8_t
round_error
()
{
return
bit_cast
<
bf8_t
>
(
static_cast
<
bf8_raw_t
>
(
0x30
));
// 0.125
}
// positive infinity value
CK_TILE_HOST_DEVICE
static
constexpr
bf8_t
infinity
()
{
return
bit_cast
<
bf8_t
>
(
static_cast
<
bf8_raw_t
>
(
0x7c
));
// 0b01111100
}
// quiet NaN
CK_TILE_HOST_DEVICE
static
constexpr
bf8_t
quiet_NaN
()
{
return
bit_cast
<
bf8_t
>
(
static_cast
<
bf8_raw_t
>
(
0x7F
));
// 0b01111111
}
// signaling NaN
CK_TILE_HOST_DEVICE
static
constexpr
bf8_t
signaling_NaN
()
{
return
bit_cast
<
bf8_t
>
(
static_cast
<
bf8_raw_t
>
(
0xFF
));
}
// smallest positive subnormal value
CK_TILE_HOST_DEVICE
static
constexpr
bf8_t
denorm_min
()
{
return
bit_cast
<
bf8_t
>
(
static_cast
<
bf8_raw_t
>
(
0x01
));
}
CK_TILE_HOST_DEVICE
static
constexpr
bf8_t
zero
()
{
return
bit_cast
<
bf8_t
>
(
static_cast
<
bf8_raw_t
>
(
0
));
}
};
#else
template
<
>
struct
numeric
<
fp8_t
>
{
...
...
@@ -811,6 +1054,7 @@ struct numeric<bf8_t>
return
bit_cast
<
bf8_t
>
(
static_cast
<
bf8_raw_t
>
(
0
));
}
};
#endif
#if CK_TILE_USE_CUSTOM_DATA_TYPE
CK_TILE_ARITHMETIC_USING_FLOAT
(
CK_TILE_HOST_DEVICE
,
fp8_t
)
...
...
@@ -818,19 +1062,26 @@ CK_TILE_ARITHMETIC_USING_FLOAT(CK_TILE_HOST_DEVICE, bf8_t)
#endif
// math
CK_TILE_HOST_DEVICE
fp8_t
abs
(
const
fp8_t
&
x
)
template
<
typename
T
>
CK_TILE_HOST_DEVICE
T
abs
(
const
T
&
x
)
{
return
bit_cast
<
fp8_t
>
(
static_cast
<
fp8_raw_t
>
(
bit_cast
<
fp8_raw_t
>
(
x
)
&
0x7f
));
static_assert
(
std
::
is_same_v
<
T
,
fp8_t
>
||
std
::
is_same_v
<
T
,
bf8_t
>
,
"Only fp8_t and bf8_t are supported"
);
return
bit_cast
<
T
>
(
static_cast
<
uint8_t
>
(
bit_cast
<
uint8_t
>
(
x
)
&
numeric_traits
<
T
>::
abs_mask
));
}
CK_TILE_HOST_DEVICE
bool
isnan
(
const
fp8_t
&
x
)
{
uint8_t
xx
=
bit_cast
<
fp8_raw_t
>
(
x
);
return
xx
==
0x80
;
// TODO: NANOO
}
#if CK_TILE_USE_OCP_FP8
return
(
xx
&
0x7f
)
==
0x7f
;
#else
return
xx
==
0x80
;
#endif
}
#if CK_TILE_USE_CUSTOM_DATA_TYPE
CK_TILE_DEVICE
fp8_t
sqrt
(
fp8_t
x
)
{
return
static_cast
<
fp8_t
>
(
__builtin_amdgcn_sqrtf
(
static_cast
<
float
>
(
x
)));
};
...
...
@@ -842,20 +1093,21 @@ fp8_t exp2(fp8_t x) { return static_cast<fp8_t>(exp2f(static_cast<float>(x))); }
CK_TILE_DEVICE
fp8_t
log
(
fp8_t
x
)
{
return
static_cast
<
fp8_t
>
(
__logf
(
static_cast
<
float
>
(
x
)));
};
CK_TILE_HOST_DEVICE
bf8_t
abs
(
const
bf8_t
&
x
)
{
return
bit_cast
<
bf8_t
>
(
static_cast
<
fp8_raw_t
>
(
bit_cast
<
bf8_raw_t
>
(
x
)
&
0x7f
));
}
#endif
CK_TILE_HOST_DEVICE
bool
isnan
(
const
bf8_t
&
x
)
{
uint8_t
xx
=
bit_cast
<
bf8_raw_t
>
(
x
);
return
xx
==
0x80
;
// TODO: NANOO
#if CK_TILE_USE_OCP_FP8
return
(
xx
&
0x7f
)
>
0x7c
;
#else
return
xx
==
0x80
;
#endif
}
#if CK_TILE_USE_CUSTOM_DATA_TYPE
CK_TILE_DEVICE
bf8_t
sqrt
(
bf8_t
x
)
{
return
static_cast
<
bf8_t
>
(
__builtin_amdgcn_sqrtf
(
static_cast
<
float
>
(
x
)));
};
...
...
@@ -867,5 +1119,6 @@ bf8_t exp2(bf8_t x) { return static_cast<bf8_t>(exp2f(static_cast<float>(x))); }
CK_TILE_DEVICE
bf8_t
log
(
bf8_t
x
)
{
return
static_cast
<
bf8_t
>
(
__logf
(
static_cast
<
float
>
(
x
)));
};
#endif
}
// namespace ck_tile
include/ck_tile/core/numeric/half.hpp
View file @
3c5717df
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-202
4
, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-202
5
, Advanced Micro Devices, Inc. All rights reserved.
#include "ck_tile/core/config.hpp"
#include "ck_tile/core/utility/bit_cast.hpp"
...
...
@@ -236,10 +236,11 @@ struct numeric_traits<half_t>
static
constexpr
uint16_t
head_mask
=
0xFC00
;
static
constexpr
uint16_t
mant_mask
=
0x3FF
;
static
constexpr
uint16_t
exp_mask
=
0x1F
;
static
constexpr
uint32_t
Inf
=
0x7C00
;
static
constexpr
uint32_t
NegInf
=
0xFC00
;
static
constexpr
uint32_t
NaN
=
0x7C01
;
static
constexpr
uint32_t
Neg0
=
0x8000
;
static
constexpr
uint16_t
abs_mask
=
0x7FFF
;
static
constexpr
uint16_t
Inf
=
0x7C00
;
static
constexpr
uint16_t
NegInf
=
0xFC00
;
static
constexpr
uint16_t
NaN
=
0x7C01
;
static
constexpr
uint16_t
Neg0
=
0x8000
;
using
bitwise_type
=
uint16_t
;
};
...
...
include/ck_tile/core/numeric/numeric.hpp
View file @
3c5717df
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-202
4
, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-202
5
, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
...
...
@@ -89,6 +89,7 @@ struct numeric_traits<float>
static
constexpr
uint32_t
head_mask
=
0xFF800000
;
static
constexpr
uint32_t
mant_mask
=
0x7FFFFF
;
static
constexpr
uint32_t
exp_mask
=
0xFF
;
static
constexpr
uint32_t
abs_mask
=
0x7FFFFFFF
;
static
constexpr
uint32_t
Inf
=
0x7F800000
;
static
constexpr
uint32_t
NegInf
=
0xFF800000
;
static
constexpr
uint32_t
NaN
=
0x7F800001
;
...
...
include/ck_tile/core/numeric/pk_int4.hpp
0 → 100644
View file @
3c5717df
// SPDX-License-Identifier: MIT
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
#include "ck_tile/core/config.hpp"
#include "ck_tile/core/numeric/half.hpp"
#include "ck_tile/core/numeric/integral_constant.hpp"
#include "ck_tile/core/numeric/math.hpp"
#include "ck_tile/core/numeric/numeric.hpp"
#include "ck_tile/core/utility/bit_cast.hpp"
#include "ck_tile/core/utility/random.hpp"
#include <stdint.h>
#include <type_traits>
#include "ck_tile/core/numeric/int8.hpp"
#pragma once
namespace
ck_tile
{
// Packed 2xint4
struct
pk_int4_t
{
using
type
=
int8_t
;
type
data
;
__host__
__device__
constexpr
pk_int4_t
()
:
data
{
type
{}}
{}
__host__
__device__
constexpr
pk_int4_t
(
type
init
)
:
data
{
init
}
{}
};
// limits
template
<
class
T
>
struct
numeric
;
template
<
>
struct
numeric
<
pk_int4_t
>
{
// minimum finite value, or minimum positive normalized value for float
CK_TILE_HOST_DEVICE
static
constexpr
pk_int4_t
min
()
{
constexpr
uint8_t
val
=
0b10001000
;
return
pk_int4_t
(
bit_cast
<
int8_t
>
(
val
));
}
// minumum finite value
CK_TILE_HOST_DEVICE
static
constexpr
pk_int4_t
lowest
()
{
constexpr
uint8_t
val
=
0b10001000
;
return
pk_int4_t
(
bit_cast
<
int8_t
>
(
val
));
}
// maximum finite value
CK_TILE_HOST_DEVICE
static
constexpr
pk_int4_t
max
()
{
constexpr
uint8_t
val
=
0b01110111
;
return
pk_int4_t
(
bit_cast
<
int8_t
>
(
val
));
}
// difference between 1.0 and next value representable by float
CK_TILE_HOST_DEVICE
static
constexpr
pk_int4_t
epsilon
()
{
return
1
;
// not used
}
CK_TILE_HOST_DEVICE
static
constexpr
pk_int4_t
round_error
()
{
return
1
;
// not used
}
// positive infinity value
CK_TILE_HOST_DEVICE
static
constexpr
pk_int4_t
infinity
()
{
return
1
;
// not used
}
// quiet NaN
CK_TILE_HOST_DEVICE
static
constexpr
pk_int4_t
quiet_NaN
()
{
return
1
;
// not used
}
// signaling NaN
CK_TILE_HOST_DEVICE
static
constexpr
pk_int4_t
signaling_NaN
()
{
return
1
;
// not used
}
// smallest positive subnormal value
CK_TILE_HOST_DEVICE
static
constexpr
pk_int4_t
denorm_min
()
{
return
1
;
// not used
}
CK_TILE_HOST_DEVICE
static
constexpr
pk_int4_t
zero
()
{
return
0
;
}
};
CK_TILE_HOST_DEVICE
fp32x2_t
pk_int4_t_to_fp32x2_t
(
const
pk_int4_t
&
x
)
{
uint8_t
x_u8
=
ck_tile
::
bit_cast
<
uint8_t
>
(
x
);
float
x_l
=
((
x_u8
&
0x0f
)
>>
0
)
-
8.
f
;
float
x_h
=
((
x_u8
&
0xf0
)
>>
4
)
-
8.
f
;
#ifdef CK_TILE_USE_PK4_LAYOUT_SHUFFLE
fp32x2_t
res
=
{
x_h
,
x_l
};
#elif
fp32x2_t
res
=
{
x_l
,
x_h
};
#endif
return
res
;
}
CK_TILE_HOST_DEVICE
fp16x2_t
pk_int4_t_to_halfx2_t
(
const
pk_int4_t
&
x
)
{
uint8_t
x_u8
=
ck_tile
::
bit_cast
<
uint8_t
>
(
x
);
#ifdef CK_TILE_USE_PK4_LAYOUT_SHUFFLE
uint32_t
i4s
=
((
x_u8
&
0x0f
)
<<
16
)
|
((
x_u8
&
0xf0
)
>>
4
);
#elif
uint32_t
i4s
=
((
x_u8
&
0xf0
)
<<
12
)
|
(
x_u8
&
0xf
);
#endif
const
int
EX
=
0x64006400
;
const
int
SUB
=
0xE408E408
;
//-8
int
lo
=
i4s
|
EX
;
return
pk_add_f16
(
bit_cast
<
fp16x2_t
>
(
lo
),
bit_cast
<
fp16x2_t
>
(
SUB
));
}
CK_TILE_HOST_DEVICE
bf16x2_t
pk_int4_t_to_bfloat16x2_t
(
const
pk_int4_t
&
x
)
{
uint8_t
x_u8
=
ck_tile
::
bit_cast
<
uint8_t
>
(
x
);
float
x_l
=
((
x_u8
&
0x0f
)
>>
0
)
-
8.
f
;
float
x_h
=
((
x_u8
&
0xf0
)
>>
4
)
-
8.
f
;
#ifdef CK_TILE_USE_PK4_LAYOUT_SHUFFLE
bf16x2_t
res
=
{
type_convert
<
bf16_t
>
(
x_h
),
type_convert
<
bf16_t
>
(
x_l
)};
#elif
bf16x2_t
res
=
{
type_convert
<
bf16_t
>
(
x_l
),
type_convert
<
bf16_t
>
(
x_h
)};
#endif
return
res
;
}
}
// namespace ck_tile
include/ck_tile/core/numeric/vector_type.hpp
View file @
3c5717df
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-202
4
, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-202
5
, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
...
...
@@ -200,4 +200,21 @@ using bf8x32_t = bf8_t __attribute((ext_vector_type(32)));
using
bf8x64_t
=
bf8_t
__attribute
((
ext_vector_type
(
64
)));
#endif
CK_TILE_HOST
fp16x2_t
pk_add_f16
(
const
fp16x2_t
&
x
,
const
fp16x2_t
&
y
)
{
fp16x2_t
vector_res
;
vector_res
.
x
=
x
.
x
+
y
.
x
;
vector_res
.
y
=
x
.
y
+
y
.
y
;
return
vector_res
;
}
CK_TILE_DEVICE
fp16x2_t
pk_add_f16
(
const
fp16x2_t
&
x
,
const
fp16x2_t
&
y
)
{
fp16x2_t
c
;
asm
volatile
(
"v_pk_add_f16 %0, %1, %2"
:
"=v"
(
c
)
:
"v"
(
x
),
"v"
(
y
));
return
c
;
}
}
// namespace ck_tile
include/ck_tile/core/tensor/buffer_view.hpp
View file @
3c5717df
...
...
@@ -437,34 +437,74 @@ struct buffer_view<address_space_enum::global,
// i is offset of T, not X. i should be aligned to X
template
<
memory_operation_enum
Op
,
typename
X
,
bool
oob_conditional_check
=
true
,
typename
std
::
enable_if
<
std
::
is_same
<
typename
vector_traits
<
remove_cvref_t
<
X
>
>::
scalar_type
,
typename
vector_traits
<
remove_cvref_t
<
T
>>::
scalar_type
>::
value
,
bool
>::
type
=
false
>
CK_TILE_DEVICE
void
update
(
index_t
i
,
index_t
linear_offset
,
bool
is_valid_element
,
const
X
&
x
)
CK_TILE_DEVICE
void
update
(
index_t
i
,
index_t
linear_offset
,
bool
is_valid_element
,
const
X
&
x
,
bool_constant
<
oob_conditional_check
>
=
{})
{
if
constexpr
(
Op
==
memory_operation_enum
::
set
)
{
this
->
template
set
<
X
>(
i
,
linear_offset
,
is_valid_element
,
x
);
this
->
template
set
<
X
,
oob_conditional_check
>(
i
,
linear_offset
,
is_valid_element
,
x
);
}
else
if
constexpr
(
Op
==
memory_operation_enum
::
atomic_add
)
{
this
->
template
atomic_add
<
X
>(
i
,
linear_offset
,
is_valid_element
,
x
);
this
->
template
atomic_add
<
X
,
oob_conditional_check
>(
i
,
linear_offset
,
is_valid_element
,
x
);
}
else
if
constexpr
(
Op
==
memory_operation_enum
::
atomic_max
)
{
this
->
template
atomic_max
<
X
>(
i
,
linear_offset
,
is_valid_element
,
x
);
this
->
template
atomic_max
<
X
,
oob_conditional_check
>(
i
,
linear_offset
,
is_valid_element
,
x
);
}
// FIXME: remove memory_operation_enum::add
else
if
constexpr
(
Op
==
memory_operation_enum
::
add
)
{
auto
tmp
=
this
->
template
get
<
X
>(
i
,
linear_offset
,
is_valid_element
);
this
->
template
set
<
X
>(
i
,
linear_offset
,
is_valid_element
,
x
+
tmp
);
auto
tmp
=
this
->
template
get
<
X
,
oob_conditional_check
>(
i
,
linear_offset
,
is_valid_element
);
this
->
template
set
<
X
,
oob_conditional_check
>(
i
,
linear_offset
,
is_valid_element
,
x
+
tmp
);
// tmp += x;
// this->template set<X>(i, is_valid_element, tmp);
}
}
// i is offset of T, not X. i should be aligned to X
template
<
memory_operation_enum
Op
,
typename
X
,
bool
oob_conditional_check
=
true
,
bool
pre_nop
=
false
,
typename
std
::
enable_if
<
std
::
is_same
<
typename
vector_traits
<
remove_cvref_t
<
X
>
>::
scalar_type
,
typename
vector_traits
<
remove_cvref_t
<
T
>>::
scalar_type
>::
value
,
bool
>::
type
=
false
>
CK_TILE_DEVICE
void
update_raw
(
index_t
i
,
index_t
linear_offset
,
bool
is_valid_element
,
const
X
&
x
,
bool_constant
<
oob_conditional_check
>
=
{},
bool_constant
<
pre_nop
>
=
{})
{
if
constexpr
(
Op
==
memory_operation_enum
::
set
)
{
this
->
template
set_raw
<
X
,
oob_conditional_check
>(
i
,
linear_offset
,
is_valid_element
,
x
);
}
else
if
constexpr
(
Op
==
memory_operation_enum
::
atomic_add
)
{
this
->
template
atomic_add_raw
<
X
,
oob_conditional_check
,
pre_nop
>(
i
,
linear_offset
,
is_valid_element
,
x
);
}
else
if
constexpr
(
Op
==
memory_operation_enum
::
atomic_max
)
{
// this->template atomic_max_raw<X>(i, linear_offset, is_valid_element, x);
}
}
// i is offset of T, not X. i should be aligned to X
template
<
typename
X
,
bool
oob_conditional_check
=
true
,
...
...
@@ -533,6 +573,7 @@ struct buffer_view<address_space_enum::global,
}
template
<
typename
X
,
bool
oob_conditional_check
=
true
,
typename
std
::
enable_if
<
std
::
is_same
<
typename
vector_traits
<
remove_cvref_t
<
X
>
>::
scalar_type
,
typename
vector_traits
<
remove_cvref_t
<
T
>>::
scalar_type
>::
value
,
...
...
@@ -585,6 +626,39 @@ struct buffer_view<address_space_enum::global,
}
template
<
typename
X
,
bool
oob_conditional_check
=
true
,
bool
pre_nop
=
true
,
typename
std
::
enable_if
<
std
::
is_same
<
typename
vector_traits
<
remove_cvref_t
<
X
>
>::
scalar_type
,
typename
vector_traits
<
remove_cvref_t
<
T
>>::
scalar_type
>::
value
,
bool
>::
type
=
false
>
CK_TILE_DEVICE
void
atomic_add_raw
(
index_t
i
,
index_t
linear_offset
,
bool
is_valid_element
,
const
X
&
x
)
{
// using scalar_t = typename vector_traits<remove_cvref_t<T>>::scalar_type;
// X contains multiple T
constexpr
index_t
scalar_per_t_vector
=
vector_traits
<
remove_cvref_t
<
T
>>::
vector_size
;
constexpr
index_t
scalar_per_x_vector
=
vector_traits
<
remove_cvref_t
<
X
>>::
vector_size
;
static_assert
(
scalar_per_x_vector
%
scalar_per_t_vector
==
0
,
"wrong! X should contain multiple T"
);
static_assert
(
get_address_space
()
==
address_space_enum
::
global
,
"only support global mem"
);
constexpr
index_t
t_per_x
=
scalar_per_x_vector
/
scalar_per_t_vector
;
amd_buffer_atomic_add_raw
<
remove_cvref_t
<
T
>
,
t_per_x
,
Coherence
,
oob_conditional_check
,
pre_nop
>
(
x
,
p_data_
,
i
,
linear_offset
,
is_valid_element
,
buffer_size_
);
}
template
<
typename
X
,
bool
oob_conditional_check
=
true
,
typename
std
::
enable_if
<
std
::
is_same
<
typename
vector_traits
<
remove_cvref_t
<
X
>
>::
scalar_type
,
typename
vector_traits
<
remove_cvref_t
<
T
>>::
scalar_type
>::
value
,
...
...
include/ck_tile/core/tensor/load_tile.hpp
View file @
3c5717df
...
...
@@ -22,28 +22,32 @@ template <typename BottomTensorView_,
typename
WindowLengths_
,
typename
TileDistribution_
,
index_t
NumCoord
,
index_t
i_access
=
-
1
,
bool
oob_conditional_check
=
true
>
CK_TILE_DEVICE
auto
load_tile
(
const
tile_window_with_static_distribution
<
BottomTensorView_
,
WindowLengths_
,
TileDistribution_
,
NumCoord
>&
tile_window
,
number
<
i_access
>
=
{},
bool_constant
<
oob_conditional_check
>
=
{})
{
return
tile_window
.
load
(
number
<
-
1
>
{},
bool_constant
<
oob_conditional_check
>
{});
return
tile_window
.
load
(
number
<
i_access
>
{},
bool_constant
<
oob_conditional_check
>
{});
}
template
<
typename
BottomTensorView_
,
typename
WindowLengths_
,
typename
TileDistribution_
,
typename
LinearBottomDims_
,
index_t
i_access
=
-
1
,
bool
oob_conditional_check
=
true
>
CK_TILE_DEVICE
auto
load_tile
(
const
tile_window_linear
<
BottomTensorView_
,
WindowLengths_
,
TileDistribution_
,
LinearBottomDims_
>&
tile_window
,
number
<
i_access
>
=
{},
bool_constant
<
oob_conditional_check
>
=
{})
{
return
tile_window
.
load
(
number
<
-
1
>
{},
bool_constant
<
oob_conditional_check
>
{});
return
tile_window
.
load
(
number
<
i_access
>
{},
bool_constant
<
oob_conditional_check
>
{});
}
template
<
typename
DistributedTensor_
,
...
...
@@ -51,15 +55,35 @@ template <typename DistributedTensor_,
typename
WindowLengths_
,
typename
TileDistribution_
,
index_t
NumCoord
,
index_t
i_access
=
-
1
,
bool
oob_conditional_check
=
true
>
CK_TILE_DEVICE
auto
load_tile
(
DistributedTensor_
&
dst_tile
,
const
tile_window_with_static_distribution
<
BottomTensorView_
,
WindowLengths_
,
TileDistribution_
,
NumCoord
>&
tile_window
,
number
<
i_access
>
=
{},
bool_constant
<
oob_conditional_check
>
=
{})
{
return
tile_window
.
load
(
dst_tile
,
bool_constant
<
oob_conditional_check
>
{});
return
tile_window
.
load
(
dst_tile
,
number
<
i_access
>
{},
bool_constant
<
oob_conditional_check
>
{});
}
template
<
typename
DistributedTensor_
,
typename
BottomTensorView_
,
typename
WindowLengths_
,
typename
TileDistribution_
,
typename
LinearBottomDims_
,
index_t
i_access
=
-
1
,
bool
oob_conditional_check
=
true
>
CK_TILE_DEVICE
auto
load_tile
(
DistributedTensor_
&
dst_tile
,
const
tile_window_linear
<
BottomTensorView_
,
WindowLengths_
,
TileDistribution_
,
LinearBottomDims_
>&
tile_window
,
number
<
i_access
>
=
{},
bool_constant
<
oob_conditional_check
>
=
{})
{
return
tile_window
.
load
(
dst_tile
,
number
<
i_access
>
{},
bool_constant
<
oob_conditional_check
>
{});
}
/**
...
...
@@ -76,6 +100,7 @@ template <typename T,
typename
WindowLengths_
,
typename
TileDistribution_
,
index_t
NumCoord
,
index_t
i_access
=
-
1
,
bool
oob_conditional_check
=
true
,
bool
pre_nop
=
false
>
CK_TILE_DEVICE
auto
load_tile_raw
(
T
&
tile
,
...
...
@@ -83,11 +108,12 @@ CK_TILE_DEVICE auto load_tile_raw(T& tile,
WindowLengths_
,
TileDistribution_
,
NumCoord
>&
tile_window
,
number
<
i_access
>
=
{},
bool_constant
<
oob_conditional_check
>
=
{},
bool_constant
<
pre_nop
>
=
{})
{
tile_window
.
load_raw
(
tile
,
number
<
-
1
>
{},
bool_constant
<
oob_conditional_check
>
{},
bool_constant
<
pre_nop
>
{});
tile
,
number
<
i_access
>
{},
bool_constant
<
oob_conditional_check
>
{},
bool_constant
<
pre_nop
>
{});
}
template
<
typename
T
,
...
...
@@ -95,6 +121,7 @@ template <typename T,
typename
WindowLengths_
,
typename
TileDistribution_
,
typename
LinearBottomDims_
,
index_t
i_access
=
-
1
,
bool
oob_conditional_check
=
true
,
bool
pre_nop
=
false
>
CK_TILE_DEVICE
auto
load_tile_raw
(
T
&
tile
,
...
...
@@ -102,11 +129,12 @@ CK_TILE_DEVICE auto load_tile_raw(T& tile,
WindowLengths_
,
TileDistribution_
,
LinearBottomDims_
>&
tile_window
,
number
<
i_access
>
=
{},
bool_constant
<
oob_conditional_check
>
=
{},
bool_constant
<
pre_nop
>
=
{})
{
tile_window
.
load_raw
(
tile
,
number
<
-
1
>
{},
bool_constant
<
oob_conditional_check
>
{},
bool_constant
<
pre_nop
>
{});
tile
,
number
<
i_access
>
{},
bool_constant
<
oob_conditional_check
>
{},
bool_constant
<
pre_nop
>
{});
}
template
<
typename
LdsTileWindow_
,
...
...
@@ -114,6 +142,7 @@ template <typename LdsTileWindow_,
typename
WindowLengths_
,
typename
TileDistribution_
,
index_t
NumCoord
,
index_t
i_access
=
-
1
,
bool
oob_conditional_check
=
true
,
bool
pre_nop
=
false
>
CK_TILE_DEVICE
auto
...
...
@@ -122,11 +151,14 @@ async_load_tile_raw(LdsTileWindow_&& lds_tile,
WindowLengths_
,
TileDistribution_
,
NumCoord
>&
tile_window
,
number
<
i_access
>
=
{},
bool_constant
<
oob_conditional_check
>
=
{},
bool_constant
<
pre_nop
>
=
{})
{
return
tile_window
.
async_load_raw
(
lds_tile
,
number
<-
1
>
{},
bool_constant
<
oob_conditional_check
>
{},
bool_constant
<
pre_nop
>
{});
return
tile_window
.
async_load_raw
(
lds_tile
,
number
<
i_access
>
{},
bool_constant
<
oob_conditional_check
>
{},
bool_constant
<
pre_nop
>
{});
}
template
<
typename
LdsTileWindow_
,
...
...
@@ -134,6 +166,7 @@ template <typename LdsTileWindow_,
typename
WindowLengths_
,
typename
TileDistribution_
,
typename
LinearBottomDims_
,
index_t
i_access
=
-
1
,
bool
oob_conditional_check
=
true
,
bool
pre_nop
=
false
>
CK_TILE_DEVICE
auto
async_load_tile_raw
(
LdsTileWindow_
&&
lds_tile
,
...
...
@@ -141,11 +174,14 @@ CK_TILE_DEVICE auto async_load_tile_raw(LdsTileWindow_&& lds_tile,
WindowLengths_
,
TileDistribution_
,
LinearBottomDims_
>&
tile_window
,
number
<
i_access
>
=
{},
bool_constant
<
oob_conditional_check
>
=
{},
bool_constant
<
pre_nop
>
=
{})
{
return
tile_window
.
async_load_raw
(
lds_tile
,
number
<-
1
>
{},
bool_constant
<
oob_conditional_check
>
{},
bool_constant
<
pre_nop
>
{});
return
tile_window
.
async_load_raw
(
lds_tile
,
number
<
i_access
>
{},
bool_constant
<
oob_conditional_check
>
{},
bool_constant
<
pre_nop
>
{});
}
CK_TILE_DEVICE
auto
async_load_fence
(
index_t
cnt
=
0
)
...
...
include/ck_tile/core/tensor/static_distributed_tensor.hpp
View file @
3c5717df
...
...
@@ -29,6 +29,7 @@ struct static_distributed_tensor
remove_cvref_t
<
decltype
(
StaticTileDistribution
{}.
get_ys_to_d_descriptor
())
>
;
static
constexpr
index_t
kThreadElementSpaceSize
=
ThreadTensorDesc
{}.
get_element_space_size
();
static_assert
(
0
<
kThreadElementSpaceSize
,
"Make sure tile distribution is valid"
);
CK_TILE_HOST_DEVICE
static
constexpr
auto
get_num_of_dimension
()
{
...
...
@@ -201,4 +202,30 @@ CK_TILE_HOST_DEVICE constexpr auto get_y_unpacks_from_x_unpacks(YLengths, number
return
unpacks
;
}
namespace
detail
{
// check if 2 static_distributed_tensor has same data type and size of element
// but only difference in distribution
template
<
typename
X
,
typename
Y
>
struct
is_similiar_distributed_tensor
{
static
constexpr
bool
value
=
false
;
};
template
<
typename
TypeX
,
typename
DistX
,
typename
TypeY
,
typename
DistY
>
struct
is_similiar_distributed_tensor
<
static_distributed_tensor
<
TypeX
,
DistX
>
,
static_distributed_tensor
<
TypeY
,
DistY
>>
{
using
Tx
=
static_distributed_tensor
<
TypeX
,
DistX
>
;
using
Ty
=
static_distributed_tensor
<
TypeY
,
DistY
>
;
static
constexpr
bool
value
=
std
::
is_same_v
<
typename
Tx
::
DataType
,
typename
Ty
::
DataType
>
&&
Tx
::
get_thread_buffer_size
()
==
Ty
::
get_thread_buffer_size
();
};
template
<
typename
X
,
typename
Y
>
inline
constexpr
bool
is_similiar_distributed_tensor_v
=
is_similiar_distributed_tensor
<
X
,
Y
>::
value
;
}
// namespace detail
}
// namespace ck_tile
include/ck_tile/core/tensor/tensor_view.hpp
View file @
3c5717df
...
...
@@ -333,6 +333,48 @@ struct tensor_view
coord
.
get_offset
(),
linear_offset
,
is_valid_element
,
x
);
}
// X is vector of DataType.
// "coord" is coordinate of DataType, not X. "coord" should be aligned to X
template
<
typename
X
,
bool
oob_conditional_check
=
true
,
bool
pre_nop
=
false
,
typename
std
::
enable_if
<
std
::
is_same_v
<
typename
vector_traits
<
remove_cvref_t
<
X
>
>::
scalar_type
,
typename
vector_traits
<
remove_cvref_t
<
DataType
>>::
scalar_type
>
,
bool
>::
type
=
false
>
CK_TILE_HOST_DEVICE
constexpr
void
update_vectorized_elements_raw
(
const
TensorCoord
&
coord
,
index_t
linear_offset
,
const
X
&
x
,
bool_constant
<
oob_conditional_check
>
=
{},
bool_constant
<
pre_nop
>
=
{})
{
buf_
.
template
update_raw
<
DstInMemOp
,
X
,
oob_conditional_check
,
pre_nop
>(
coord
.
get_offset
(),
linear_offset
,
coordinate_has_valid_offset_assuming_top_index_is_valid
(
desc_
,
coord
),
x
);
}
template
<
typename
X
,
bool
oob_conditional_check
=
true
,
bool
pre_nop
=
false
,
typename
std
::
enable_if
<
std
::
is_same_v
<
typename
vector_traits
<
remove_cvref_t
<
X
>
>::
scalar_type
,
typename
vector_traits
<
remove_cvref_t
<
DataType
>>::
scalar_type
>
,
bool
>::
type
=
false
>
CK_TILE_HOST_DEVICE
constexpr
void
update_vectorized_elements_raw
(
const
TensorCoord
&
coord
,
index_t
linear_offset
,
bool
is_valid_element
,
const
X
&
x
,
bool_constant
<
oob_conditional_check
>
=
{},
bool_constant
<
pre_nop
>
=
{})
{
buf_
.
template
update_raw
<
DstInMemOp
,
X
,
oob_conditional_check
,
pre_nop
>(
coord
.
get_offset
(),
linear_offset
,
is_valid_element
,
x
);
}
CK_TILE_HOST_DEVICE
void
print
()
const
{
printf
(
"tensor_view{"
);
...
...
include/ck_tile/core/tensor/tile_window.hpp
View file @
3c5717df
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-202
4
, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-202
5
, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
...
...
@@ -18,8 +18,17 @@
namespace
ck_tile
{
// Note: this tile window do not support single issue
// you need to use tile_window_linear structure for this purpose
/**
* @brief This class provides tile (windowed) view and access to the device memory.
*
* @note This tile window does not support single issue you need to use tile_window_linear
* structure for this purpose
*
* @tparam BottomTensorView_ Class describing & holding device tensor memory.
* @tparam WindowLengths_ Spatial sizes of windowed view on tensor.
* @tparam StaticTileDistribution_ Thread distribution (mapping) into Tile dimensions
* @tparam NumCoord TBD
*/
template
<
typename
BottomTensorView_
,
typename
WindowLengths_
,
typename
StaticTileDistribution_
,
...
...
@@ -292,12 +301,15 @@ struct tile_window_with_static_distribution
{
constexpr
auto
tile_dstr
=
TileDstr
{};
auto
dst_tensor
=
make_static_distributed_tensor
<
DataType
>
(
tile_dstr
);
load
(
dst_tensor
,
bool_constant
<
oob_conditional_check
>
{});
load
(
dst_tensor
,
number
<
i_access_unsupport_
>
{},
bool_constant
<
oob_conditional_check
>
{});
return
dst_tensor
;
}
template
<
typename
DistributedTensor
,
bool
oob_conditional_check
=
true
>
template
<
typename
DistributedTensor
,
index_t
i_access_unsupport_
=
-
1
,
bool
oob_conditional_check
=
true
>
CK_TILE_DEVICE
auto
load
(
DistributedTensor
&
dst_tensor
,
number
<
i_access_unsupport_
>
=
{},
bool_constant
<
oob_conditional_check
>
=
{})
const
{
using
Traits
=
load_store_traits
;
...
...
@@ -785,6 +797,73 @@ struct tile_window_with_static_distribution
});
}
template
<
index_t
i_access_unsupport_
=
-
1
,
bool
oob_conditional_check
=
true
,
bool
pre_nop
>
CK_TILE_DEVICE
void
update_raw
(
const
static_distributed_tensor
<
DataType
,
TileDstr
>&
dstr_tensor
,
number
<
i_access_unsupport_
>
=
{},
bool_constant
<
oob_conditional_check
>
=
{},
bool_constant
<
pre_nop
>
=
{})
const
{
using
Traits
=
load_store_traits
;
using
vector_t
=
typename
Traits
::
vector_t
;
using
SFC_Ys
=
typename
Traits
::
SFC_Ys
;
constexpr
auto
tile_dstr
=
TileDstr
{};
// loop over thread tensor space [y0, y1, ...]
static_for
<
0
,
NumCoord
,
1
>
{}([
&
](
auto
iCoord
)
{
/// TODO: use structure binding (to be captured later) if compiled in C++20
auto
window_adaptor_thread_coord
=
pre_computed_coords_
[
iCoord
][
I0
];
auto
bottom_tensor_thread_coord
=
pre_computed_coords_
[
iCoord
][
I1
];
static_for
<
0
,
NumAccessPerCoord
,
1
>
{}([
&
](
auto
iCoordAccess
)
{
constexpr
auto
iAccess
=
number
<
iCoord
*
NumAccessPerCoord
+
iCoordAccess
>
{};
// data index [y0, y1, ...]
constexpr
auto
idx_ys_start
=
SFC_Ys
::
get_index
(
iAccess
);
// read from distributed tensor
vector_t
vec_value
;
static_for
<
0
,
Traits
::
ScalarPerVector
,
1
>
{}([
&
](
auto
j
)
{
constexpr
auto
idx_ys
=
generate_tuple
(
[
&
](
auto
jj
)
{
return
jj
==
Traits
::
VectorDimY
?
(
idx_ys_start
[
jj
]
+
j
)
:
idx_ys_start
[
jj
];
},
number
<
NDimY
>
{});
constexpr
index_t
d
=
tile_dstr
.
get_ys_to_d_descriptor
().
calculate_offset
(
idx_ys
);
vec_value
.
template
get_as
<
DataType
>()(
j
)
=
dstr_tensor
.
get_thread_buffer
().
template
at
<
d
>();
});
// write into bottom tensor
get_bottom_tensor_view
().
template
update_vectorized_elements_raw
<
vector_t
>(
bottom_tensor_thread_coord
,
0
,
vec_value
,
bool_constant
<
oob_conditional_check
>
{},
bool_constant
<
pre_nop
>
{});
// move thread coordinate
if
constexpr
(
iCoordAccess
!=
(
NumAccessPerCoord
-
1
))
{
constexpr
auto
idx_diff_ys
=
SFC_Ys
::
get_forward_step
(
iAccess
);
constexpr
auto
idx_diff_ps_ys
=
container_concat
(
generate_tuple
([
&
](
auto
)
{
return
number
<
0
>
{};
},
number
<
NDimP
>
{}),
idx_diff_ys
);
move_window_adaptor_and_bottom_tensor_thread_coordinate
(
window_adaptor_thread_coord
,
bottom_tensor_thread_coord
,
idx_diff_ps_ys
);
}
});
});
}
// move thread's botom tensor coordiante
// [x0', x1', ... ] ==> [offset]
// also move window-origin
...
...
@@ -939,6 +1018,14 @@ CK_TILE_DEVICE void move_tile_window(
window
.
move
(
step
);
}
/**
* @brief This class provides description of tile windowed view on the device memory.
*
* @note This class does not provide any functions to read or modify device memory.
*
* @tparam BottomTensorView_ Class describing & holding device tensor memory.
* @tparam WindowLengths_ Spatial sizes of windowed view on tensor.
*/
template
<
typename
BottomTensorView_
,
typename
WindowLengths_
>
struct
tile_window_with_static_lengths
{
...
...
include/ck_tile/core/tensor/tile_window_linear.hpp
View file @
3c5717df
...
...
@@ -432,23 +432,38 @@ struct tile_window_linear
CK_TILE_DEVICE
static
constexpr
index_t
get_bottom_linear_offset
(
number
<
i_access
>
)
{
constexpr
auto
linear_coord
=
get_bottom_linear_coordinate
(
number
<
i_access
>
{});
// since this is linear offset, we assum bottom X tensor is always linear
constexpr
index_t
linear_offset
=
[
&
]()
{
constexpr
auto
x_idx_
=
linear_coord
;
constexpr
auto
x_len_
=
TileDstr
{}.
get_lengths
();
static_assert
(
x_idx_
.
size
()
==
x_len_
.
size
());
constexpr
index_t
x_dims_
=
x_idx_
.
size
();
index_t
cu_stride_
=
1
;
index_t
cu_offset_
=
0
;
static_for
<
0
,
x_dims_
,
1
>
{}([
&
](
auto
i_
)
{
auto
r_i_
=
number
<
x_dims_
-
i_
-
1
>
{};
cu_offset_
+=
x_idx_
[
r_i_
]
*
cu_stride_
;
cu_stride_
*=
x_len_
[
r_i_
];
});
return
cu_offset_
;
}();
return
linear_offset
;
constexpr
auto
is_pure_linear_tensor
=
reduce_on_sequence
(
LinearBottomDims
{},
multiplies
{},
number
<
1
>
{});
if
constexpr
(
is_pure_linear_tensor
)
{
// this case usually is a LDS window, everything is known at compile tile.
// we directly use BottomTensorView transform to compute the offset, in case padding
auto
bottom_tensor_coord
=
make_tensor_coordinate
(
BottomTensorView
{}.
get_tensor_descriptor
(),
linear_coord
);
return
bottom_tensor_coord
.
get_offset
();
}
else
{
// this case usually is a global window, where last dim can be linear
// we hack here, that use the original TileDstr to compute the linear offset
// ... hoping that there is no extra padding between other dims, which make sense
// since that would introduce runtime length (so can't use linear offset)
constexpr
index_t
linear_offset
=
[
&
]()
{
constexpr
auto
x_idx_
=
linear_coord
;
constexpr
auto
x_len_
=
TileDstr
{}.
get_lengths
();
static_assert
(
x_idx_
.
size
()
==
x_len_
.
size
());
constexpr
index_t
x_dims_
=
x_idx_
.
size
();
index_t
cu_stride_
=
1
;
index_t
cu_offset_
=
0
;
static_for
<
0
,
x_dims_
,
1
>
{}([
&
](
auto
i_
)
{
auto
r_i_
=
number
<
x_dims_
-
i_
-
1
>
{};
cu_offset_
+=
x_idx_
[
r_i_
]
*
cu_stride_
;
cu_stride_
*=
x_len_
[
r_i_
];
});
return
cu_offset_
;
}();
return
linear_offset
;
}
}
CK_TILE_DEVICE
constexpr
auto
get_num_of_access
()
const
{
return
traits
::
NumAccess
;
}
...
...
@@ -509,6 +524,64 @@ struct tile_window_linear
return
dst_tensor
;
}
template
<
typename
DstTile
,
index_t
i_access
=
-
1
,
bool
oob_conditional_check
=
true
>
CK_TILE_DEVICE
auto
load
(
DstTile
&
dst_tensor
,
number
<
i_access
>
=
{},
bool_constant
<
oob_conditional_check
>
=
{})
const
{
using
vector_t
=
typename
traits
::
vector_t
;
using
SFC_Ys
=
typename
traits
::
SFC_Ys
;
constexpr
auto
tile_dstr
=
TileDstr
{};
// auto dst_tensor = make_static_distributed_tensor<DataType>(tile_dstr);
auto
issue
=
[
&
](
auto
i_access_
)
{
constexpr
auto
IAccess
=
number
<
i_access_
>
{};
constexpr
auto
non_linear_id
=
number
<
AccessMap_NonLinear
{}[
IAccess
]
>
{};
auto
bottom_tensor_thread_coord
=
cached_coords_
[
non_linear_id
];
auto
bottom_tensor_flag
=
cached_flags_
[
IAccess
];
constexpr
auto
linear_offset
=
get_bottom_linear_offset
(
IAccess
);
// read from bottom tensor
const
vector_t
vec_value
=
get_bottom_tensor_view
().
template
get_vectorized_elements
<
vector_t
>(
bottom_tensor_thread_coord
,
linear_offset
,
bottom_tensor_flag
,
bool_constant
<
oob_conditional_check
>
{});
#if 1
// data index [y0, y1, ...]
constexpr
auto
idx_diff_ys
=
SFC_Ys
::
get_index
(
IAccess
);
// write into distributed tensor
static_for
<
0
,
traits
::
ScalarPerVector
,
1
>
{}([
&
](
auto
j
)
{
constexpr
auto
idx_ys
=
generate_tuple
(
[
&
](
auto
jj
)
{
return
jj
==
traits
::
VectorDimY
?
(
idx_diff_ys
[
jj
]
+
j
)
:
idx_diff_ys
[
jj
];
},
number
<
NDimY
>
{});
constexpr
index_t
d
=
tile_dstr
.
get_ys_to_d_descriptor
().
calculate_offset
(
idx_ys
);
dst_tensor
.
get_thread_buffer
().
template
at
<
d
>()
=
vec_value
.
template
get_as
<
DataType
>()[
j
];
});
#else
constexpr
index_t
d
=
tile_dstr
.
get_ys_to_d_descriptor
().
calculate_offset
(
idx_ys_start
);
static_assert
(
d
%
traits
::
ScalarPerVector
==
0
);
dst_tensor
.
get_thread_buffer
().
template
get_as
<
vector_t
>()(
number
<
d
/
traits
::
ScalarPerVector
>
{})
=
bit_cast
<
vector_t
>
(
vec_value
);
#endif
};
WINDOW_DISPATCH_ISSUE
();
return
dst_tensor
;
}
template
<
typename
DstTile
,
index_t
i_access
=
-
1
,
bool
oob_conditional_check
=
true
,
...
...
@@ -849,6 +922,58 @@ struct tile_window_linear
WINDOW_DISPATCH_ISSUE
();
}
template
<
index_t
i_access
=
-
1
,
bool
oob_conditional_check
=
true
,
bool
pre_nop
=
false
>
CK_TILE_DEVICE
void
update_raw
(
const
static_distributed_tensor
<
DataType
,
TileDstr
>&
dstr_tensor
,
number
<
i_access
>
=
{},
bool_constant
<
oob_conditional_check
>
=
{},
bool_constant
<
pre_nop
>
=
{})
const
{
using
vector_t
=
typename
traits
::
vector_t
;
using
SFC_Ys
=
typename
traits
::
SFC_Ys
;
constexpr
auto
tile_dstr
=
TileDstr
{};
// loop over thread tensor space [y0, y1, ...]
auto
issue
=
[
&
](
auto
i_access_
)
{
constexpr
auto
IAccess
=
number
<
i_access_
>
{};
constexpr
auto
non_linear_id
=
number
<
AccessMap_NonLinear
{}[
IAccess
]
>
{};
auto
bottom_tensor_thread_coord
=
cached_coords_
[
non_linear_id
];
constexpr
auto
linear_offset
=
get_bottom_linear_offset
(
IAccess
);
auto
bottom_tensor_flag
=
cached_flags_
[
IAccess
];
// data index [y0, y1, ...]
constexpr
auto
idx_ys_start
=
SFC_Ys
::
get_index
(
IAccess
);
// read from distributed tensor
vector_t
vec_value
;
static_for
<
0
,
traits
::
ScalarPerVector
,
1
>
{}([
&
](
auto
j
)
{
constexpr
auto
idx_ys
=
generate_tuple
(
[
&
](
auto
jj
)
{
return
jj
==
traits
::
VectorDimY
?
(
idx_ys_start
[
jj
]
+
j
)
:
idx_ys_start
[
jj
];
},
number
<
NDimY
>
{});
constexpr
index_t
d
=
tile_dstr
.
get_ys_to_d_descriptor
().
calculate_offset
(
idx_ys
);
vec_value
.
template
get_as
<
DataType
>()(
j
)
=
dstr_tensor
.
get_thread_buffer
().
template
at
<
d
>();
});
// write into bottom tensor
get_bottom_tensor_view
().
template
update_vectorized_elements_raw
<
vector_t
>(
bottom_tensor_thread_coord
,
linear_offset
,
bottom_tensor_flag
,
vec_value
,
bool_constant
<
oob_conditional_check
>
{},
bool_constant
<
pre_nop
>
{});
};
WINDOW_DISPATCH_ISSUE
();
}
// move thread's botom tensor coordiante
// [x0', x1', ... ] ==> [offset]
// also move window-origin
...
...
include/ck_tile/core/tensor/tile_window_utils.hpp
0 → 100644
View file @
3c5717df
// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
#include "ck_tile/core/arch/arch.hpp"
#include "ck_tile/core/arch/utility.hpp"
#include "ck_tile/core/algorithm/space_filling_curve.hpp"
#include "ck_tile/core/config.hpp"
#include "ck_tile/core/container/array.hpp"
#include "ck_tile/core/container/sequence.hpp"
#include "ck_tile/core/container/tuple.hpp"
#include "ck_tile/core/container/container_helper.hpp"
#include "ck_tile/core/tensor/static_distributed_tensor.hpp"
#include "ck_tile/core/tensor/tensor_adaptor.hpp"
#include "ck_tile/core/tensor/tile_distribution.hpp"
#include "ck_tile/core/utility/functional.hpp"
#include "ck_tile/core/utility/type_traits.hpp"
#pragma once
namespace
ck_tile
{
// input a lds store tile, extract some information from it
// used to set m0 value for gfx9 serious
template
<
typename
LdsTileWindow_
>
CK_TILE_DEVICE
auto
get_async_store_smem_info
(
LdsTileWindow_
&&
lds_tile
)
{
using
LdsTileWindow
=
remove_cvref_t
<
LdsTileWindow_
>
;
using
LdsDataType
=
typename
LdsTileWindow
::
DataType
;
// issues * warps * lanes
static_assert
(
LdsTileWindow
::
get_num_of_dimension
()
==
3
);
// TODO: hard coded
const
index_t
size_per_buf
=
lds_tile
.
get_bottom_tensor_view
().
get_tensor_descriptor
().
calculate_offset
(
make_tuple
(
number
<
0
>
{},
number
<
0
>
{},
number
<
0
>
{}))
*
sizeof
(
LdsDataType
);
const
index_t
size_per_wave
=
lds_tile
.
get_bottom_tensor_view
().
get_tensor_descriptor
().
calculate_offset
(
make_tuple
(
number
<
0
>
{},
number
<
1
>
{},
number
<
0
>
{}))
*
sizeof
(
LdsDataType
)
-
size_per_buf
;
const
index_t
size_per_issue
=
lds_tile
.
get_bottom_tensor_view
().
get_tensor_descriptor
().
calculate_offset
(
make_tuple
(
number
<
1
>
{},
number
<
0
>
{},
number
<
0
>
{}))
*
sizeof
(
LdsDataType
)
-
size_per_buf
;
const
index_t
m0_init_value
=
size_per_buf
+
size_per_wave
*
get_warp_id
();
return
make_tuple
(
m0_init_value
,
size_per_issue
);
}
}
// namespace ck_tile
include/ck_tile/core/tensor/transpose_tile.hpp
0 → 100644
View file @
3c5717df
// SPDX-License-Identifier: MIT
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core/config.hpp"
#include "ck_tile/core/numeric/integer.hpp"
#include "ck_tile/core/numeric/integral_constant.hpp"
#include "ck_tile/core/utility/functional.hpp"
#include "ck_tile/core/algorithm/coordinate_transform.hpp"
#include "ck_tile/core/algorithm/space_filling_curve.hpp"
#include "ck_tile/core/container/container_helper.hpp"
#include "ck_tile/core/container/thread_buffer.hpp"
#include "ck_tile/core/container/statically_indexed_array.hpp"
#include "ck_tile/core/numeric/math.hpp"
#include "ck_tile/core/utility/type_traits.hpp"
#include "ck_tile/core/tensor/tile_elementwise.hpp"
#include "ck_tile/core/utility/transpose_vectors.hpp"
namespace
ck_tile
{
namespace
detail
{
template
<
typename
OutTensor
,
typename
InTensor
>
CK_TILE_DEVICE
void
transpose_tile2d_impl_in_thread
(
OutTensor
&
out_tensor
,
const
InTensor
&
in_tensor
)
{
constexpr
auto
I0
=
number
<
0
>
{};
static_assert
(
std
::
is_same_v
<
typename
InTensor
::
DataType
,
typename
OutTensor
::
DataType
>
,
"Data type for InTensor and OutTensor must be the same!"
);
using
DataType
=
typename
InTensor
::
DataType
;
constexpr
auto
y_in_desc
=
InTensor
::
get_tile_distribution
().
get_ys_to_d_descriptor
();
constexpr
auto
y_out_desc
=
OutTensor
::
get_tile_distribution
().
get_ys_to_d_descriptor
();
// y_dim_out_to_in
// For swapped Hs tile case I need only get_rh_minor_to_y
// since rh_major are already swapped due to swapped Hs.
constexpr
auto
get_rh_minor_to_y
=
[](
auto
dstr_tensor
)
{
using
DstrEncode
=
typename
decltype
(
dstr_tensor
.
get_tile_distribution
())
::
DstrEncode
;
map
<
index_t
,
index_t
>
rh_minor_to_y_
;
static_for
<
0
,
DstrEncode
::
NDimY
,
1
>
{}([
&
](
auto
i
)
{
constexpr
index_t
rh_minor
=
DstrEncode
::
ys_to_rhs_minor_
[
i
];
rh_minor_to_y_
(
rh_minor
)
=
i
;
});
return
rh_minor_to_y_
;
};
// In swapped Hs case <Y,X> -> <X,Y> tile
// we have same rh_major, but reversed rh_minor!
constexpr
auto
rh_minor_to_y_in
=
get_rh_minor_to_y
(
InTensor
{});
constexpr
auto
rh_minor_to_y_out
=
get_rh_minor_to_y
(
OutTensor
{});
// Is this really needed?? Should we have simple reverse here??
constexpr
auto
y_dim_out_to_in
=
[
&
]
{
map
<
index_t
,
index_t
>
y_dim_out_to_in_
;
for
(
const
auto
&
[
rh_minor
,
y_out
]
:
rh_minor_to_y_out
)
{
y_dim_out_to_in_
(
y_out
)
=
rh_minor_to_y_in
[
rh_minor
];
}
return
y_dim_out_to_in_
;
}();
constexpr
index_t
NDimY
=
InTensor
::
get_tile_distribution
().
get_num_of_dimension_y
();
constexpr
auto
y_lengths
=
to_sequence
(
y_in_desc
.
get_lengths
());
// input and output vector dim in the order of input Y dims
constexpr
index_t
y_dim_vec_in
=
NDimY
-
1
;
constexpr
index_t
y_dim_vec_out
=
y_dim_out_to_in
[
NDimY
-
1
];
// vector lengths
constexpr
index_t
vec_length_in
=
y_lengths
[
y_dim_vec_in
];
constexpr
index_t
vec_length_out
=
y_lengths
[
y_dim_vec_out
];
// # of vectors
constexpr
index_t
num_vec_in
=
vec_length_out
;
constexpr
index_t
num_vec_out
=
vec_length_in
;
using
InVec
=
array
<
DataType
,
vec_length_in
>
;
using
OutVec
=
array
<
DataType
,
vec_length_out
>
;
// SFC
constexpr
auto
scalars_per_access_arr
=
generate_array
(
[
&
](
auto
i
)
{
return
(
i
==
y_dim_vec_in
or
i
==
y_dim_vec_out
)
?
y_lengths
[
i
]
:
1
;
},
number
<
NDimY
>
{});
constexpr
auto
scalars_per_access
=
TO_SEQUENCE
(
scalars_per_access_arr
,
NDimY
);
using
SFC_Y
=
space_filling_curve
<
decltype
(
y_lengths
),
typename
arithmetic_sequence_gen
<
0
,
NDimY
,
1
>::
type
,
decltype
(
scalars_per_access
)
>
;
constexpr
index_t
num_access
=
SFC_Y
::
get_num_of_access
();
static_assert
(
num_access
>
0
,
"wrong! num_access should be larger than 0"
);
// in/out vectors to be transposed
thread_buffer
<
InVec
,
num_vec_in
>
in_vectors
;
thread_buffer
<
OutVec
,
num_vec_out
>
out_vectors
;
// loop over SFC and do transpose
static_for
<
0
,
num_access
,
1
>
{}([
&
](
auto
iAccess
)
{
// data index [y0, y1, ...] in the order of input tensor
constexpr
auto
idx_y_start
=
SFC_Y
::
get_index
(
iAccess
);
// get input vectors
static_for
<
0
,
num_vec_in
,
1
>
{}([
&
](
auto
i
)
{
constexpr
auto
idx_y_in
=
generate_tuple
(
[
&
](
auto
ii
)
{
return
ii
==
y_dim_vec_out
?
idx_y_start
[
ii
]
+
i
:
idx_y_start
[
ii
];
},
number
<
NDimY
>
{});
constexpr
index_t
in_offset
=
y_in_desc
.
calculate_offset
(
idx_y_in
);
static_assert
(
in_offset
%
vec_length_in
==
0
);
in_vectors
(
i
).
template
get_as
<
InVec
>()(
I0
)
=
in_tensor
.
get_thread_buffer
()
.
template
get_as
<
InVec
>()[
number
<
in_offset
/
vec_length_in
>
{}];
});
// transpose
transpose_vectors
<
DataType
,
num_vec_in
,
num_vec_out
>
{}(
in_vectors
,
out_vectors
);
// set output vectors
static_for
<
0
,
num_vec_out
,
1
>
{}([
&
](
auto
i
)
{
constexpr
auto
idx_y_out_tmp
=
generate_array
(
[
&
](
auto
ii
)
{
return
ii
==
y_dim_vec_in
?
idx_y_start
[
ii
]
+
i
:
idx_y_start
[
ii
];
},
number
<
NDimY
>
{});
constexpr
auto
idx_y_out
=
container_reorder_given_new2old
(
idx_y_out_tmp
,
y_dim_out_to_in
);
constexpr
index_t
out_offset
=
y_out_desc
.
calculate_offset
(
idx_y_out
);
static_assert
(
out_offset
%
vec_length_out
==
0
);
out_tensor
.
get_thread_buffer
().
template
set_as
<
OutVec
>(
number
<
out_offset
/
vec_length_out
>
{},
out_vectors
[
i
].
template
get_as
<
OutVec
>()[
I0
]);
});
});
}
}
// namespace detail
template
<
typename
OutTensor
,
typename
InTensor
>
CK_TILE_DEVICE
void
transpose_tile2d
(
OutTensor
&
out
,
const
InTensor
&
in
)
{
using
InDataType
=
typename
InTensor
::
DataType
;
using
OutDataType
=
typename
OutTensor
::
DataType
;
using
InTileDistr
=
typename
InTensor
::
StaticTileDistribution
;
using
OutTileDistr
=
typename
OutTensor
::
StaticTileDistribution
;
using
InDstrEncode
=
typename
InTileDistr
::
DstrEncode
;
using
OutDstrEncode
=
typename
OutTileDistr
::
DstrEncode
;
using
InThreadTensorDesc
=
typename
InTensor
::
ThreadTensorDesc
;
using
OutThreadTensorDesc
=
typename
OutTensor
::
ThreadTensorDesc
;
// Ys:
constexpr
auto
in_thread_desc_lengths
=
InThreadTensorDesc
{}.
get_lengths
();
constexpr
auto
out_thread_desc_lengths
=
OutThreadTensorDesc
{}.
get_lengths
();
// type convert
const
auto
in_tmp
=
[
&
]()
{
if
constexpr
(
std
::
is_same_v
<
OutDataType
,
InDataType
>
)
{
return
in
;
}
else
{
return
tile_elementwise_in
(
type_convert
<
OutDataType
,
InDataType
>
,
in
);
}
}();
// Scenario where we switch from tile <Y, X> -> <X, Y> - only 2D tiles!
// we preserve Ps but swap Ys: <Y1, Y0> -> <Y0, Y1>
if
constexpr
(
InDstrEncode
::
rs_lengths_
==
OutDstrEncode
::
rs_lengths_
&&
InDstrEncode
::
hs_lengthss_
==
tuple_reverse
(
OutDstrEncode
::
hs_lengthss_
)
&&
InDstrEncode
::
NDimY
==
OutDstrEncode
::
NDimY
&&
InDstrEncode
::
NDimY
==
2
&&
in_thread_desc_lengths
==
tuple_reverse
(
out_thread_desc_lengths
))
// Any condition on Ps ??
// InDstrEncode::ps_to_rhss_major_ == OutDstrEncode::ps_to_rhss_major_ &&
// InDstrEncode::ps_to_rhss_minor_ == OutDstrEncode::ps_to_rhss_minor_ &&
{
detail
::
transpose_tile2d_impl_in_thread
(
out
,
in_tmp
);
}
else
{
static_assert
(
false
,
"Provided tensors could not be transposed!"
);
}
}
}
// namespace ck_tile
include/ck_tile/core/tensor/update_tile.hpp
View file @
3c5717df
...
...
@@ -41,15 +41,65 @@ template <typename BottomTensorView_,
typename
WindowLengths_
,
typename
TileDistribution_
,
index_t
NumCoord
,
typename
DataType_
>
typename
DataType_
,
index_t
i_access
=
-
1
,
bool
oob_conditional_check
=
true
>
CK_TILE_DEVICE
void
update_tile
(
tile_window_with_static_distribution
<
BottomTensorView_
,
WindowLengths_
,
TileDistribution_
,
NumCoord
>&
tile_window
,
const
static_distributed_tensor
<
DataType_
,
TileDistribution_
>&
dstr_tensor
)
const
static_distributed_tensor
<
DataType_
,
TileDistribution_
>&
dstr_tensor
,
number
<
i_access
>
=
{},
bool_constant
<
oob_conditional_check
>
=
{})
{
tile_window
.
update
(
dstr_tensor
);
tile_window
.
update
(
dstr_tensor
,
number
<
i_access
>
{},
bool_constant
<
oob_conditional_check
>
{});
}
template
<
typename
BottomTensorView_
,
typename
WindowLengths_
,
typename
TileDistribution_
,
index_t
NumCoord
,
typename
DataType_
,
index_t
i_access
=
-
1
,
bool
oob_conditional_check
=
true
,
bool
pre_nop
=
false
>
CK_TILE_DEVICE
void
update_tile_raw
(
tile_window_with_static_distribution
<
BottomTensorView_
,
WindowLengths_
,
TileDistribution_
,
NumCoord
>&
tile_window
,
const
static_distributed_tensor
<
DataType_
,
TileDistribution_
>&
dstr_tensor
,
number
<
i_access
>
=
{},
bool_constant
<
oob_conditional_check
>
=
{},
bool_constant
<
pre_nop
>
=
{})
{
tile_window
.
update_raw
(
dstr_tensor
,
number
<
i_access
>
{},
bool_constant
<
oob_conditional_check
>
{},
bool_constant
<
pre_nop
>
{});
}
template
<
typename
BottomTensorView_
,
typename
WindowLengths_
,
typename
TileDistribution_
,
typename
LinearBottomDims_
,
typename
DataType_
,
index_t
i_access
=
-
1
,
bool
oob_conditional_check
=
true
,
bool
pre_nop
=
false
>
CK_TILE_DEVICE
auto
update_tile_raw
(
tile_window_linear
<
BottomTensorView_
,
WindowLengths_
,
TileDistribution_
,
LinearBottomDims_
>&
tile_window
,
const
static_distributed_tensor
<
DataType_
,
TileDistribution_
>&
dstr_tensor
,
number
<
i_access
>
=
{},
bool_constant
<
oob_conditional_check
>
=
{},
bool_constant
<
pre_nop
>
=
{})
{
tile_window
.
update_raw
(
dstr_tensor
,
number
<
i_access
>
{},
bool_constant
<
oob_conditional_check
>
{},
bool_constant
<
pre_nop
>
{});
}
}
// namespace ck_tile
include/ck_tile/core/utility/static_counter.hpp
0 → 100644
View file @
3c5717df
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core/config.hpp"
namespace
ck_tile
{
template
<
typename
Context
,
index_t
Start
=
0
,
index_t
Step
=
1
>
struct
static_counter
{
public:
template
<
typename
Unique
>
static
constexpr
index_t
next
()
{
return
next
<
Unique
>
(
0
)
*
Step
+
Start
;
}
template
<
unsigned
long
long
>
static
constexpr
index_t
next
()
{
struct
Unique
{
};
return
next
<
Unique
>
(
0
)
*
Step
+
Start
;
}
template
<
typename
Unique
>
static
constexpr
index_t
current
()
{
return
current
<
Unique
>
(
0
)
*
Step
+
Start
;
}
template
<
unsigned
long
long
>
static
constexpr
index_t
current
()
{
struct
Unique
{
};
return
current
<
Unique
>
(
0
)
*
Step
+
Start
;
}
private:
template
<
index_t
I
>
struct
slot
{
_Pragma
(
"GCC diagnostic push"
);
_Pragma
(
"GCC diagnostic ignored
\"
-Wundefined-internal
\"
"
);
friend
constexpr
bool
slot_allocated
(
slot
<
I
>
);
_Pragma
(
"GCC diagnostic pop"
);
};
template
<
index_t
I
>
struct
allocate_slot
{
friend
constexpr
bool
slot_allocated
(
slot
<
I
>
)
{
return
true
;
}
enum
{
value
=
I
};
};
// If slot_allocated(slot<I>) has NOT been defined, then SFINAE will keep this function out of
// the overload set...
template
<
typename
Unique
,
index_t
I
=
0
,
bool
=
slot_allocated
(
slot
<
I
>())
>
static
constexpr
index_t
next
(
index_t
)
{
return
next
<
Unique
,
I
+
1
>
(
0
);
}
// ...And this function will be used, instead, which will define slot_allocated(slot<I>) via
// allocate_slot<I>.
template
<
typename
Unique
,
index_t
I
=
0
>
static
constexpr
index_t
next
(
double
)
{
return
allocate_slot
<
I
>::
value
;
}
// If slot_allocated(slot<I>) has NOT been defined, then SFINAE will keep this function out of
// the overload set...
template
<
typename
Unique
,
index_t
I
=
Start
,
bool
=
slot_allocated
(
slot
<
I
>())
>
static
constexpr
index_t
current
(
index_t
)
{
return
current
<
Unique
,
I
+
1
>
(
0
);
}
// ...And this function will be used, instead, which will return the current counter, or assert
// in case next() hasn't been called yet.
template
<
typename
Unique
,
index_t
I
=
Start
>
static
constexpr
index_t
current
(
double
)
{
static_assert
(
I
!=
0
,
"You must invoke next() first"
);
return
I
-
1
;
}
};
namespace
impl
{
template
<
int
I
>
struct
static_counter_uniq_
;
}
#define MAKE_SC() \
ck_tile::static_counter<ck_tile::impl::static_counter_uniq_<__COUNTER__>> {}
#define MAKE_SC_WITH(start_, step_) \
ck_tile::static_counter<ck_tile::impl::static_counter_uniq_<__COUNTER__>, start_, step_> {}
#define NEXT_SC(c_) c_.next<__COUNTER__>()
#define NEXT_SCI(c_, static_i_) c_.next<__COUNTER__ + static_i_>()
// Usage:
// constexpr auto c = MAKE_SC()
// NEXT_SC(c) // -> constexpr 0
// NEXT_SC(c) // -> constexpr 1
// NEXT_SC(c) // -> constexpr 2
}
// namespace ck_tile
Prev
1
…
17
18
19
20
21
22
23
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment