Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel_ROCM
Commits
66c45110
Commit
66c45110
authored
Dec 18, 2024
by
illsilin
Browse files
Merge branch 'gfx950' of github.com:ROCm/composable_kernel-internal into gfx950
parents
12b16cc3
c847d5be
Changes
6
Hide whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
197 additions
and
8 deletions
+197
-8
include/ck/ck.hpp
include/ck/ck.hpp
+3
-0
include/ck/utility/data_type.hpp
include/ck/utility/data_type.hpp
+121
-1
test/data_type/CMakeLists.txt
test/data_type/CMakeLists.txt
+21
-6
test/data_type/test_bf6.cpp
test/data_type/test_bf6.cpp
+26
-0
test/data_type/test_fp4.cpp
test/data_type/test_fp4.cpp
+0
-1
test/data_type/test_fp6.cpp
test/data_type/test_fp6.cpp
+26
-0
No files found.
include/ck/ck.hpp
View file @
66c45110
...
@@ -158,6 +158,9 @@ CK_DECLARE_ENV_VAR_BOOL(CK_LOGGING)
...
@@ -158,6 +158,9 @@ CK_DECLARE_ENV_VAR_BOOL(CK_LOGGING)
// set rounding to nearest even as default for f8 conversions
// set rounding to nearest even as default for f8 conversions
#define CK_USE_SR_F8_CONVERSION 0
#define CK_USE_SR_F8_CONVERSION 0
// set rounding to nearest even as default for f6 conversions
#define CK_USE_SR_F6_CONVERSION 0
// set rounding to nearest even as default for f4 conversions
// set rounding to nearest even as default for f4 conversions
#define CK_USE_SR_F4_CONVERSION 0
#define CK_USE_SR_F4_CONVERSION 0
...
...
include/ck/utility/data_type.hpp
View file @
66c45110
...
@@ -12,6 +12,8 @@ using bhalf_t = ushort;
...
@@ -12,6 +12,8 @@ using bhalf_t = ushort;
using
half_t
=
_Float16
;
using
half_t
=
_Float16
;
using
int4_t
=
_BitInt
(
4
);
using
int4_t
=
_BitInt
(
4
);
using
f4_t
=
unsigned
_BitInt
(
4
);
using
f4_t
=
unsigned
_BitInt
(
4
);
using
f6_t
=
_BitInt
(
6
);
// e2m3 format
using
bf6_t
=
unsigned
_BitInt
(
6
);
// e3m2 format
struct
e8m0_bexp_t
struct
e8m0_bexp_t
{
{
...
@@ -60,7 +62,8 @@ inline constexpr bool is_native_type()
...
@@ -60,7 +62,8 @@ inline constexpr bool is_native_type()
return
is_same
<
T
,
double
>::
value
||
is_same
<
T
,
float
>::
value
||
is_same
<
T
,
half_t
>::
value
||
return
is_same
<
T
,
double
>::
value
||
is_same
<
T
,
float
>::
value
||
is_same
<
T
,
half_t
>::
value
||
is_same
<
T
,
bhalf_t
>::
value
||
is_same
<
T
,
int32_t
>::
value
||
is_same
<
T
,
int8_t
>::
value
||
is_same
<
T
,
bhalf_t
>::
value
||
is_same
<
T
,
int32_t
>::
value
||
is_same
<
T
,
int8_t
>::
value
||
is_same
<
T
,
uint8_t
>::
value
||
is_same
<
T
,
f8_fnuz_t
>::
value
||
is_same
<
T
,
uint8_t
>::
value
||
is_same
<
T
,
f8_fnuz_t
>::
value
||
is_same
<
T
,
bf8_fnuz_t
>::
value
||
is_same
<
T
,
bool
>::
value
||
is_same
<
T
,
f4_t
>::
value
;
is_same
<
T
,
bf8_fnuz_t
>::
value
||
is_same
<
T
,
bool
>::
value
||
is_same
<
T
,
f4_t
>::
value
||
is_same
<
T
,
f6_t
>::
value
||
is_same
<
T
,
bf6_t
>::
value
;
}
}
// vector_type
// vector_type
...
@@ -2076,6 +2079,65 @@ struct NumericLimits<f4_t>
...
@@ -2076,6 +2079,65 @@ struct NumericLimits<f4_t>
}
}
};
};
template
<
>
struct
NumericLimits
<
f6_t
>
{
static
constexpr
uint8_t
binary_min_normal
=
0x08
;
// 0b001000
static
constexpr
uint8_t
binary_max_normal
=
0x1F
;
// 0b011111
static
constexpr
uint8_t
binary_lowest_normal
=
0x3F
;
// 0b111111
static
constexpr
uint8_t
binary_min_subnorm
=
0x01
;
// 0b000001
static
constexpr
uint8_t
binary_max_subnorm
=
0x07
;
// 0b000111
static
constexpr
float
data_max_normal_number
=
7.5
;
static
constexpr
float
data_min_subnormal_number
=
0.125
;
__host__
__device__
static
constexpr
f6_t
Min
()
{
return
f6_t
(
binary_min_normal
&
0b111111
);
}
__host__
__device__
static
constexpr
f6_t
Max
()
{
return
f6_t
(
binary_max_normal
&
0b111111
);
}
__host__
__device__
static
constexpr
f6_t
Lowest
()
{
return
f6_t
(
binary_lowest_normal
&
0b111111
);
}
__host__
__device__
static
constexpr
f6_t
MinSubnorm
()
{
return
f6_t
(
binary_min_subnorm
&
0b111111
);
}
__host__
__device__
static
constexpr
f6_t
MaxSubnorm
()
{
return
f6_t
(
binary_max_subnorm
&
0b111111
);
}
__host__
__device__
static
constexpr
float
DataMaxNorm
()
{
return
data_max_normal_number
;
}
__host__
__device__
static
constexpr
float
DataMinSubnorm
()
{
return
data_min_subnormal_number
;
}
};
template
<
>
struct
NumericLimits
<
bf6_t
>
{
static
constexpr
uint8_t
binary_min_normal
=
0x08
;
// 0b001000
static
constexpr
uint8_t
binary_max_normal
=
0x1F
;
// 0b011111
static
constexpr
uint8_t
binary_lowest_normal
=
0x3F
;
// 0b111111
static
constexpr
uint8_t
binary_min_subnorm
=
0x01
;
// 0b000001
static
constexpr
uint8_t
binary_max_subnorm
=
0x03
;
// 0b000011
static
constexpr
float
data_max_normal_number
=
28
;
static
constexpr
float
data_min_subnormal_number
=
0.0625
;
__host__
__device__
static
constexpr
bf6_t
Min
()
{
return
bf6_t
(
binary_min_normal
);
}
__host__
__device__
static
constexpr
bf6_t
Max
()
{
return
bf6_t
(
binary_max_normal
);
}
__host__
__device__
static
constexpr
bf6_t
Lowest
()
{
return
bf6_t
(
binary_lowest_normal
);
}
__host__
__device__
static
constexpr
bf6_t
MinSubnorm
()
{
return
bf6_t
(
binary_min_subnorm
);
}
__host__
__device__
static
constexpr
bf6_t
MaxSubnorm
()
{
return
bf6_t
(
binary_max_subnorm
);
}
__host__
__device__
static
constexpr
float
DataMaxNorm
()
{
return
data_max_normal_number
;
}
__host__
__device__
static
constexpr
float
DataMinSubnorm
()
{
return
data_min_subnormal_number
;
}
};
template
<
>
template
<
>
struct
NumericLimits
<
e8m0_bexp_t
>
struct
NumericLimits
<
e8m0_bexp_t
>
{
{
...
@@ -2219,6 +2281,64 @@ struct NumericUtils<f4_t>
...
@@ -2219,6 +2281,64 @@ struct NumericUtils<f4_t>
using
bitwise_type
=
uint8_t
;
using
bitwise_type
=
uint8_t
;
};
};
template
<
>
struct
NumericUtils
<
f6_t
>
{
static
constexpr
int
exp
=
2
;
static
constexpr
int
mant
=
3
;
static
constexpr
int
bias
=
1
;
static
constexpr
uint32_t
sr_shift
=
12
;
static
constexpr
int
unbiased_exp_min
=
0
;
static
constexpr
int
unbiased_exp_max
=
2
;
static
constexpr
int
biased_exp_min
=
1
;
static
constexpr
int
biased_exp_max
=
3
;
static
constexpr
uint8_t
positive_zero_mask
=
0b000000
;
static
constexpr
uint8_t
negative_zero_mask
=
0b100000
;
static
constexpr
uint8_t
data_max_positive_normal_mask
=
0b011111
;
static
constexpr
uint8_t
data_max_negative_normal_mask
=
0b111111
;
static
constexpr
uint8_t
data_max_positive_subnormal_mask
=
0b000111
;
static
constexpr
uint8_t
data_max_negative_subnormal_mask
=
0b100111
;
static
constexpr
bool
has_inf
=
false
;
static
constexpr
bool
has_nan
=
false
;
static
constexpr
bool
has_zero
=
true
;
using
bitwise_type
=
uint8_t
;
};
template
<
>
struct
NumericUtils
<
bf6_t
>
{
static
constexpr
int
exp
=
3
;
static
constexpr
int
mant
=
2
;
static
constexpr
int
bias
=
3
;
static
constexpr
uint32_t
sr_shift
=
11
;
static
constexpr
int
unbiased_exp_min
=
-
2
;
static
constexpr
int
unbiased_exp_max
=
4
;
static
constexpr
int
biased_exp_min
=
1
;
static
constexpr
int
biased_exp_max
=
7
;
static
constexpr
uint8_t
positive_zero_mask
=
0b000000
;
static
constexpr
uint8_t
negative_zero_mask
=
0b100000
;
static
constexpr
uint8_t
data_max_positive_normal_mask
=
0b011111
;
static
constexpr
uint8_t
data_max_negative_normal_mask
=
0b111111
;
static
constexpr
uint8_t
data_max_positive_subnormal_mask
=
0b000011
;
static
constexpr
uint8_t
data_max_negative_subnormal_mask
=
0b100011
;
static
constexpr
bool
has_inf
=
false
;
static
constexpr
bool
has_nan
=
false
;
static
constexpr
bool
has_zero
=
true
;
using
bitwise_type
=
uint8_t
;
};
template
<
>
template
<
>
struct
NumericUtils
<
e8m0_bexp_t
>
struct
NumericUtils
<
e8m0_bexp_t
>
{
{
...
...
test/data_type/CMakeLists.txt
View file @
66c45110
...
@@ -9,8 +9,6 @@ if (USE_BITINT_EXTENSION_INT4)
...
@@ -9,8 +9,6 @@ if (USE_BITINT_EXTENSION_INT4)
endif
()
endif
()
endif
()
endif
()
add_custom_target
(
test_fp8
)
add_custom_target
(
test_fp8
)
if
(
CK_USE_OCP_FP8
)
if
(
CK_USE_OCP_FP8
)
...
@@ -42,11 +40,28 @@ if (CK_USE_FNUZ_FP8)
...
@@ -42,11 +40,28 @@ if (CK_USE_FNUZ_FP8)
add_dependencies
(
test_fp8 test_fp8_fnuz
)
add_dependencies
(
test_fp8 test_fp8_fnuz
)
add_dependencies
(
test_fp8 test_bf8_fnuz
)
add_dependencies
(
test_fp8 test_bf8_fnuz
)
endif
()
endif
()
add_gtest_executable
(
test_fp4 test_fp4.cpp
)
if
(
result EQUAL 0
)
target_link_libraries
(
test_fp4 PRIVATE utility
)
endif
()
if
(
GPU_TARGETS MATCHES
"gfx950"
)
add_custom_target
(
test_mx_data_types
)
add_gtest_executable
(
test_fp4 test_fp4.cpp
)
if
(
result EQUAL 0
)
target_link_libraries
(
test_fp4 PRIVATE utility
)
endif
()
add_dependencies
(
test_mx_data_types test_fp4
)
add_gtest_executable
(
test_fp6 test_fp6.cpp
)
if
(
result EQUAL 0
)
target_link_libraries
(
test_fp6 PRIVATE utility
)
endif
()
add_dependencies
(
test_mx_data_types test_fp6
)
add_gtest_executable
(
test_bf6 test_bf6.cpp
)
if
(
result EQUAL 0
)
target_link_libraries
(
test_bf6 PRIVATE utility
)
endif
()
add_dependencies
(
test_mx_data_types test_bf6
)
endif
()
add_gtest_executable
(
test_custom_type test_custom_type.cpp
)
add_gtest_executable
(
test_custom_type test_custom_type.cpp
)
if
(
result EQUAL 0
)
if
(
result EQUAL 0
)
target_link_libraries
(
test_custom_type PRIVATE utility
)
target_link_libraries
(
test_custom_type PRIVATE utility
)
...
...
test/data_type/test_bf6.cpp
0 → 100644
View file @
66c45110
// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
#include "gtest/gtest.h"
#include "ck/utility/data_type.hpp"
#include "ck/utility/type_convert.hpp"
#include "ck/utility/scaled_type_convert.hpp"
using
ck
::
bf6_t
;
using
ck
::
e8m0_bexp_t
;
using
ck
::
Number
;
using
ck
::
scaled_type_convert
;
using
ck
::
type_convert
;
using
ck
::
vector_type
;
using
ck
::
utils
::
cast_from_float
;
using
ck
::
utils
::
cast_to_float
;
TEST
(
BF6
,
NumericLimits
)
{
EXPECT_EQ
(
ck
::
NumericLimits
<
bf6_t
>::
Min
(),
bf6_t
(
0b001000
));
EXPECT_EQ
(
ck
::
NumericLimits
<
bf6_t
>::
Max
(),
bf6_t
(
0b011111
));
EXPECT_EQ
(
ck
::
NumericLimits
<
bf6_t
>::
Lowest
(),
bf6_t
(
0b111111
));
EXPECT_EQ
(
ck
::
NumericLimits
<
bf6_t
>::
MinSubnorm
(),
bf6_t
(
0b000001
));
EXPECT_EQ
(
ck
::
NumericLimits
<
bf6_t
>::
MaxSubnorm
(),
bf6_t
(
0b000011
));
}
test/data_type/test_fp4.cpp
View file @
66c45110
...
@@ -21,7 +21,6 @@ using ck::utils::cast_to_float;
...
@@ -21,7 +21,6 @@ using ck::utils::cast_to_float;
TEST
(
FP4
,
NumericLimits
)
TEST
(
FP4
,
NumericLimits
)
{
{
// constants given for negative zero nan mode
EXPECT_EQ
(
ck
::
NumericLimits
<
f4_t
>::
Min
(),
f4_t
{
0x2
});
EXPECT_EQ
(
ck
::
NumericLimits
<
f4_t
>::
Min
(),
f4_t
{
0x2
});
EXPECT_EQ
(
ck
::
NumericLimits
<
f4_t
>::
Max
(),
f4_t
{
0x7
});
EXPECT_EQ
(
ck
::
NumericLimits
<
f4_t
>::
Max
(),
f4_t
{
0x7
});
EXPECT_EQ
(
ck
::
NumericLimits
<
f4_t
>::
Lowest
(),
f4_t
{
0xF
});
EXPECT_EQ
(
ck
::
NumericLimits
<
f4_t
>::
Lowest
(),
f4_t
{
0xF
});
...
...
test/data_type/test_fp6.cpp
0 → 100644
View file @
66c45110
// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
#include "gtest/gtest.h"
#include "ck/utility/data_type.hpp"
#include "ck/utility/type_convert.hpp"
#include "ck/utility/scaled_type_convert.hpp"
using
ck
::
e8m0_bexp_t
;
using
ck
::
f6_t
;
using
ck
::
Number
;
using
ck
::
scaled_type_convert
;
using
ck
::
type_convert
;
using
ck
::
vector_type
;
using
ck
::
utils
::
cast_from_float
;
using
ck
::
utils
::
cast_to_float
;
TEST
(
FP6
,
NumericLimits
)
{
EXPECT_EQ
(
ck
::
NumericLimits
<
f6_t
>::
Min
(),
f6_t
(
0b001000
));
EXPECT_EQ
(
ck
::
NumericLimits
<
f6_t
>::
Max
(),
f6_t
(
0b011111
));
EXPECT_EQ
(
ck
::
NumericLimits
<
f6_t
>::
Lowest
(),
f6_t
(
0b111111
));
EXPECT_EQ
(
ck
::
NumericLimits
<
f6_t
>::
MinSubnorm
(),
f6_t
(
0b000001
));
EXPECT_EQ
(
ck
::
NumericLimits
<
f6_t
>::
MaxSubnorm
(),
f6_t
(
0b000111
));
}
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment