Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
06d2c7b1
Commit
06d2c7b1
authored
Jun 28, 2023
by
Jing Zhang
Committed by
root
Jun 28, 2023
Browse files
clean
parents
b27909a0
3b18f1e3
Changes
1000
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
455 additions
and
183 deletions
+455
-183
include/ck/utility/amd_llvm_intrinsic.hpp
include/ck/utility/amd_llvm_intrinsic.hpp
+0
-14
include/ck/utility/amd_wave_read_first_lane.hpp
include/ck/utility/amd_wave_read_first_lane.hpp
+138
-0
include/ck/utility/amd_wmma.hpp
include/ck/utility/amd_wmma.hpp
+1
-1
include/ck/utility/amd_xdlops.hpp
include/ck/utility/amd_xdlops.hpp
+2
-2
include/ck/utility/array.hpp
include/ck/utility/array.hpp
+1
-1
include/ck/utility/array_multi_index.hpp
include/ck/utility/array_multi_index.hpp
+1
-1
include/ck/utility/c_style_pointer_cast.hpp
include/ck/utility/c_style_pointer_cast.hpp
+1
-1
include/ck/utility/common_header.hpp
include/ck/utility/common_header.hpp
+3
-1
include/ck/utility/container_element_picker.hpp
include/ck/utility/container_element_picker.hpp
+1
-1
include/ck/utility/container_helper.hpp
include/ck/utility/container_helper.hpp
+1
-1
include/ck/utility/data_type.hpp
include/ck/utility/data_type.hpp
+32
-145
include/ck/utility/debug.hpp
include/ck/utility/debug.hpp
+1
-1
include/ck/utility/dynamic_buffer.hpp
include/ck/utility/dynamic_buffer.hpp
+17
-8
include/ck/utility/enable_if.hpp
include/ck/utility/enable_if.hpp
+1
-1
include/ck/utility/f8_utils.hpp
include/ck/utility/f8_utils.hpp
+250
-0
include/ck/utility/functional.hpp
include/ck/utility/functional.hpp
+1
-1
include/ck/utility/functional2.hpp
include/ck/utility/functional2.hpp
+1
-1
include/ck/utility/functional3.hpp
include/ck/utility/functional3.hpp
+1
-1
include/ck/utility/functional4.hpp
include/ck/utility/functional4.hpp
+1
-1
include/ck/utility/generic_memory_space_atomic.hpp
include/ck/utility/generic_memory_space_atomic.hpp
+1
-1
No files found.
Too many changes to show.
To preserve performance only
1000 of 1000+
files are displayed.
Plain diff
Email patch
include/ck/utility/amd_llvm_intrinsic.hpp
deleted
100644 → 0
View file @
b27909a0
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#ifndef CK_AMD_LLVM_INTRINSIC_HPP
#define CK_AMD_LLVM_INTRINSIC_HPP
#include "data_type.hpp"
namespace
ck
{
__device__
int32_t
llvm_amdgcn_readfirstlane_i32
(
int32_t
i
)
__asm
(
"llvm.amdgcn.readfirstlane"
);
}
// namespace ck
#endif
include/ck/utility/amd_wave_read_first_lane.hpp
0 → 100644
View file @
06d2c7b1
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck/ck.hpp"
#include "ck/utility/functional2.hpp"
#include "ck/utility/math.hpp"
#include <array>
#include <cstddef>
#include <cstdint>
#include <type_traits>
namespace
ck
{
namespace
detail
{
template
<
unsigned
SizeInBytes
>
struct
get_carrier
;
template
<
>
struct
get_carrier
<
1
>
{
using
type
=
uint8_t
;
};
template
<
>
struct
get_carrier
<
2
>
{
using
type
=
uint16_t
;
};
template
<
>
struct
get_carrier
<
3
>
{
using
type
=
class
carrier
{
using
value_type
=
uint32_t
;
std
::
array
<
std
::
byte
,
3
>
bytes
;
static_assert
(
sizeof
(
bytes
)
<=
sizeof
(
value_type
));
// replacement of host std::copy_n()
template
<
typename
InputIterator
,
typename
Size
,
typename
OutputIterator
>
__device__
static
OutputIterator
copy_n
(
InputIterator
from
,
Size
size
,
OutputIterator
to
)
{
if
(
0
<
size
)
{
*
to
=
*
from
;
++
to
;
for
(
Size
count
=
1
;
count
<
size
;
++
count
)
{
*
to
=
*++
from
;
++
to
;
}
}
return
to
;
}
// method to trigger template substitution failure
__device__
carrier
(
const
carrier
&
other
)
noexcept
{
copy_n
(
other
.
bytes
.
begin
(),
bytes
.
size
(),
bytes
.
begin
());
}
public:
__device__
carrier
&
operator
=
(
value_type
value
)
noexcept
{
copy_n
(
reinterpret_cast
<
const
std
::
byte
*>
(
&
value
),
bytes
.
size
(),
bytes
.
begin
());
return
*
this
;
}
__device__
operator
value_type
()
const
noexcept
{
std
::
byte
result
[
sizeof
(
value_type
)];
copy_n
(
bytes
.
begin
(),
bytes
.
size
(),
result
);
return
*
reinterpret_cast
<
const
value_type
*>
(
result
);
}
};
};
static_assert
(
sizeof
(
get_carrier
<
3
>::
type
)
==
3
);
template
<
>
struct
get_carrier
<
4
>
{
using
type
=
uint32_t
;
};
template
<
unsigned
SizeInBytes
>
using
get_carrier_t
=
typename
get_carrier
<
SizeInBytes
>::
type
;
}
// namespace detail
__device__
inline
int32_t
amd_wave_read_first_lane
(
int32_t
value
)
{
return
__builtin_amdgcn_readfirstlane
(
value
);
}
template
<
typename
Object
,
typename
=
std
::
enable_if_t
<
std
::
is_class_v
<
Object
>
&&
std
::
is_trivially_copyable_v
<
Object
>>>
__device__
auto
amd_wave_read_first_lane
(
const
Object
&
obj
)
{
using
Size
=
unsigned
;
constexpr
Size
SgprSize
=
4
;
constexpr
Size
ObjectSize
=
sizeof
(
Object
);
auto
*
const
from_obj
=
reinterpret_cast
<
const
std
::
byte
*>
(
&
obj
);
alignas
(
Object
)
std
::
byte
to_obj
[
ObjectSize
];
constexpr
Size
RemainedSize
=
ObjectSize
%
SgprSize
;
constexpr
Size
CompleteSgprCopyBoundary
=
ObjectSize
-
RemainedSize
;
for
(
Size
offset
=
0
;
offset
<
CompleteSgprCopyBoundary
;
offset
+=
SgprSize
)
{
using
Sgpr
=
detail
::
get_carrier_t
<
SgprSize
>
;
*
reinterpret_cast
<
Sgpr
*>
(
to_obj
+
offset
)
=
amd_wave_read_first_lane
(
*
reinterpret_cast
<
const
Sgpr
*>
(
from_obj
+
offset
));
}
if
constexpr
(
0
<
RemainedSize
)
{
using
Carrier
=
detail
::
get_carrier_t
<
RemainedSize
>
;
*
reinterpret_cast
<
Carrier
*>
(
to_obj
+
CompleteSgprCopyBoundary
)
=
amd_wave_read_first_lane
(
*
reinterpret_cast
<
const
Carrier
*>
(
from_obj
+
CompleteSgprCopyBoundary
));
}
/// NOTE: Implicitly start object lifetime. It's better to use std::start_lifetime_at() in this
/// scenario
return
*
reinterpret_cast
<
Object
*>
(
to_obj
);
}
}
// namespace ck
include/ck/utility/amd_wmma.hpp
View file @
06d2c7b1
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-202
2
, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-202
3
, Advanced Micro Devices, Inc. All rights reserved.
#ifndef CK_AMD_WMMA_HPP
#define CK_AMD_WMMA_HPP
...
...
include/ck/utility/amd_xdlops.hpp
View file @
06d2c7b1
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-202
2
, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-202
3
, Advanced Micro Devices, Inc. All rights reserved.
#ifndef CK_AMD_XDLOPS_HPP
#define CK_AMD_XDLOPS_HPP
...
...
@@ -344,7 +344,7 @@ struct intrin_mfma_f64_16x16x4f64<16, 16>
template
<
class
FloatC
>
__device__
static
void
Run
(
const
double
&
reg_a
,
const
double
&
reg_b
,
FloatC
&
reg_c
)
{
#if defined(__gfx90a__) || defined(__gfx940__)
#if defined(__gfx90a__) || defined(__gfx940__)
|| defined(__gfx941__) || defined(__gfx942__)
reg_c
.
template
AsType
<
double4_t
>()(
Number
<
0
>
{})
=
__builtin_amdgcn_mfma_f64_16x16x4f64
(
reg_a
,
reg_b
,
reg_c
.
template
AsType
<
double4_t
>()[
Number
<
0
>
{}],
0
,
0
,
0
);
#else
...
...
include/ck/utility/array.hpp
View file @
06d2c7b1
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-202
2
, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-202
3
, Advanced Micro Devices, Inc. All rights reserved.
#ifndef CK_ARRAY_HPP
#define CK_ARRAY_HPP
...
...
include/ck/utility/array_multi_index.hpp
View file @
06d2c7b1
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-202
2
, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-202
3
, Advanced Micro Devices, Inc. All rights reserved.
#ifndef CK_ARRAY_MULTI_INDEX_HPP
#define CK_ARRAY_MULTI_INDEX_HPP
...
...
include/ck/utility/c_style_pointer_cast.hpp
View file @
06d2c7b1
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-202
2
, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-202
3
, Advanced Micro Devices, Inc. All rights reserved.
#ifndef CK_C_STYLE_POINTER_CAST_HPP
#define CK_C_STYLE_POINTER_CAST_HPP
...
...
include/ck/utility/common_header.hpp
View file @
06d2c7b1
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-202
2
, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-202
3
, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
...
...
@@ -24,6 +24,7 @@
#include "ck/utility/tuple.hpp"
#include "ck/utility/tuple_helper.hpp"
#include "ck/utility/type.hpp"
#include "ck/utility/type_convert.hpp"
#include "ck/utility/magic_division.hpp"
#include "ck/utility/c_style_pointer_cast.hpp"
#include "ck/utility/is_known_at_compile_time.hpp"
...
...
@@ -33,6 +34,7 @@
#include "ck/utility/debug.hpp"
#include "ck/utility/amd_buffer_addressing.hpp"
#include "ck/utility/amd_wave_read_first_lane.hpp"
#include "ck/utility/generic_memory_space_atomic.hpp"
#include "ck/utility/get_id.hpp"
#include "ck/utility/thread_group.hpp"
...
...
include/ck/utility/container_element_picker.hpp
View file @
06d2c7b1
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-202
2
, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-202
3
, Advanced Micro Devices, Inc. All rights reserved.
#ifndef CK_CONTAINER_ELEMENT_PICKER_HPP
#define CK_CONTAINER_ELEMENT_PICKER_HPP
...
...
include/ck/utility/container_helper.hpp
View file @
06d2c7b1
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-202
2
, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-202
3
, Advanced Micro Devices, Inc. All rights reserved.
#ifndef CK_CONTAINER_HELPER_HPP
#define CK_CONTAINER_HELPER_HPP
...
...
include/ck/utility/data_type.hpp
View file @
06d2c7b1
...
...
@@ -12,6 +12,7 @@ using half_t = _Float16;
#ifdef CK_EXPERIMENTAL_BIT_INT_EXTENSION_INT4
using
int4_t
=
_BitInt
(
4
);
#endif
using
f8_t
=
uint8_t
;
// vector_type
template
<
typename
T
,
index_t
N
>
...
...
@@ -142,6 +143,13 @@ struct scalar_type<int4_t>
};
#endif
template
<
>
struct
scalar_type
<
f8_t
>
{
using
type
=
f8_t
;
static
constexpr
index_t
vector_size
=
1
;
};
//
template
<
typename
T
>
struct
vector_type
<
T
,
1
>
...
...
@@ -944,151 +952,13 @@ using int8x16_t = typename vector_type<int8_t, 16>::type;
using
int8x32_t
=
typename
vector_type
<
int8_t
,
32
>::
type
;
using
int8x64_t
=
typename
vector_type
<
int8_t
,
64
>::
type
;
// Convert X to Y
template
<
typename
Y
,
typename
X
>
__host__
__device__
constexpr
Y
type_convert
(
X
x
)
{
static_assert
(
!
std
::
is_reference_v
<
Y
>
&&
!
std
::
is_reference_v
<
X
>
);
return
static_cast
<
Y
>
(
x
);
}
// convert bfp16 to fp32
template
<
>
inline
__host__
__device__
constexpr
float
type_convert
<
float
,
bhalf_t
>
(
bhalf_t
x
)
{
union
{
uint32_t
int32
;
float
fp32
;
}
u
=
{
uint32_t
(
x
)
<<
16
};
return
u
.
fp32
;
}
// convert fp32 to bfp16
template
<
>
inline
__host__
__device__
constexpr
bhalf_t
type_convert
<
bhalf_t
,
float
>
(
float
x
)
{
union
{
float
fp32
;
uint32_t
int32
;
}
u
=
{
x
};
return
uint16_t
(
u
.
int32
>>
16
);
}
// convert bfp16 to fp16 via fp32
template
<
>
inline
__host__
__device__
constexpr
half_t
type_convert
<
half_t
,
bhalf_t
>
(
bhalf_t
x
)
{
float
x_fp32
=
type_convert
<
float
>
(
x
);
return
static_cast
<
half_t
>
(
x_fp32
);
}
// convert fp16 to bfp16 via fp32
template
<
>
inline
__host__
__device__
constexpr
bhalf_t
type_convert
<
bhalf_t
,
half_t
>
(
half_t
x
)
{
float
x_fp32
=
static_cast
<
float
>
(
x
);
return
type_convert
<
bhalf_t
>
(
x_fp32
);
}
// convert bfp16 to int32 via fp32
template
<
>
inline
__host__
__device__
constexpr
int32_t
type_convert
<
int32_t
,
bhalf_t
>
(
bhalf_t
x
)
{
float
x_fp32
=
type_convert
<
float
>
(
x
);
return
static_cast
<
int32_t
>
(
x_fp32
);
}
// convert int32 to bfp16 via fp32
template
<
>
inline
__host__
__device__
constexpr
bhalf_t
type_convert
<
bhalf_t
,
int32_t
>
(
int32_t
x
)
{
float
x_fp32
=
static_cast
<
float
>
(
x
);
return
type_convert
<
bhalf_t
>
(
x_fp32
);
}
// convert bfp16 to int8 via fp32
template
<
>
inline
__host__
__device__
constexpr
int8_t
type_convert
<
int8_t
,
bhalf_t
>
(
bhalf_t
x
)
{
float
x_fp32
=
type_convert
<
float
>
(
x
);
return
static_cast
<
int8_t
>
(
x_fp32
);
}
// convert int8 to bfp16 via fp32
template
<
>
inline
__host__
__device__
constexpr
bhalf_t
type_convert
<
bhalf_t
,
int8_t
>
(
int8_t
x
)
{
float
x_fp32
=
static_cast
<
float
>
(
x
);
return
type_convert
<
bhalf_t
>
(
x_fp32
);
}
// Declare a template function for bf16 conversion using RTN
template
<
typename
Y
,
typename
X
>
__host__
__device__
constexpr
Y
bf16_convert_rtn
(
X
x
);
// Convert fp32 to bf16 with RTN if higher precision is needed
template
<
>
inline
__host__
__device__
constexpr
bhalf_t
bf16_convert_rtn
<
bhalf_t
,
float
>
(
float
x
)
{
union
{
float
fp32
;
uint32_t
int32
;
}
u
=
{
x
};
// When the exponent bits are not all 1s, then the value is zero, normal,
// or subnormal. We round the bfloat16 mantissa up by adding 0x7FFF, plus
// 1 if the least significant bit of the bfloat16 mantissa is 1 (odd).
// This causes the bfloat16's mantissa to be incremented by 1 if the 16
// least significant bits of the float mantissa are greater than 0x8000,
// or if they are equal to 0x8000 and the least significant bit of the
// bfloat16 mantissa is 1 (odd). This causes it to be rounded to even when
// the lower 16 bits are exactly 0x8000. If the bfloat16 mantissa already
// has the value 0x7f, then incrementing it causes it to become 0x00 and
// the exponent is incremented by one, which is the next higher FP value
// to the unrounded bfloat16 value. When the bfloat16 value is subnormal
// with an exponent of 0x00 and a mantissa of 0x7f, it may be rounded up
// to a normal value with an exponent of 0x01 and a mantissa of 0x00.
// When the bfloat16 value has an exponent of 0xFE and a mantissa of 0x7F,
// incrementing it causes it to become an exponent of 0xFF and a mantissa
// of 0x00, which is Inf, the next higher value to the unrounded value.
bool
flag0
=
~
u
.
int32
&
0x7f800000
;
// When all of the exponent bits are 1, the value is Inf or NaN.
// Inf is indicated by a zero mantissa. NaN is indicated by any nonzero
// mantissa bit. Quiet NaN is indicated by the most significant mantissa
// bit being 1. Signaling NaN is indicated by the most significant
// mantissa bit being 0 but some other bit(s) being 1. If any of the
// lower 16 bits of the mantissa are 1, we set the least significant bit
// of the bfloat16 mantissa, in order to preserve signaling NaN in case
// the bfloat16's mantissa bits are all 0.
bool
flag1
=
!
flag0
&&
(
u
.
int32
&
0xffff
);
u
.
int32
+=
flag0
?
0x7fff
+
((
u
.
int32
>>
16
)
&
1
)
:
0
;
// Round to nearest, round to even
u
.
int32
|=
flag1
?
0x10000
:
0x0
;
// Preserve signaling NaN
return
uint16_t
(
u
.
int32
>>
16
);
}
// convert fp16 to bfp16 via fp32 with RTN if higher precision is needed
template
<
>
inline
__host__
__device__
constexpr
bhalf_t
bf16_convert_rtn
<
bhalf_t
,
half_t
>
(
half_t
x
)
{
float
x_fp32
=
static_cast
<
float
>
(
x
);
return
bf16_convert_rtn
<
bhalf_t
>
(
x_fp32
);
}
// f8
using
f8x2_t
=
typename
vector_type
<
f8_t
,
2
>::
type
;
using
f8x4_t
=
typename
vector_type
<
f8_t
,
4
>::
type
;
using
f8x8_t
=
typename
vector_type
<
f8_t
,
8
>::
type
;
using
f8x16_t
=
typename
vector_type
<
f8_t
,
16
>::
type
;
using
f8x32_t
=
typename
vector_type
<
f8_t
,
32
>::
type
;
using
f8x64_t
=
typename
vector_type
<
f8_t
,
64
>::
type
;
template
<
typename
T
>
struct
NumericLimits
...
...
@@ -1136,4 +1006,21 @@ struct NumericLimits<int4_t>
};
#endif // CK_EXPERIMENTAL_BIT_INT_EXTENSION_INT4
template
<
>
struct
NumericLimits
<
f8_t
>
{
static
constexpr
uint8_t
binary_min
=
0x08
;
// 0b00001000
static
constexpr
uint8_t
binary_max
=
0x77
;
// 0b01110111
static
constexpr
uint8_t
binary_lowest
=
0xF7
;
// 0b11110111
static
constexpr
uint8_t
binary_qnan
=
0x80
;
// 0b10000000
__host__
__device__
static
constexpr
f8_t
Min
()
{
return
bit_cast
<
f8_t
>
(
binary_min
);
}
__host__
__device__
static
constexpr
f8_t
Max
()
{
return
bit_cast
<
f8_t
>
(
binary_max
);
}
__host__
__device__
static
constexpr
f8_t
Lowest
()
{
return
bit_cast
<
f8_t
>
(
binary_lowest
);
}
__host__
__device__
static
constexpr
f8_t
QuietNaN
()
{
return
bit_cast
<
f8_t
>
(
binary_qnan
);
}
};
}
// namespace ck
include/ck/utility/debug.hpp
View file @
06d2c7b1
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-202
2
, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-202
3
, Advanced Micro Devices, Inc. All rights reserved.
#ifndef UTILITY_DEBUG_HPP
#define UTILITY_DEBUG_HPP
...
...
include/ck/utility/dynamic_buffer.hpp
View file @
06d2c7b1
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-202
2
, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-202
3
, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
...
...
@@ -19,7 +19,8 @@ namespace ck {
template
<
AddressSpaceEnum
BufferAddressSpace
,
typename
T
,
typename
ElementSpaceSize
,
bool
InvalidElementUseNumericalZeroValue
>
bool
InvalidElementUseNumericalZeroValue
,
AmdBufferCoherenceEnum
coherence
=
AmdBufferCoherenceEnum
::
DefaultCoherence
>
struct
DynamicBuffer
{
using
type
=
T
;
...
...
@@ -77,13 +78,16 @@ struct DynamicBuffer
if
constexpr
(
InvalidElementUseNumericalZeroValue
)
{
return
amd_buffer_load_invalid_element_return_zero
<
remove_cvref_t
<
T
>
,
t_per_x
>
(
return
amd_buffer_load_invalid_element_return_zero
<
remove_cvref_t
<
T
>
,
t_per_x
,
coherence
>
(
p_data_
,
i
,
is_valid_element
,
element_space_size_
);
}
else
{
return
amd_buffer_load_invalid_element_return_customized_value
<
remove_cvref_t
<
T
>
,
t_per_x
>
(
t_per_x
,
coherence
>
(
p_data_
,
i
,
is_valid_element
,
element_space_size_
,
invalid_element_value_
);
}
}
...
...
@@ -173,7 +177,7 @@ struct DynamicBuffer
{
constexpr
index_t
t_per_x
=
scalar_per_x_vector
/
scalar_per_t_vector
;
amd_buffer_store
<
remove_cvref_t
<
T
>
,
t_per_x
>
(
amd_buffer_store
<
remove_cvref_t
<
T
>
,
t_per_x
,
coherence
>
(
x
,
p_data_
,
i
,
is_valid_element
,
element_space_size_
);
}
else
if
constexpr
(
GetAddressSpace
()
==
AddressSpaceEnum
::
Lds
&&
...
...
@@ -376,14 +380,19 @@ struct DynamicBuffer
__host__
__device__
static
constexpr
bool
IsDynamicBuffer
()
{
return
true
;
}
};
template
<
AddressSpaceEnum
BufferAddressSpace
,
typename
T
,
typename
ElementSpaceSize
>
template
<
AddressSpaceEnum
BufferAddressSpace
,
AmdBufferCoherenceEnum
coherence
=
AmdBufferCoherenceEnum
::
DefaultCoherence
,
typename
T
,
typename
ElementSpaceSize
>
__host__
__device__
constexpr
auto
make_dynamic_buffer
(
T
*
p
,
ElementSpaceSize
element_space_size
)
{
return
DynamicBuffer
<
BufferAddressSpace
,
T
,
ElementSpaceSize
,
true
>
{
p
,
element_space_size
};
return
DynamicBuffer
<
BufferAddressSpace
,
T
,
ElementSpaceSize
,
true
,
coherence
>
{
p
,
element_space_size
};
}
template
<
AddressSpaceEnum
BufferAddressSpace
,
AmdBufferCoherenceEnum
coherence
=
AmdBufferCoherenceEnum
::
DefaultCoherence
,
typename
T
,
typename
ElementSpaceSize
,
typename
X
,
...
...
@@ -391,7 +400,7 @@ template <
__host__
__device__
constexpr
auto
make_dynamic_buffer
(
T
*
p
,
ElementSpaceSize
element_space_size
,
X
invalid_element_value
)
{
return
DynamicBuffer
<
BufferAddressSpace
,
T
,
ElementSpaceSize
,
false
>
{
return
DynamicBuffer
<
BufferAddressSpace
,
T
,
ElementSpaceSize
,
false
,
coherence
>
{
p
,
element_space_size
,
invalid_element_value
};
}
...
...
include/ck/utility/enable_if.hpp
View file @
06d2c7b1
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-202
2
, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-202
3
, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
...
...
include/ck/utility/f8_utils.hpp
0 → 100644
View file @
06d2c7b1
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck/utility/data_type.hpp"
namespace
ck
{
// fp8 rounding modes
// use standard for rounding to nearest, the faster one
// use stochastic for stochastic rounding, helps to avoid error accumulation
enum
class
f8_rounding_mode
{
standard
,
stochastic
};
}
// namespace ck
namespace
ck
::
utils
{
namespace
{
template
<
typename
T
,
bool
negative_zero_nan
,
bool
clip
,
bool
stoch
>
__host__
__device__
f8_t
run_cast_to_f8
(
T
x
,
uint32_t
rng
)
{
// check data type
constexpr
bool
is_half
=
std
::
is_same
<
T
,
half_t
>::
value
;
constexpr
bool
is_float
=
std
::
is_same
<
T
,
float
>::
value
;
// fp8 exponent/mantissa layout
constexpr
int
f8_exp
=
4
;
constexpr
int
f8_mant
=
3
;
// resulting type exponent/mantissa layout
constexpr
int
type_exp
=
is_half
?
5
:
8
;
constexpr
int
type_mant
=
is_half
?
10
:
23
;
int
exponent
;
uint32_t
head
,
mantissa
,
sign
;
// nan code is same for float and half
constexpr
uint8_t
nan_code
=
0x80
;
constexpr
uint32_t
nan_mask
=
is_half
?
0x7C00
:
0x7F800000
;
// convert to bitwise
typedef
typename
std
::
conditional
<
std
::
is_same
<
T
,
half_t
>::
value
,
uint16_t
,
uint32_t
>::
type
T_bitwise
;
T_bitwise
x_bitwise
=
*
(
reinterpret_cast
<
T_bitwise
*>
(
&
x
));
// unpack the input, depends on datatype
if
constexpr
(
is_float
)
{
head
=
x_bitwise
&
0xFF800000
;
mantissa
=
x_bitwise
&
0x7FFFFF
;
exponent
=
(
head
>>
type_mant
)
&
0xFF
;
sign
=
head
>>
(
type_exp
+
type_mant
);
}
else
if
constexpr
(
is_half
)
{
head
=
x_bitwise
&
0xFC00
;
mantissa
=
x_bitwise
&
0x3FF
;
exponent
=
(
head
>>
type_mant
)
&
0x1F
;
sign
=
head
>>
(
type_exp
+
type_mant
);
}
uint32_t
signed_inf
=
(
sign
<<
(
type_exp
+
type_mant
))
+
(((
1
<<
type_exp
)
-
1
)
<<
type_mant
);
uint32_t
drop_mask
=
(
1
<<
(
type_mant
-
f8_mant
))
-
1
;
constexpr
int
max_exp
=
(
1
<<
f8_exp
)
-
(
negative_zero_nan
?
1
:
2
);
constexpr
int
exp_low_cutoff
=
(
1
<<
(
type_exp
-
1
))
-
(
1
<<
(
f8_exp
-
1
))
+
1
-
(
negative_zero_nan
?
1
:
0
);
if
constexpr
(
negative_zero_nan
)
{
if
((
x_bitwise
&
nan_mask
)
==
nan_mask
)
return
nan_code
;
}
else
{
if
((
x_bitwise
&
nan_mask
)
==
nan_mask
)
return
signed_inf
+
(
mantissa
!=
0
?
1
:
0
);
}
// check if x is 0.0
if
(
x_bitwise
==
0
)
return
0
;
exponent
-=
exp_low_cutoff
-
1
;
if
(
exponent
<=
0
)
drop_mask
=
(
1
<<
(
type_mant
-
f8_mant
+
1
-
exponent
))
-
1
;
mantissa
+=
1
<<
type_mant
;
// apply random number if needed
mantissa
+=
(
stoch
?
rng
:
mantissa
)
&
drop_mask
;
if
(
mantissa
>=
(
2
<<
type_mant
))
{
mantissa
>>=
1
;
exponent
++
;
}
mantissa
>>=
(
type_mant
-
f8_mant
);
// check negative exponent
if
(
exponent
<=
0
)
{
if
(
x_bitwise
==
0
)
return
0
;
else
{
// subnormal range; represented by a subnormal float8 (exponent 0)
// and involves loss of accuracy
mantissa
>>=
1
-
exponent
;
exponent
=
0
;
}
}
// above range: quantize to maximum possible float of the same sign
else
if
(
exponent
>
max_exp
)
{
if
(
clip
)
{
mantissa
=
(
1
<<
f8_mant
)
-
1
;
exponent
=
max_exp
;
}
else
{
return
signed_inf
;
}
}
// check if x is 0.0 or -0.0
if
(
exponent
==
0
&&
mantissa
==
0
)
return
negative_zero_nan
?
0
:
(
sign
<<
(
f8_exp
+
f8_mant
));
mantissa
&=
(
1
<<
f8_mant
)
-
1
;
return
(
sign
<<
(
f8_exp
+
f8_mant
))
|
(
exponent
<<
f8_mant
)
|
mantissa
;
}
template
<
typename
T
,
bool
negative_zero_nan
>
__host__
__device__
T
run_cast_from_f8
(
f8_t
x
)
{
// check data type
constexpr
bool
is_half
=
std
::
is_same
<
T
,
half_t
>::
value
;
constexpr
bool
is_float
=
std
::
is_same
<
T
,
float
>::
value
;
// fp8 exponent/mantissa layout
constexpr
int
f8_exp
=
4
;
constexpr
int
f8_mant
=
3
;
// resulting type exponent/mantissa layout
constexpr
int
type_exp
=
is_half
?
5
:
8
;
constexpr
int
type_mant
=
is_half
?
10
:
23
;
// prepare the codes
constexpr
uint8_t
nan_code
=
0x80
;
T
fInf
,
fNegInf
,
fNaN
,
fNeg0
;
if
constexpr
(
is_half
)
{
constexpr
uint16_t
ihInf
=
0x7C00
;
constexpr
uint16_t
ihNegInf
=
0xFC00
;
constexpr
uint16_t
ihNaN
=
0x7C01
;
constexpr
uint16_t
ihNeg0
=
0x8000
;
fInf
=
*
(
reinterpret_cast
<
const
half_t
*>
(
&
ihInf
));
fNegInf
=
*
(
reinterpret_cast
<
const
half_t
*>
(
&
ihNegInf
));
fNaN
=
*
(
reinterpret_cast
<
const
half_t
*>
(
&
ihNaN
));
fNeg0
=
*
(
reinterpret_cast
<
const
half_t
*>
(
&
ihNeg0
));
}
else
if
constexpr
(
is_float
)
{
constexpr
uint32_t
ifInf
=
0x7F800000
;
constexpr
uint32_t
ifNegInf
=
0xFF800000
;
constexpr
uint32_t
ifNaN
=
0x7F800001
;
constexpr
uint32_t
ifNeg0
=
0x80000000
;
fInf
=
*
(
reinterpret_cast
<
const
float
*>
(
&
ifInf
));
fNegInf
=
*
(
reinterpret_cast
<
const
float
*>
(
&
ifNegInf
));
fNaN
=
*
(
reinterpret_cast
<
const
float
*>
(
&
ifNaN
));
fNeg0
=
*
(
reinterpret_cast
<
const
float
*>
(
&
ifNeg0
));
}
// unpack the input
uint32_t
sign
=
x
>>
(
f8_exp
+
f8_mant
);
uint32_t
mantissa
=
x
&
((
1
<<
f8_mant
)
-
1
);
int
exponent
=
(
x
&
0x7F
)
>>
f8_mant
;
constexpr
int
exp_low_cutoff
=
(
1
<<
(
type_exp
-
1
))
-
(
1
<<
(
f8_exp
-
1
))
+
1
-
(
negative_zero_nan
?
1
:
0
);
typename
std
::
conditional
<
std
::
is_same
<
T
,
half_t
>::
value
,
uint16_t
,
uint32_t
>::
type
retval
;
if
constexpr
(
negative_zero_nan
)
{
if
(
x
==
nan_code
)
return
fNaN
;
}
else
{
if
(
x
==
nan_code
)
return
fNeg0
;
if
(
exponent
==
((
1
<<
f8_exp
)
-
1
))
return
(
mantissa
==
0
)
?
(
sign
?
fNegInf
:
fInf
)
:
fNaN
;
}
// subnormal input
if
(
exponent
==
0
)
{
// guaranteed mantissa!=0 since cases 0x0 and 0x80 are handled above
int
sh
=
1
+
__builtin_clz
(
mantissa
)
-
((
1
+
type_exp
+
type_mant
)
-
f8_mant
);
mantissa
<<=
sh
;
mantissa
&=
((
1
<<
f8_mant
)
-
1
);
exponent
+=
1
-
sh
;
}
exponent
+=
exp_low_cutoff
-
1
;
mantissa
<<=
type_mant
-
f8_mant
;
// subnormal output (occurs when T=half, we=5, negative_zero_nan=true)
if
(
exponent
<=
0
)
{
mantissa
|=
1
<<
type_mant
;
mantissa
>>=
1
-
exponent
;
exponent
=
0
;
}
retval
=
(
sign
<<
(
type_exp
+
type_mant
))
|
(
exponent
<<
type_mant
)
|
mantissa
;
return
*
(
reinterpret_cast
<
const
T
*>
(
&
retval
));
}
}
// namespace
template
<
typename
T
,
bool
negative_zero_nan
,
bool
clip
,
bool
stoch
>
__host__
__device__
f8_t
cast_to_f8
(
T
x
,
uint32_t
rng
)
{
// check datatype
constexpr
bool
is_half
=
std
::
is_same
<
T
,
half_t
>::
value
;
constexpr
bool
is_float
=
std
::
is_same
<
T
,
float
>::
value
;
static_assert
(
is_half
||
is_float
,
"Only half and float can be casted to f8."
);
return
run_cast_to_f8
<
T
,
negative_zero_nan
,
clip
,
stoch
>
(
x
,
rng
);
}
template
<
typename
T
,
bool
negative_zero_nan
>
__host__
__device__
T
cast_from_f8
(
f8_t
x
)
{
// check datatype
constexpr
bool
is_half
=
std
::
is_same
<
T
,
half_t
>::
value
;
constexpr
bool
is_float
=
std
::
is_same
<
T
,
float
>::
value
;
static_assert
(
is_half
||
is_float
,
"only half and float are supported."
);
// check if x is 0.0
if
(
x
==
0
)
return
static_cast
<
T
>
(
0
);
return
run_cast_from_f8
<
T
,
negative_zero_nan
>
(
x
);
}
}
// namespace ck::utils
include/ck/utility/functional.hpp
View file @
06d2c7b1
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-202
2
, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-202
3
, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
...
...
include/ck/utility/functional2.hpp
View file @
06d2c7b1
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-202
2
, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-202
3
, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
...
...
include/ck/utility/functional3.hpp
View file @
06d2c7b1
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-202
2
, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-202
3
, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
...
...
include/ck/utility/functional4.hpp
View file @
06d2c7b1
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-202
2
, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-202
3
, Advanced Micro Devices, Inc. All rights reserved.
#ifndef CK_FUNCTIONAL4_HPP
#define CK_FUNCTIONAL4_HPP
...
...
include/ck/utility/generic_memory_space_atomic.hpp
View file @
06d2c7b1
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-202
2
, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-202
3
, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "data_type.hpp"
...
...
Prev
1
…
23
24
25
26
27
28
29
30
31
…
50
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment