Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel_ROCM
Commits
7e2f7c95
"docs/source/vscode:/vscode.git/clone" did not exist on "8dfff7c01529a1a476696691626b261f92fd19e3"
Commit
7e2f7c95
authored
Oct 23, 2024
by
Andriy Roshchenko
Browse files
Enable build of example_gemm_xdl_fp8_bf8 test.
parent
043709c2
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
66 additions
and
46 deletions
+66
-46
include/ck/utility/amd_ck_fp8.hpp
include/ck/utility/amd_ck_fp8.hpp
+58
-42
include/ck/utility/data_type.hpp
include/ck/utility/data_type.hpp
+8
-4
No files found.
include/ck/utility/amd_ck_fp8.hpp
View file @
7e2f7c95
...
@@ -291,6 +291,9 @@ static __device__ float2_t cast_to_f32x2_from_f8x2(fp8x2_storage_t v)
...
@@ -291,6 +291,9 @@ static __device__ float2_t cast_to_f32x2_from_f8x2(fp8x2_storage_t v)
}
// namespace fp8_impl
}
// namespace fp8_impl
template
<
typename
T
,
index_t
N
>
struct
non_native_vector_base
;
struct
f8_ocp_t
struct
f8_ocp_t
{
{
using
data_type
=
fp8_storage_t
;
using
data_type
=
fp8_storage_t
;
...
@@ -336,8 +339,51 @@ struct f8_ocp_t
...
@@ -336,8 +339,51 @@ struct f8_ocp_t
}
}
};
};
template
<
typename
T
,
index_t
N
>
struct
bf8_ocp_t
struct
non_native_vector_base
;
{
using
data_type
=
fp8_storage_t
;
data_type
data
;
static
constexpr
ck_saturation_t
default_saturation
=
CK_SATFINITE
;
static
constexpr
ck_fp8_interpretation_t
default_interpret
=
CK_E5M2_OCP
;
static
constexpr
unsigned
int
we
=
5
;
// exponent width
static
constexpr
unsigned
int
wm
=
2
;
// mantissa width
__host__
__device__
constexpr
bool
operator
==
(
const
bf8_ocp_t
&
other
)
const
{
return
(
data
==
other
.
data
)
&&
(
fp8_impl
::
ocp_bf8_is_nan
(
data
)
==
false
);
// NaN != NaN
}
#if CK_USE_OCP_FP8
__host__
__device__
explicit
operator
float
()
const
#else
__host__
explicit
operator
float
()
const
#endif
{
#if defined(__gfx950__) || defined(__gfx1200__) || defined(__gfx1201__)
return
fp8_impl
::
cast_to_f32_from_f8
<
default_interpret
>
(
this
->
data
);
#else
return
fp8_impl
::
cast_from_f8
<
float
,
wm
,
we
,
false
>
(
this
->
data
);
// XXX: clip==false must be consistent with operator _Float16
#endif
}
#if CK_USE_OCP_FP8
__host__
__device__
explicit
operator
_Float16
()
const
#else
__host__
explicit
operator
_Float16
()
const
#endif
{
#if defined(__gfx950__) || defined(__gfx1200__) || defined(__gfx1201__)
return
static_cast
<
_Float16
>
(
fp8_impl
::
cast_to_f32_from_f8
<
default_interpret
>
(
this
->
data
));
#else
return
fp8_impl
::
cast_from_f8
<
_Float16
,
wm
,
we
,
false
>
(
this
->
data
);
// XXX: clip==false must be consistent with operator float
#endif
}
};
template
<
index_t
N
>
template
<
index_t
N
>
struct
non_native_vector_base
<
f8_ocp_t
,
N
>
struct
non_native_vector_base
<
f8_ocp_t
,
N
>
...
@@ -383,50 +429,20 @@ struct non_native_vector_base<f8_ocp_t, 2>
...
@@ -383,50 +429,20 @@ struct non_native_vector_base<f8_ocp_t, 2>
}
}
};
};
struct
bf8_ocp_t
template
<
index_t
N
>
struct
non_native_vector_base
<
bf8_ocp_t
,
N
>
{
{
using
data_type
=
fp8_storage_t
;
using
data_t
=
bf8_ocp_t
::
data_type
;
data_type
data
;
using
data_v
=
data_t
__attribute__
((
ext_vector_type
(
sizeof
(
data_t
)
*
N
)));
using
type
=
non_native_vector_base
<
bf8_ocp_t
,
N
>
;
static
constexpr
ck_saturation_t
default_saturation
=
CK_SATFINITE
;
static
constexpr
ck_fp8_interpretation_t
default_interpret
=
CK_E5M2_OCP
;
static
constexpr
unsigned
int
we
=
5
;
// exponent width
static
constexpr
unsigned
int
wm
=
2
;
// mantissa width
__host__
__device__
constexpr
bool
operator
==
(
const
bf8_ocp_t
&
other
)
const
{
return
(
data
==
other
.
data
)
&&
(
fp8_impl
::
ocp_bf8_is_nan
(
data
)
==
false
);
// NaN != NaN
}
#if CK_USE_OCP_FP8
data_v
d
;
// storage vector
__host__
__device__
explicit
operator
float
()
const
#else
__host__
__device__
non_native_vector_base
()
=
default
;
__host__
explicit
operator
float
()
const
__host__
__device__
non_native_vector_base
(
data_t
a
)
:
d
{
a
}
{}
#endif
__host__
__device__
non_native_vector_base
(
data_v
v
)
:
d
{
v
}
{}
{
#if defined(__gfx950__) || defined(__gfx1200__) || defined(__gfx1201__)
return
fp8_impl
::
cast_to_f32_from_f8
<
default_interpret
>
(
this
->
data
);
#else
return
fp8_impl
::
cast_from_f8
<
float
,
wm
,
we
,
false
>
(
this
->
data
);
// XXX: clip==false must be consistent with operator _Float16
#endif
}
#if CK_USE_OCP_FP8
__host__
__device__
operator
data_v
()
const
{
return
d
;
}
__host__
__device__
explicit
operator
_Float16
()
const
#else
__host__
explicit
operator
_Float16
()
const
#endif
{
#if defined(__gfx950__) || defined(__gfx1200__) || defined(__gfx1201__)
return
static_cast
<
_Float16
>
(
fp8_impl
::
cast_to_f32_from_f8
<
default_interpret
>
(
this
->
data
));
#else
return
fp8_impl
::
cast_from_f8
<
_Float16
,
wm
,
we
,
false
>
(
this
->
data
);
// XXX: clip==false must be consistent with operator float
#endif
}
};
};
namespace
fp8_impl
{
namespace
fp8_impl
{
...
...
include/ck/utility/data_type.hpp
View file @
7e2f7c95
...
@@ -1036,10 +1036,6 @@ struct non_native_vector_base
...
@@ -1036,10 +1036,6 @@ struct non_native_vector_base
template
<
typename
T
,
index_t
N
>
template
<
typename
T
,
index_t
N
>
struct
scalar_type
<
non_native_vector_base
<
T
,
N
>>
;
struct
scalar_type
<
non_native_vector_base
<
T
,
N
>>
;
// {
// using type = T;
// static constexpr index_t vector_size = N;
// };
template
<
index_t
N
>
template
<
index_t
N
>
struct
scalar_type
<
non_native_vector_base
<
f8_ocp_t
,
N
>>
struct
scalar_type
<
non_native_vector_base
<
f8_ocp_t
,
N
>>
...
@@ -1049,6 +1045,14 @@ struct scalar_type<non_native_vector_base<f8_ocp_t, N>>
...
@@ -1049,6 +1045,14 @@ struct scalar_type<non_native_vector_base<f8_ocp_t, N>>
static
constexpr
index_t
vector_size
=
N
;
static
constexpr
index_t
vector_size
=
N
;
};
};
template
<
index_t
N
>
struct
scalar_type
<
non_native_vector_base
<
bf8_ocp_t
,
N
>>
{
using
type
=
typename
non_native_vector_base
<
bf8_ocp_t
,
N
>::
data_t
;
static
constexpr
index_t
vector_size
=
N
;
};
// non-native vector_type implementation
// non-native vector_type implementation
template
<
typename
T
>
template
<
typename
T
>
struct
vector_type
<
T
,
1
,
typename
std
::
enable_if_t
<!
is_native_type
<
T
>
()
>>
struct
vector_type
<
T
,
1
,
typename
std
::
enable_if_t
<!
is_native_type
<
T
>
()
>>
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment