Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel_ROCM
Commits
bcc12098
Commit
bcc12098
authored
Jan 30, 2025
by
Rostyslav Geyyer
Browse files
Update pack/unpack methods
parent
acf8854e
Changes
1
Show whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
180 additions
and
36 deletions
+180
-36
include/ck/utility/data_type.hpp
include/ck/utility/data_type.hpp
+180
-36
No files found.
include/ck/utility/data_type.hpp
View file @
bcc12098
...
...
@@ -51,15 +51,24 @@ struct f6x16_pk_t
template
<
index_t
I
>
__host__
__device__
inline
f6_t
unpack
(
Number
<
I
>
)
{
static_assert
(
I
<
16
,
"Index
is
out of range."
);
static_assert
(
I
<
16
,
"Index out of range
for 16 f6_t elements
."
);
union
constexpr
int
num_bits_elem
=
6
;
constexpr
int
num_bits_vec_elem
=
32
;
constexpr
int
vector_size
=
3
;
constexpr
int
bit_pos
=
I
*
num_bits_elem
;
constexpr
int
arr_idx
=
bit_pos
/
num_bits_vec_elem
;
constexpr
int
bit_offset
=
bit_pos
%
num_bits_vec_elem
;
uint32_t
bits
=
data
.
At
(
Number
<
arr_idx
>
{})
>>
bit_offset
;
constexpr
int
overhang
=
bit_offset
+
num_bits_elem
-
num_bits_vec_elem
;
if
constexpr
(
overhang
>
0
&&
(
arr_idx
+
1
)
<
vector_size
)
{
StaticallyIndexedArray_v2
<
element_type
,
3
>
uint32_array
;
f6_t
f6_array
[
16
]
;
}
data_union
{
data
};
bits
|=
(
data
.
At
(
Number
<
arr_idx
+
1
>
{})
&
((
1u
<<
overhang
)
-
1
))
<<
(
num_bits_elem
-
overhang
)
;
}
return
data_union
.
f6_array
[
I
]
;
return
static_cast
<
f6_t
>
(
bits
&
0x3F
)
;
}
// V is vector_size and E is number of elements
...
...
@@ -68,14 +77,41 @@ struct f6x16_pk_t
{
static_assert
(
V
==
1
||
V
==
2
,
"Vector size must be 1 or 2."
);
static_assert
(
E
==
16
,
"Number of elements must be 16."
);
type
*
retval
=
reinterpret_cast
<
type
*>
(
x
);
return
*
retval
;
type
packed
{};
// for each of the 16 f6_t values, place its 6 bits in the correct position
ck
::
static_for
<
0
,
16
,
1
>
{}([
&
](
auto
i
)
{
uint32_t
bits
=
static_cast
<
uint32_t
>
(
x
[
i
])
&
0x3F
;
constexpr
int
num_bits_elem
=
6
;
constexpr
int
num_bits_vec_elem
=
32
;
constexpr
int
vector_size
=
3
;
constexpr
int
bit_pos
=
i
*
num_bits_elem
;
constexpr
int
arr_index
=
bit_pos
/
num_bits_vec_elem
;
constexpr
int
bit_offset
=
bit_pos
%
num_bits_vec_elem
;
constexpr
int
overhang
=
bit_offset
+
num_bits_elem
-
num_bits_vec_elem
;
uint32_t
old_value
=
packed
.
At
(
Number
<
arr_index
>
{});
// insert bits into the current 32-bit block
old_value
|=
(
bits
<<
bit_offset
);
packed
.
At
(
Number
<
arr_index
>
{})
=
old_value
;
// if it crosses into the next block, shift the remainder
if
constexpr
(
overhang
>
0
&&
(
arr_index
+
1
)
<
vector_size
)
{
uint32_t
next_value
=
packed
.
At
(
Number
<
arr_index
+
1
>
{});
next_value
|=
(
bits
>>
(
num_bits_elem
-
overhang
));
packed
.
At
(
Number
<
arr_index
+
1
>
{})
=
next_value
;
}
});
return
packed
;
}
};
struct
f6x32_pk_t
{
// store
16
elements of f6_t in an array of 6 uint32_t
// store
32
elements of f6_t in an array of 6 uint32_t
using
element_type
=
uint32_t
;
using
type
=
StaticallyIndexedArray_v2
<
element_type
,
6
>
;
type
data
;
...
...
@@ -85,15 +121,24 @@ struct f6x32_pk_t
template
<
index_t
I
>
__host__
__device__
inline
f6_t
unpack
(
Number
<
I
>
)
{
static_assert
(
I
<
32
,
"Index
is
out of range."
);
static_assert
(
I
<
32
,
"Index out of range
for 32 f6_t elements
."
);
union
constexpr
int
num_bits_elem
=
6
;
constexpr
int
num_bits_vec_elem
=
32
;
constexpr
int
vector_size
=
6
;
constexpr
int
bit_pos
=
I
*
num_bits_elem
;
constexpr
int
arr_idx
=
bit_pos
/
num_bits_vec_elem
;
constexpr
int
bit_offset
=
bit_pos
%
num_bits_vec_elem
;
uint32_t
bits
=
data
.
At
(
Number
<
arr_idx
>
{})
>>
bit_offset
;
constexpr
int
overhang
=
bit_offset
+
num_bits_elem
-
num_bits_vec_elem
;
if
constexpr
(
overhang
>
0
&&
(
arr_idx
+
1
)
<
vector_size
)
{
StaticallyIndexedArray_v2
<
element_type
,
6
>
uint32_array
;
f6_t
f6_array
[
32
]
;
}
data_union
{
data
};
bits
|=
(
data
.
At
(
Number
<
arr_idx
+
1
>
{})
&
((
1u
<<
overhang
)
-
1
))
<<
(
num_bits_elem
-
overhang
)
;
}
return
data_union
.
f6_array
[
I
]
;
return
static_cast
<
f6_t
>
(
bits
&
0x3F
)
;
}
// V is vector_size and E is number of elements
...
...
@@ -101,9 +146,36 @@ struct f6x32_pk_t
__host__
__device__
inline
type
pack
(
f6_t
*
x
)
{
static_assert
(
V
==
1
,
"Vector size must be 1."
);
static_assert
(
E
==
32
,
"Number of elements must be 16."
);
type
*
retval
=
reinterpret_cast
<
type
*>
(
x
);
return
*
retval
;
static_assert
(
E
==
32
,
"Number of elements must be 32."
);
type
packed
{};
// for each of the 32 f6_t values, place its 6 bits in the correct position
ck
::
static_for
<
0
,
32
,
1
>
{}([
&
](
auto
i
)
{
uint32_t
bits
=
static_cast
<
uint32_t
>
(
x
[
i
])
&
0x3F
;
constexpr
int
num_bits_elem
=
6
;
constexpr
int
num_bits_vec_elem
=
32
;
constexpr
int
vector_size
=
6
;
constexpr
int
bit_pos
=
i
*
num_bits_elem
;
constexpr
int
arr_index
=
bit_pos
/
num_bits_vec_elem
;
constexpr
int
bit_offset
=
bit_pos
%
num_bits_vec_elem
;
constexpr
int
overhang
=
bit_offset
+
num_bits_elem
-
num_bits_vec_elem
;
uint32_t
old_value
=
packed
.
At
(
Number
<
arr_index
>
{});
// insert bits into the current 32-bit block
old_value
|=
(
bits
<<
bit_offset
);
packed
.
At
(
Number
<
arr_index
>
{})
=
old_value
;
// if it crosses into the next block, shift the remainder
if
constexpr
(
overhang
>
0
&&
(
arr_index
+
1
)
<
vector_size
)
{
uint32_t
next_value
=
packed
.
At
(
Number
<
arr_index
+
1
>
{});
next_value
|=
(
bits
>>
(
num_bits_elem
-
overhang
));
packed
.
At
(
Number
<
arr_index
+
1
>
{})
=
next_value
;
}
});
return
packed
;
}
};
...
...
@@ -119,15 +191,24 @@ struct bf6x16_pk_t
template
<
index_t
I
>
__host__
__device__
inline
bf6_t
unpack
(
Number
<
I
>
)
{
static_assert
(
I
<
16
,
"Index
is
out of range."
);
static_assert
(
I
<
16
,
"Index out of range
for 16 f6_t elements
."
);
union
constexpr
int
num_bits_elem
=
6
;
constexpr
int
num_bits_vec_elem
=
32
;
constexpr
int
vector_size
=
3
;
constexpr
int
bit_pos
=
I
*
num_bits_elem
;
constexpr
int
arr_idx
=
bit_pos
/
num_bits_vec_elem
;
constexpr
int
bit_offset
=
bit_pos
%
num_bits_vec_elem
;
uint32_t
bits
=
data
.
At
(
Number
<
arr_idx
>
{})
>>
bit_offset
;
constexpr
int
overhang
=
bit_offset
+
num_bits_elem
-
num_bits_vec_elem
;
if
constexpr
(
overhang
>
0
&&
(
arr_idx
+
1
)
<
vector_size
)
{
StaticallyIndexedArray_v2
<
element_type
,
3
>
uint32_array
;
bf6_t
bf6_array
[
16
]
;
}
data_union
{
data
};
bits
|=
(
data
.
At
(
Number
<
arr_idx
+
1
>
{})
&
((
1u
<<
overhang
)
-
1
))
<<
(
num_bits_elem
-
overhang
)
;
}
return
data_union
.
bf6_array
[
I
]
;
return
static_cast
<
bf6_t
>
(
bits
&
0x3F
)
;
}
// V is vector_size and E is number of elements
...
...
@@ -136,14 +217,41 @@ struct bf6x16_pk_t
{
static_assert
(
V
==
1
||
V
==
2
,
"Vector size must be 1 or 2."
);
static_assert
(
E
==
16
,
"Number of elements must be 16."
);
type
*
retval
=
reinterpret_cast
<
type
*>
(
x
);
return
*
retval
;
type
packed
{};
// for each of the 16 bf6_t values, place its 6 bits in the correct position
ck
::
static_for
<
0
,
16
,
1
>
{}([
&
](
auto
i
)
{
uint32_t
bits
=
static_cast
<
uint32_t
>
(
x
[
i
])
&
0x3F
;
constexpr
int
num_bits_elem
=
6
;
constexpr
int
num_bits_vec_elem
=
32
;
constexpr
int
vector_size
=
3
;
constexpr
int
bit_pos
=
i
*
num_bits_elem
;
constexpr
int
arr_index
=
bit_pos
/
num_bits_vec_elem
;
constexpr
int
bit_offset
=
bit_pos
%
num_bits_vec_elem
;
constexpr
int
overhang
=
bit_offset
+
num_bits_elem
-
num_bits_vec_elem
;
uint32_t
old_value
=
packed
.
At
(
Number
<
arr_index
>
{});
// insert bits into the current 32-bit block
old_value
|=
(
bits
<<
bit_offset
);
packed
.
At
(
Number
<
arr_index
>
{})
=
old_value
;
// if it crosses into the next block, shift the remainder
if
constexpr
(
overhang
>
0
&&
(
arr_index
+
1
)
<
vector_size
)
{
uint32_t
next_value
=
packed
.
At
(
Number
<
arr_index
+
1
>
{});
next_value
|=
(
bits
>>
(
num_bits_elem
-
overhang
));
packed
.
At
(
Number
<
arr_index
+
1
>
{})
=
next_value
;
}
});
return
packed
;
}
};
struct
bf6x32_pk_t
{
// store
16
elements of bf6_t in an array of 6 uint32_t
// store
32
elements of bf6_t in an array of 6 uint32_t
using
element_type
=
uint32_t
;
using
type
=
StaticallyIndexedArray_v2
<
element_type
,
6
>
;
type
data
;
...
...
@@ -153,15 +261,24 @@ struct bf6x32_pk_t
template
<
index_t
I
>
__host__
__device__
inline
bf6_t
unpack
(
Number
<
I
>
)
{
static_assert
(
I
<
32
,
"Index
is
out of range."
);
static_assert
(
I
<
32
,
"Index out of range
for 32 f6_t elements
."
);
union
constexpr
int
num_bits_elem
=
6
;
constexpr
int
num_bits_vec_elem
=
32
;
constexpr
int
vector_size
=
6
;
constexpr
int
bit_pos
=
I
*
num_bits_elem
;
constexpr
int
arr_idx
=
bit_pos
/
num_bits_vec_elem
;
constexpr
int
bit_offset
=
bit_pos
%
num_bits_vec_elem
;
uint32_t
bits
=
data
.
At
(
Number
<
arr_idx
>
{})
>>
bit_offset
;
constexpr
int
overhang
=
bit_offset
+
num_bits_elem
-
num_bits_vec_elem
;
if
constexpr
(
overhang
>
0
&&
(
arr_idx
+
1
)
<
vector_size
)
{
StaticallyIndexedArray_v2
<
element_type
,
6
>
uint32_array
;
bf6_t
bf6_array
[
32
]
;
}
data_union
{
data
};
bits
|=
(
data
.
At
(
Number
<
arr_idx
+
1
>
{})
&
((
1u
<<
overhang
)
-
1
))
<<
(
num_bits_elem
-
overhang
)
;
}
return
data_union
.
bf6_array
[
I
]
;
return
static_cast
<
bf6_t
>
(
bits
&
0x3F
)
;
}
// V is vector_size and E is number of elements
...
...
@@ -169,9 +286,36 @@ struct bf6x32_pk_t
__host__
__device__
inline
type
pack
(
bf6_t
*
x
)
{
static_assert
(
V
==
1
,
"Vector size must be 1."
);
static_assert
(
E
==
32
,
"Number of elements must be 16."
);
type
*
retval
=
reinterpret_cast
<
type
*>
(
x
);
return
*
retval
;
static_assert
(
E
==
32
,
"Number of elements must be 32."
);
type
packed
{};
// for each of the 32 bf6_t values, place its 6 bits in the correct position
ck
::
static_for
<
0
,
32
,
1
>
{}([
&
](
auto
i
)
{
uint32_t
bits
=
static_cast
<
uint32_t
>
(
x
[
i
])
&
0x3F
;
constexpr
int
num_bits_elem
=
6
;
constexpr
int
num_bits_vec_elem
=
32
;
constexpr
int
vector_size
=
6
;
constexpr
int
bit_pos
=
i
*
num_bits_elem
;
constexpr
int
arr_index
=
bit_pos
/
num_bits_vec_elem
;
constexpr
int
bit_offset
=
bit_pos
%
num_bits_vec_elem
;
constexpr
int
overhang
=
bit_offset
+
num_bits_elem
-
num_bits_vec_elem
;
uint32_t
old_value
=
packed
.
At
(
Number
<
arr_index
>
{});
// insert bits into the current 32-bit block
old_value
|=
(
bits
<<
bit_offset
);
packed
.
At
(
Number
<
arr_index
>
{})
=
old_value
;
// if it crosses into the next block, shift the remainder
if
constexpr
(
overhang
>
0
&&
(
arr_index
+
1
)
<
vector_size
)
{
uint32_t
next_value
=
packed
.
At
(
Number
<
arr_index
+
1
>
{});
next_value
|=
(
bits
>>
(
num_bits_elem
-
overhang
));
packed
.
At
(
Number
<
arr_index
+
1
>
{})
=
next_value
;
}
});
return
packed
;
}
};
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment