Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel_ROCM
Commits
bc1d4fa4
Commit
bc1d4fa4
authored
Nov 28, 2024
by
illsilin
Browse files
Merge branch 'gfx950' into promote_ocp_fp8
parents
001a32c5
97042d87
Changes
3
Show whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
129 additions
and
116 deletions
+129
-116
include/ck/utility/amd_ck_fp8.hpp
include/ck/utility/amd_ck_fp8.hpp
+0
-108
include/ck/utility/data_type.hpp
include/ck/utility/data_type.hpp
+115
-8
include/ck/utility/type_convert.hpp
include/ck/utility/type_convert.hpp
+14
-0
No files found.
include/ck/utility/amd_ck_fp8.hpp
View file @
bc1d4fa4
...
@@ -293,9 +293,6 @@ static __device__ float2_t cast_to_f32x2_from_f8x2(fp8x2_storage_t v)
...
@@ -293,9 +293,6 @@ static __device__ float2_t cast_to_f32x2_from_f8x2(fp8x2_storage_t v)
}
// namespace fp8_impl
}
// namespace fp8_impl
template
<
typename
T
,
index_t
N
>
struct
non_native_vector_base
;
struct
f8_ocp_t
struct
f8_ocp_t
{
{
using
data_type
=
fp8_storage_t
;
using
data_type
=
fp8_storage_t
;
...
@@ -389,111 +386,6 @@ struct bf8_ocp_t
...
@@ -389,111 +386,6 @@ struct bf8_ocp_t
}
}
};
};
template
<
index_t
N
>
struct
non_native_vector_base
<
f8_ocp_t
,
N
>
{
using
data_t
=
f8_ocp_t
::
data_type
;
static_assert
(
sizeof
(
f8_ocp_t
)
==
sizeof
(
data_t
),
"non_native_vector_base storage size mismatch"
);
using
data_v
=
data_t
__attribute__
((
ext_vector_type
(
sizeof
(
data_t
)
*
N
)));
using
type
=
non_native_vector_base
<
f8_ocp_t
,
N
>
;
data_v
d
;
// storage vector
__host__
__device__
non_native_vector_base
()
=
default
;
__host__
__device__
non_native_vector_base
(
data_t
a
)
:
d
{
a
}
{}
__host__
__device__
non_native_vector_base
(
f8_ocp_t
f
)
:
non_native_vector_base
(
f
.
data
)
{}
__host__
__device__
non_native_vector_base
(
data_v
v
)
:
d
{
v
}
{}
__host__
__device__
operator
data_v
()
const
{
return
d
;
}
};
template
<
>
struct
non_native_vector_base
<
f8_ocp_t
,
1
>
{
using
data_t
=
f8_ocp_t
::
data_type
;
using
data_v
=
data_t
__attribute__
((
ext_vector_type
(
sizeof
(
data_t
))));
using
type
=
non_native_vector_base
<
f8_ocp_t
,
1
>
;
data_v
d
;
// storage vector
__host__
__device__
non_native_vector_base
()
=
default
;
__host__
__device__
non_native_vector_base
(
data_t
a
)
:
d
{
a
}
{}
__host__
__device__
non_native_vector_base
(
f8_ocp_t
f
)
:
non_native_vector_base
(
f
.
data
)
{}
__host__
__device__
non_native_vector_base
(
data_v
v
)
:
d
{
v
}
{}
__host__
__device__
operator
data_v
()
const
{
return
d
;
}
__host__
__device__
operator
data_t
()
const
{
return
d
[
0
];
}
__host__
__device__
operator
f8_ocp_t
()
const
{
return
f8_ocp_t
{
d
[
0
]};
}
};
template
<
>
struct
non_native_vector_base
<
f8_ocp_t
,
2
>
{
using
data_t
=
f8_ocp_t
::
data_type
;
using
type
=
non_native_vector_base
<
f8_ocp_t
,
2
>
;
using
data_v
=
fp8_impl
::
fp8x2_storage_t
;
// type of storage vector
data_v
d
;
// storage vector
__host__
__device__
non_native_vector_base
()
=
default
;
__host__
__device__
non_native_vector_base
(
data_t
a
)
:
d
{
a
}
{}
__host__
__device__
non_native_vector_base
(
f8_ocp_t
f
)
:
non_native_vector_base
(
f
.
data
)
{}
__host__
__device__
non_native_vector_base
(
data_v
v
)
:
d
{
v
}
{}
__host__
__device__
operator
data_v
()
const
{
return
d
;
}
using
float2_t
=
fp8_impl
::
float2_t
;
#if CK_USE_OCP_FP8
__host__
__device__
explicit
operator
float2_t
()
const
#else
__host__
explicit
operator
float2_t
()
const
#endif
{
#if CK_OCP_FP8_CVT_FAST_PATH
return
fp8_impl
::
cast_to_f32x2_from_f8x2
<
f8_ocp_t
::
default_interpret
>
(
d
);
#else
return
float2_t
{
fp8_impl
::
cast_from_f8
<
float
,
f8_ocp_t
::
wm
,
f8_ocp_t
::
we
,
false
>
(
d
[
0
]),
fp8_impl
::
cast_from_f8
<
float
,
f8_ocp_t
::
wm
,
f8_ocp_t
::
we
,
false
>
(
d
[
1
])};
#endif
}
};
template
<
index_t
N
>
struct
non_native_vector_base
<
bf8_ocp_t
,
N
>
{
using
data_t
=
bf8_ocp_t
::
data_type
;
using
data_v
=
data_t
__attribute__
((
ext_vector_type
(
sizeof
(
data_t
)
*
N
)));
using
type
=
non_native_vector_base
<
bf8_ocp_t
,
N
>
;
data_v
d
;
// storage vector
__host__
__device__
non_native_vector_base
()
=
default
;
__host__
__device__
non_native_vector_base
(
data_t
a
)
:
d
{
a
}
{}
__host__
__device__
non_native_vector_base
(
data_v
v
)
:
d
{
v
}
{}
__host__
__device__
operator
data_v
()
const
{
return
d
;
}
};
template
<
>
struct
non_native_vector_base
<
bf8_ocp_t
,
1
>
{
using
data_t
=
bf8_ocp_t
::
data_type
;
using
data_v
=
data_t
__attribute__
((
ext_vector_type
(
sizeof
(
data_t
))));
using
type
=
non_native_vector_base
<
bf8_ocp_t
,
1
>
;
data_v
d
;
// storage vector
__host__
__device__
non_native_vector_base
()
=
default
;
__host__
__device__
non_native_vector_base
(
data_t
a
)
:
d
{
a
}
{}
__host__
__device__
non_native_vector_base
(
bf8_ocp_t
f
)
:
non_native_vector_base
(
f
.
data
)
{}
__host__
__device__
non_native_vector_base
(
data_v
v
)
:
d
{
v
}
{}
__host__
__device__
operator
data_v
()
const
{
return
d
;
}
__host__
__device__
operator
data_t
()
const
{
return
d
[
0
];
}
__host__
__device__
operator
bf8_ocp_t
()
const
{
return
bf8_ocp_t
{
d
[
0
]};
}
};
template
<
typename
T
>
template
<
typename
T
>
__host__
__device__
static
inline
constexpr
bool
fp8_is_nan
(
T
);
__host__
__device__
static
inline
constexpr
bool
fp8_is_nan
(
T
);
...
...
include/ck/utility/data_type.hpp
View file @
bc1d4fa4
...
@@ -1024,17 +1024,124 @@ struct vector_type<T, 256, typename std::enable_if_t<is_native_type<T>()>>
...
@@ -1024,17 +1024,124 @@ struct vector_type<T, 256, typename std::enable_if_t<is_native_type<T>()>>
}
}
};
};
template
<
typename
T
,
index_t
N
,
typename
Enable
=
void
>
struct
non_native_vector_base
;
template
<
typename
T
>
struct
nnvb_data_t_selector
{
using
type
=
unsigned
_BitInt
(
8
*
sizeof
(
T
));
};
template
<
>
struct
nnvb_data_t_selector
<
f8_ocp_t
>
{
using
type
=
f8_ocp_t
::
data_type
;
};
template
<
>
struct
nnvb_data_t_selector
<
bf8_ocp_t
>
{
using
type
=
bf8_ocp_t
::
data_type
;
};
template
<
typename
T
,
index_t
N
>
template
<
typename
T
,
index_t
N
>
struct
non_native_vector_base
struct
non_native_vector_base
<
T
,
N
,
std
::
enable_if_t
<
sizeof
(
T
)
==
1
||
sizeof
(
T
)
==
2
||
sizeof
(
T
)
==
4
||
sizeof
(
T
)
==
8
>>
{
{
using
data_t
=
typename
nnvb_data_t_selector
<
T
>::
type
;
// select data_t based on the size of T
static_assert
(
sizeof
(
T
)
==
sizeof
(
data_t
),
"non_native_vector_base storage size mismatch"
);
using
data_v
=
data_t
__attribute__
((
ext_vector_type
(
N
)));
using
type
=
non_native_vector_base
<
T
,
N
>
;
using
type
=
non_native_vector_base
<
T
,
N
>
;
__host__
__device__
non_native_vector_base
()
=
default
;
union
alignas
(
next_pow2
(
N
*
sizeof
(
T
)))
__host__
__device__
non_native_vector_base
(
const
type
&
)
=
default
;
{
__host__
__device__
non_native_vector_base
(
type
&&
)
=
default
;
data_v
dN
;
// storage vector;
__host__
__device__
~
non_native_vector_base
()
=
default
;
StaticallyIndexedArray
<
data_t
,
N
>
dxN
;
StaticallyIndexedArray
<
T
,
N
>
dTxN
;
StaticallyIndexedArray
<
data_v
,
1
>
dNx1
;
}
data_
;
__host__
__device__
constexpr
non_native_vector_base
(
data_t
a
)
:
data_
{
data_v
{
a
}}
{}
__host__
__device__
constexpr
non_native_vector_base
(
T
f
)
:
non_native_vector_base
(
bit_cast
<
data_t
>
(
f
))
{
}
__host__
__device__
constexpr
non_native_vector_base
()
:
non_native_vector_base
(
T
{}){};
__host__
__device__
constexpr
non_native_vector_base
(
data_v
v
)
:
data_
{
v
}
{}
__host__
__device__
constexpr
operator
data_v
()
const
{
return
data_
.
dN
;
}
__host__
__device__
constexpr
operator
data_t
()
const
{
if
constexpr
(
N
==
1
)
{
return
data_
.
dxN
[
Number
<
0
>
{}];
}
else
{
return
data_
.
dxN
;
// XXX this should cause an error
}
}
__host__
__device__
constexpr
operator
T
()
const
{
if
constexpr
(
N
==
1
)
{
return
data_
.
dTxN
[
Number
<
0
>
{}];
}
else
{
return
data_
.
dTxN
;
// XXX this should cause an error
}
}
T
d
[
N
];
template
<
typename
X
>
__host__
__device__
constexpr
const
auto
&
AsType
()
const
{
static_assert
(
is_same_v
<
X
,
data_t
>
||
is_same_v
<
X
,
T
>
||
is_same_v
<
X
,
data_v
>
,
"Something went wrong, please check src and dst types."
);
if
constexpr
(
is_same_v
<
X
,
data_t
>
)
{
return
data_
.
dxN
;
}
else
if
constexpr
(
is_same_v
<
X
,
T
>
)
{
return
data_
.
dTxN
;
}
else
if
constexpr
(
is_same_v
<
X
,
data_v
>
)
{
return
data_
.
dNx1
;
}
else
{
return
err
;
}
}
template
<
typename
X
>
__host__
__device__
constexpr
auto
&
AsType
()
{
static_assert
(
is_same_v
<
X
,
data_t
>
||
is_same_v
<
X
,
T
>
||
is_same_v
<
X
,
data_v
>
,
"Something went wrong, please check src and dst types."
);
if
constexpr
(
is_same_v
<
X
,
data_t
>
)
{
return
data_
.
dxN
;
}
else
if
constexpr
(
is_same_v
<
X
,
T
>
)
{
return
data_
.
dTxN
;
}
else
if
constexpr
(
is_same_v
<
X
,
data_v
>
)
{
return
data_
.
dNx1
;
}
else
{
return
err
;
}
}
};
};
template
<
typename
T
,
index_t
N
>
template
<
typename
T
,
index_t
N
>
...
@@ -1073,7 +1180,7 @@ struct vector_type<T, 1, typename std::enable_if_t<!is_native_type<T>()>>
...
@@ -1073,7 +1180,7 @@ struct vector_type<T, 1, typename std::enable_if_t<!is_native_type<T>()>>
__host__
__device__
constexpr
vector_type
()
:
data_
{
d1_t
{}}
{}
__host__
__device__
constexpr
vector_type
()
:
data_
{
d1_t
{}}
{}
__host__
__device__
constexpr
vector_type
(
type
v
)
:
data_
{
{
v
}
}
{}
__host__
__device__
constexpr
vector_type
(
type
v
)
:
data_
{
v
}
{}
template
<
typename
X
>
template
<
typename
X
>
__host__
__device__
constexpr
const
auto
&
AsType
()
const
__host__
__device__
constexpr
const
auto
&
AsType
()
const
...
...
include/ck/utility/type_convert.hpp
View file @
bc1d4fa4
...
@@ -451,6 +451,20 @@ inline __host__ __device__ float2_t type_convert<float2_t, f8x2_fnuz_t>(f8x2_fnu
...
@@ -451,6 +451,20 @@ inline __host__ __device__ float2_t type_convert<float2_t, f8x2_fnuz_t>(f8x2_fnu
#endif
#endif
}
}
template
<
>
inline
__host__
__device__
float2_t
type_convert
<
float2_t
,
f8x2_ocp_t
>
(
f8x2_ocp_t
x
)
{
#if CK_OCP_FP8_CVT_FAST_PATH
return
fp8_impl
::
cast_to_f32x2_from_f8x2
<
f8_ocp_t
::
default_interpret
>
(
x
.
AsType
<
fp8_impl
::
fp8x2_storage_t
>
()[
Number
<
0
>
{}]);
#else
return
float2_t
{
fp8_impl
::
cast_from_f8
<
float
,
f8_ocp_t
::
wm
,
f8_ocp_t
::
we
,
false
>
(
x
.
AsType
<
fp8_storage_t
>
()[
Number
<
0
>
{}]),
fp8_impl
::
cast_from_f8
<
float
,
f8_ocp_t
::
wm
,
f8_ocp_t
::
we
,
false
>
(
x
.
AsType
<
fp8_storage_t
>
()[
Number
<
1
>
{}])};
#endif
}
template
<
>
template
<
>
inline
__host__
__device__
half2_t
type_convert
<
half2_t
,
float2_t
>
(
float2_t
x
)
inline
__host__
__device__
half2_t
type_convert
<
half2_t
,
float2_t
>
(
float2_t
x
)
{
{
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment