Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel_ROCM
Commits
3a64757f
Commit
3a64757f
authored
Jan 16, 2025
by
Rostyslav Geyyer
Browse files
Add vector support
parent
d44b24d1
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
256 additions
and
1 deletion
+256
-1
include/ck/utility/data_type.hpp
include/ck/utility/data_type.hpp
+256
-1
No files found.
include/ck/utility/data_type.hpp
View file @
3a64757f
...
...
@@ -26,6 +26,7 @@ struct f4x2_pk_t
template
<
index_t
I
>
__host__
__device__
inline
type
unpack
()
const
{
static_assert
(
I
<
2
,
"Index is out of range."
);
if
constexpr
(
I
==
0
)
return
data
&
0b00001111
;
else
...
...
@@ -38,6 +39,126 @@ struct f4x2_pk_t
}
};
struct
f6x16_pk_t
{
// store 16 elements of f6_t in an array of 3 uint32_t
using
element_type
=
uint32_t
;
using
type
=
StaticallyIndexedArray_v2
<
element_type
,
3
>
;
type
data
;
f6x16_pk_t
()
:
data
{
type
{}}
{}
f6x16_pk_t
(
type
init
)
:
data
{
init
}
{}
template
<
index_t
I
>
__host__
__device__
inline
f6_t
unpack
()
{
static_assert
(
I
<
16
,
"Index is out of range."
);
union
{
StaticallyIndexedArray_v2
<
element_type
,
3
>
uint32_array
;
f6_t
f6_array
[
16
];
}
data_union
{
data
};
return
data_union
.
f6_array
[
I
];
}
__host__
__device__
inline
type
pack
(
f6_t
*
x
)
{
type
*
retval
=
reinterpret_cast
<
type
*>
(
x
);
return
*
retval
;
}
};
struct
f6x32_pk_t
{
// store 16 elements of f6_t in an array of 6 uint32_t
using
element_type
=
uint32_t
;
using
type
=
StaticallyIndexedArray_v2
<
element_type
,
6
>
;
type
data
;
f6x32_pk_t
()
:
data
{
type
{}}
{}
f6x32_pk_t
(
type
init
)
:
data
{
init
}
{}
template
<
index_t
I
>
__host__
__device__
inline
f6_t
unpack
()
{
static_assert
(
I
<
32
,
"Index is out of range."
);
union
{
StaticallyIndexedArray_v2
<
element_type
,
6
>
uint32_array
;
f6_t
f6_array
[
32
];
}
data_union
{
data
};
return
data_union
.
f6_array
[
I
];
}
__host__
__device__
inline
type
pack
(
f6_t
*
x
)
{
type
*
retval
=
reinterpret_cast
<
type
*>
(
x
);
return
*
retval
;
}
};
struct
bf6x16_pk_t
{
// store 16 elements of bf6_t in an array of 3 uint32_t
using
element_type
=
uint32_t
;
using
type
=
StaticallyIndexedArray_v2
<
element_type
,
3
>
;
type
data
;
bf6x16_pk_t
()
:
data
{
type
{}}
{}
bf6x16_pk_t
(
type
init
)
:
data
{
init
}
{}
template
<
index_t
I
>
__host__
__device__
inline
bf6_t
unpack
()
{
static_assert
(
I
<
16
,
"Index is out of range."
);
union
{
StaticallyIndexedArray_v2
<
element_type
,
3
>
uint32_array
;
bf6_t
bf6_array
[
16
];
}
data_union
{
data
};
return
data_union
.
bf6_array
[
I
];
}
__host__
__device__
inline
type
pack
(
bf6_t
*
x
)
{
type
*
retval
=
reinterpret_cast
<
type
*>
(
x
);
return
*
retval
;
}
};
struct
bf6x32_pk_t
{
// store 16 elements of bf6_t in an array of 6 uint32_t
using
element_type
=
uint32_t
;
using
type
=
StaticallyIndexedArray_v2
<
element_type
,
6
>
;
type
data
;
bf6x32_pk_t
()
:
data
{
type
{}}
{}
bf6x32_pk_t
(
type
init
)
:
data
{
init
}
{}
template
<
index_t
I
>
__host__
__device__
inline
bf6_t
unpack
()
{
static_assert
(
I
<
32
,
"Index is out of range."
);
union
{
StaticallyIndexedArray_v2
<
element_type
,
6
>
uint32_array
;
bf6_t
bf6_array
[
32
];
}
data_union
{
data
};
return
data_union
.
bf6_array
[
I
];
}
__host__
__device__
inline
type
pack
(
bf6_t
*
x
)
{
type
*
retval
=
reinterpret_cast
<
type
*>
(
x
);
return
*
retval
;
}
};
inline
constexpr
auto
next_pow2
(
uint32_t
x
)
{
// Precondition: x > 1.
...
...
@@ -45,7 +166,7 @@ inline constexpr auto next_pow2(uint32_t x)
}
// native types: double, float, _Float16, ushort, int32_t, int8_t, uint8_t, f8_fnuz_t, bf8_fnuz_t,
// native types: bool
// native types: bool
, f4_t, f6_t, bf6_t
template
<
typename
T
>
inline
constexpr
bool
is_native_type
()
{
...
...
@@ -1065,12 +1186,37 @@ struct nnvb_data_t_selector<f8_ocp_t>
{
using
type
=
f8_ocp_t
::
data_type
;
};
template
<
>
struct
nnvb_data_t_selector
<
bf8_ocp_t
>
{
using
type
=
bf8_ocp_t
::
data_type
;
};
template
<
>
struct
nnvb_data_t_selector
<
f6x16_pk_t
>
{
using
type
=
f6x16_pk_t
::
type
;
};
template
<
>
struct
nnvb_data_t_selector
<
f6x32_pk_t
>
{
using
type
=
f6x32_pk_t
::
type
;
};
template
<
>
struct
nnvb_data_t_selector
<
bf6x16_pk_t
>
{
using
type
=
bf6x16_pk_t
::
type
;
};
template
<
>
struct
nnvb_data_t_selector
<
bf6x32_pk_t
>
{
using
type
=
bf6x32_pk_t
::
type
;
};
template
<
typename
T
,
index_t
N
>
struct
non_native_vector_base
<
T
,
...
...
@@ -1171,6 +1317,111 @@ struct non_native_vector_base<
}
};
// implementation for f6x16 and f6x32
template
<
typename
T
,
index_t
N
>
struct
non_native_vector_base
<
T
,
N
,
std
::
enable_if_t
<
sizeof
(
T
)
==
12
||
sizeof
(
T
)
==
24
>>
{
using
data_t
=
typename
nnvb_data_t_selector
<
T
>::
type
;
// select data_t based on declared base type
using
element_t
=
typename
T
::
element_type
;
// select element_t based on declared element type
static_assert
(
sizeof
(
T
)
==
sizeof
(
data_t
),
"non_native_vector_base storage size mismatch"
);
static
constexpr
size_t
size_factor
=
sizeof
(
data_t
)
/
sizeof
(
element_t
);
// f6x16: 12/4 = 3, f6x32: 24/4 = 6
using
data_v
=
element_t
__attribute__
((
ext_vector_type
(
N
*
size_factor
)));
using
type
=
non_native_vector_base
<
T
,
N
>
;
union
alignas
(
next_pow2
(
N
*
sizeof
(
T
)))
{
data_v
dN
;
// storage vector;
StaticallyIndexedArray
<
data_t
,
N
>
dxN
;
StaticallyIndexedArray
<
T
,
N
>
dTxN
;
StaticallyIndexedArray
<
data_v
,
1
>
dNx1
;
}
data_
;
__host__
__device__
constexpr
non_native_vector_base
(
data_t
a
)
:
data_
{
data_v
(
a
.
At
(
Number
<
0
>
{}))}
{
}
__host__
__device__
constexpr
non_native_vector_base
(
T
f
)
:
non_native_vector_base
(
bit_cast
<
data_t
>
(
f
))
{
}
__host__
__device__
constexpr
non_native_vector_base
()
:
non_native_vector_base
(
T
{}){};
__host__
__device__
constexpr
non_native_vector_base
(
data_v
v
)
:
data_
{
v
}
{}
__host__
__device__
constexpr
operator
data_v
()
const
{
return
data_
.
dN
;
}
__host__
__device__
constexpr
operator
data_t
()
const
{
if
constexpr
(
N
==
1
)
{
return
data_
.
dxN
[
Number
<
0
>
{}];
}
else
{
return
data_
.
dxN
;
// XXX this should cause an error
}
}
__host__
__device__
constexpr
operator
T
()
const
{
if
constexpr
(
N
==
1
)
{
return
data_
.
dTxN
[
Number
<
0
>
{}];
}
else
{
return
data_
.
dTxN
;
// XXX this should cause an error
}
}
template
<
typename
X
>
__host__
__device__
constexpr
const
auto
&
AsType
()
const
{
static_assert
(
is_same_v
<
X
,
data_t
>
||
is_same_v
<
X
,
T
>
||
is_same_v
<
X
,
data_v
>
,
"Something went wrong, please check src and dst types."
);
if
constexpr
(
is_same_v
<
X
,
data_t
>
)
{
return
data_
.
dxN
;
}
else
if
constexpr
(
is_same_v
<
X
,
T
>
)
{
return
data_
.
dTxN
;
}
else
if
constexpr
(
is_same_v
<
X
,
data_v
>
)
{
return
data_
.
dNx1
;
}
else
{
return
err
;
}
}
template
<
typename
X
>
__host__
__device__
constexpr
auto
&
AsType
()
{
static_assert
(
is_same_v
<
X
,
data_t
>
||
is_same_v
<
X
,
T
>
||
is_same_v
<
X
,
data_v
>
,
"Something went wrong, please check src and dst types."
);
if
constexpr
(
is_same_v
<
X
,
data_t
>
)
{
return
data_
.
dxN
;
}
else
if
constexpr
(
is_same_v
<
X
,
T
>
)
{
return
data_
.
dTxN
;
}
else
if
constexpr
(
is_same_v
<
X
,
data_v
>
)
{
return
data_
.
dNx1
;
}
else
{
return
err
;
}
}
};
template
<
typename
T
,
index_t
N
>
struct
scalar_type
<
non_native_vector_base
<
T
,
N
>>
;
...
...
@@ -1906,6 +2157,10 @@ using f4x16_t = typename vector_type<f4x2_pk_t, 8>::type;
using
f4x32_t
=
typename
vector_type
<
f4x2_pk_t
,
16
>::
type
;
using
f4x64_t
=
typename
vector_type
<
f4x2_pk_t
,
32
>::
type
;
// f6
using
f6x16_t
=
typename
vector_type
<
f6x16_pk_t
,
1
>::
type
;
using
f6x32_t
=
typename
vector_type
<
f6x32_pk_t
,
1
>::
type
;
template
<
typename
T
>
struct
NumericLimits
{
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment