Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel_ROCM
Commits
dda07ed0
Commit
dda07ed0
authored
Jun 11, 2024
by
Rostyslav Geyyer
Browse files
Use vector_type to cover non-native implementation as well
parent
fd019e14
Changes
5
Show whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
82 additions
and
85 deletions
+82
-85
include/ck/tensor_operation/gpu/block/blockwise_gemm_xdlops.hpp
...e/ck/tensor_operation/gpu/block/blockwise_gemm_xdlops.hpp
+7
-9
include/ck/utility/amd_xdlops.hpp
include/ck/utility/amd_xdlops.hpp
+16
-16
include/ck/utility/data_type.hpp
include/ck/utility/data_type.hpp
+54
-47
include/ck/utility/transpose_vectors.hpp
include/ck/utility/transpose_vectors.hpp
+4
-12
include/ck/utility/type_convert.hpp
include/ck/utility/type_convert.hpp
+1
-1
No files found.
include/ck/tensor_operation/gpu/block/blockwise_gemm_xdlops.hpp
View file @
dda07ed0
...
@@ -322,8 +322,8 @@ struct BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1
...
@@ -322,8 +322,8 @@ struct BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1
b_thread_buf
);
b_thread_buf
);
static_for
<
0
,
KPerThread
,
KPack
>
{}([
&
](
auto
k
)
{
static_for
<
0
,
KPerThread
,
KPack
>
{}([
&
](
auto
k
)
{
non_native_
vector_type
<
ComputeTypeA
,
KPack
>
a_thread_vec
;
vector_type
<
ComputeTypeA
,
KPack
>
a_thread_vec
;
non_native_
vector_type
<
ComputeTypeB
,
KPack
>
b_thread_vec
;
vector_type
<
ComputeTypeB
,
KPack
>
b_thread_vec
;
static_for
<
0
,
KPack
,
1
>
{}([
&
](
auto
i
)
{
static_for
<
0
,
KPack
,
1
>
{}([
&
](
auto
i
)
{
a_thread_vec
.
template
AsType
<
ComputeTypeA
>()(
i
)
=
a_thread_buf
a_thread_vec
.
template
AsType
<
ComputeTypeA
>()(
i
)
=
a_thread_buf
...
@@ -333,11 +333,9 @@ struct BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1
...
@@ -333,11 +333,9 @@ struct BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1
});
});
using
mfma_input_type_a
=
using
mfma_input_type_a
=
typename
non_native_vector_type
<
ComputeTypeA
,
typename
vector_type
<
ComputeTypeA
,
xdlops_gemm
.
K1PerXdlops
>::
type
;
xdlops_gemm
.
K1PerXdlops
>::
type
;
using
mfma_input_type_b
=
using
mfma_input_type_b
=
typename
non_native_vector_type
<
ComputeTypeB
,
typename
vector_type
<
ComputeTypeB
,
xdlops_gemm
.
K1PerXdlops
>::
type
;
xdlops_gemm
.
K1PerXdlops
>::
type
;
constexpr
index_t
c_offset
=
constexpr
index_t
c_offset
=
c_thread_desc_
.
CalculateOffset
(
make_tuple
(
m0
,
n0
,
0
));
c_thread_desc_
.
CalculateOffset
(
make_tuple
(
m0
,
n0
,
0
));
...
@@ -949,8 +947,8 @@ struct BlockwiseGemmXdlops_v2
...
@@ -949,8 +947,8 @@ struct BlockwiseGemmXdlops_v2
b_thread_desc_
,
b_thread_desc_
,
make_tuple
(
I0
,
I0
,
I0
,
I0
),
make_tuple
(
I0
,
I0
,
I0
,
I0
),
b_thread_buf
);
b_thread_buf
);
non_native_
vector_type
<
FloatAB
,
KPack
>
a_thread_vec
;
vector_type
<
FloatAB
,
KPack
>
a_thread_vec
;
non_native_
vector_type
<
FloatAB
,
KPack
>
b_thread_vec
;
vector_type
<
FloatAB
,
KPack
>
b_thread_vec
;
static_for
<
0
,
KPack
,
1
>
{}([
&
](
auto
i
)
{
static_for
<
0
,
KPack
,
1
>
{}([
&
](
auto
i
)
{
a_thread_vec
.
template
AsType
<
FloatAB
>()(
i
)
=
a_thread_buf
a_thread_vec
.
template
AsType
<
FloatAB
>()(
i
)
=
a_thread_buf
...
@@ -960,7 +958,7 @@ struct BlockwiseGemmXdlops_v2
...
@@ -960,7 +958,7 @@ struct BlockwiseGemmXdlops_v2
});
});
using
mfma_input_type
=
using
mfma_input_type
=
typename
non_native_
vector_type
<
FloatAB
,
xdlops_gemm
.
K1PerXdlops
>::
type
;
typename
vector_type
<
FloatAB
,
xdlops_gemm
.
K1PerXdlops
>::
type
;
constexpr
index_t
c_offset
=
constexpr
index_t
c_offset
=
c_thread_desc_
.
CalculateOffset
(
make_tuple
(
m0
,
n0
,
0
));
c_thread_desc_
.
CalculateOffset
(
make_tuple
(
m0
,
n0
,
0
));
...
...
include/ck/utility/amd_xdlops.hpp
View file @
dda07ed0
...
@@ -375,8 +375,8 @@ struct intrin_mfma_f32_32x32x16f8f8<32, 32>
...
@@ -375,8 +375,8 @@ struct intrin_mfma_f32_32x32x16f8f8<32, 32>
0
,
0
,
0
);
0
);
#else
#else
non_native_
vector_type
<
f8_t
,
8
>
reg_a_v
(
reg_a
);
vector_type
<
f8_t
,
8
>
reg_a_v
(
reg_a
);
non_native_
vector_type
<
f8_t
,
8
>
reg_b_v
(
reg_b
);
vector_type
<
f8_t
,
8
>
reg_b_v
(
reg_b
);
static_for
<
0
,
8
,
1
>
{}([
&
](
auto
k
)
{
static_for
<
0
,
8
,
1
>
{}([
&
](
auto
k
)
{
float
reg_a_f32
=
type_convert
<
float
>
(
reg_a_v
.
template
AsType
<
f8_t
>()[
Number
<
k
>
{}]);
float
reg_a_f32
=
type_convert
<
float
>
(
reg_a_v
.
template
AsType
<
f8_t
>()[
Number
<
k
>
{}]);
...
@@ -406,8 +406,8 @@ struct intrin_mfma_f32_16x16x32f8f8<16, 16>
...
@@ -406,8 +406,8 @@ struct intrin_mfma_f32_16x16x32f8f8<16, 16>
0
,
0
,
0
);
0
);
#else
#else
non_native_
vector_type
<
f8_t
,
8
>
reg_a_v
(
reg_a
);
vector_type
<
f8_t
,
8
>
reg_a_v
(
reg_a
);
non_native_
vector_type
<
f8_t
,
8
>
reg_b_v
(
reg_b
);
vector_type
<
f8_t
,
8
>
reg_b_v
(
reg_b
);
static_for
<
0
,
8
,
1
>
{}([
&
](
auto
k
)
{
static_for
<
0
,
8
,
1
>
{}([
&
](
auto
k
)
{
float
reg_a_f32
=
type_convert
<
float
>
(
reg_a_v
.
template
AsType
<
f8_t
>()[
Number
<
k
>
{}]);
float
reg_a_f32
=
type_convert
<
float
>
(
reg_a_v
.
template
AsType
<
f8_t
>()[
Number
<
k
>
{}]);
...
@@ -438,8 +438,8 @@ struct intrin_mfma_f32_32x32x16bf8bf8<32, 32>
...
@@ -438,8 +438,8 @@ struct intrin_mfma_f32_32x32x16bf8bf8<32, 32>
0
,
0
,
0
);
0
);
#else
#else
non_native_
vector_type
<
bf8_t
,
8
>
reg_a_v
(
reg_a
);
vector_type
<
bf8_t
,
8
>
reg_a_v
(
reg_a
);
non_native_
vector_type
<
bf8_t
,
8
>
reg_b_v
(
reg_b
);
vector_type
<
bf8_t
,
8
>
reg_b_v
(
reg_b
);
static_for
<
0
,
8
,
1
>
{}([
&
](
auto
k
)
{
static_for
<
0
,
8
,
1
>
{}([
&
](
auto
k
)
{
float
reg_a_f32
=
type_convert
<
float
>
(
reg_a_v
.
template
AsType
<
bf8_t
>()[
Number
<
k
>
{}]);
float
reg_a_f32
=
type_convert
<
float
>
(
reg_a_v
.
template
AsType
<
bf8_t
>()[
Number
<
k
>
{}]);
...
@@ -469,8 +469,8 @@ struct intrin_mfma_f32_16x16x32bf8bf8<16, 16>
...
@@ -469,8 +469,8 @@ struct intrin_mfma_f32_16x16x32bf8bf8<16, 16>
0
,
0
,
0
);
0
);
#else
#else
non_native_
vector_type
<
bf8_t
,
8
>
reg_a_v
(
reg_a
);
vector_type
<
bf8_t
,
8
>
reg_a_v
(
reg_a
);
non_native_
vector_type
<
bf8_t
,
8
>
reg_b_v
(
reg_b
);
vector_type
<
bf8_t
,
8
>
reg_b_v
(
reg_b
);
static_for
<
0
,
8
,
1
>
{}([
&
](
auto
k
)
{
static_for
<
0
,
8
,
1
>
{}([
&
](
auto
k
)
{
float
reg_a_f32
=
type_convert
<
float
>
(
reg_a_v
.
template
AsType
<
bf8_t
>()[
Number
<
k
>
{}]);
float
reg_a_f32
=
type_convert
<
float
>
(
reg_a_v
.
template
AsType
<
bf8_t
>()[
Number
<
k
>
{}]);
...
@@ -501,8 +501,8 @@ struct intrin_mfma_f32_32x32x16f8bf8<32, 32>
...
@@ -501,8 +501,8 @@ struct intrin_mfma_f32_32x32x16f8bf8<32, 32>
0
,
0
,
0
);
0
);
#else
#else
non_native_
vector_type
<
f8_t
,
8
>
reg_a_v
(
reg_a
);
vector_type
<
f8_t
,
8
>
reg_a_v
(
reg_a
);
non_native_
vector_type
<
bf8_t
,
8
>
reg_b_v
(
reg_b
);
vector_type
<
bf8_t
,
8
>
reg_b_v
(
reg_b
);
static_for
<
0
,
8
,
1
>
{}([
&
](
auto
k
)
{
static_for
<
0
,
8
,
1
>
{}([
&
](
auto
k
)
{
float
reg_a_f32
=
type_convert
<
float
>
(
reg_a_v
.
template
AsType
<
f8_t
>()[
Number
<
k
>
{}]);
float
reg_a_f32
=
type_convert
<
float
>
(
reg_a_v
.
template
AsType
<
f8_t
>()[
Number
<
k
>
{}]);
...
@@ -532,8 +532,8 @@ struct intrin_mfma_f32_16x16x32f8bf8<16, 16>
...
@@ -532,8 +532,8 @@ struct intrin_mfma_f32_16x16x32f8bf8<16, 16>
0
,
0
,
0
);
0
);
#else
#else
non_native_
vector_type
<
f8_t
,
8
>
reg_a_v
(
reg_a
);
vector_type
<
f8_t
,
8
>
reg_a_v
(
reg_a
);
non_native_
vector_type
<
bf8_t
,
8
>
reg_b_v
(
reg_b
);
vector_type
<
bf8_t
,
8
>
reg_b_v
(
reg_b
);
static_for
<
0
,
8
,
1
>
{}([
&
](
auto
k
)
{
static_for
<
0
,
8
,
1
>
{}([
&
](
auto
k
)
{
float
reg_a_f32
=
type_convert
<
float
>
(
reg_a_v
.
template
AsType
<
f8_t
>()[
Number
<
k
>
{}]);
float
reg_a_f32
=
type_convert
<
float
>
(
reg_a_v
.
template
AsType
<
f8_t
>()[
Number
<
k
>
{}]);
...
@@ -564,8 +564,8 @@ struct intrin_mfma_f32_32x32x16bf8f8<32, 32>
...
@@ -564,8 +564,8 @@ struct intrin_mfma_f32_32x32x16bf8f8<32, 32>
0
,
0
,
0
);
0
);
#else
#else
non_native_
vector_type
<
bf8_t
,
8
>
reg_a_v
(
reg_a
);
vector_type
<
bf8_t
,
8
>
reg_a_v
(
reg_a
);
non_native_
vector_type
<
f8_t
,
8
>
reg_b_v
(
reg_b
);
vector_type
<
f8_t
,
8
>
reg_b_v
(
reg_b
);
static_for
<
0
,
8
,
1
>
{}([
&
](
auto
k
)
{
static_for
<
0
,
8
,
1
>
{}([
&
](
auto
k
)
{
float
reg_a_f32
=
type_convert
<
float
>
(
reg_a_v
.
template
AsType
<
bf8_t
>()[
Number
<
k
>
{}]);
float
reg_a_f32
=
type_convert
<
float
>
(
reg_a_v
.
template
AsType
<
bf8_t
>()[
Number
<
k
>
{}]);
...
@@ -595,8 +595,8 @@ struct intrin_mfma_f32_16x16x32bf8f8<16, 16>
...
@@ -595,8 +595,8 @@ struct intrin_mfma_f32_16x16x32bf8f8<16, 16>
0
,
0
,
0
);
0
);
#else
#else
non_native_
vector_type
<
bf8_t
,
8
>
reg_a_v
(
reg_a
);
vector_type
<
bf8_t
,
8
>
reg_a_v
(
reg_a
);
non_native_
vector_type
<
f8_t
,
8
>
reg_b_v
(
reg_b
);
vector_type
<
f8_t
,
8
>
reg_b_v
(
reg_b
);
static_for
<
0
,
8
,
1
>
{}([
&
](
auto
k
)
{
static_for
<
0
,
8
,
1
>
{}([
&
](
auto
k
)
{
float
reg_a_f32
=
type_convert
<
float
>
(
reg_a_v
.
template
AsType
<
bf8_t
>()[
Number
<
k
>
{}]);
float
reg_a_f32
=
type_convert
<
float
>
(
reg_a_v
.
template
AsType
<
bf8_t
>()[
Number
<
k
>
{}]);
...
...
include/ck/utility/data_type.hpp
View file @
dda07ed0
...
@@ -19,8 +19,18 @@ inline constexpr auto next_pow2(uint32_t x)
...
@@ -19,8 +19,18 @@ inline constexpr auto next_pow2(uint32_t x)
return
x
>
1u
?
(
1u
<<
(
32u
-
__builtin_clz
(
x
-
1u
)))
:
x
;
return
x
>
1u
?
(
1u
<<
(
32u
-
__builtin_clz
(
x
-
1u
)))
:
x
;
}
}
// native types: double, float, _Float16, ushort, int32_t, int8_t, uint8_t, bool
template
<
typename
T
>
inline
constexpr
bool
is_native_type
()
{
return
is_same
<
T
,
double
>::
value
||
is_same
<
T
,
float
>::
value
||
is_same
<
T
,
half_t
>::
value
||
is_same
<
T
,
bhalf_t
>::
value
||
is_same
<
T
,
int32_t
>::
value
||
is_same
<
T
,
int8_t
>::
value
||
is_same
<
T
,
uint8_t
>::
value
||
is_same
<
T
,
_BitInt
(
8
)
>::
value
||
is_same
<
T
,
unsigned
_BitInt
(
8
)
>::
value
||
is_same
<
T
,
bool
>::
value
;
}
// vector_type
// vector_type
template
<
typename
T
,
index_t
N
>
template
<
typename
T
,
index_t
N
,
typename
Enable
=
void
>
struct
vector_type
;
struct
vector_type
;
// Caution: DO NOT REMOVE
// Caution: DO NOT REMOVE
...
@@ -177,7 +187,7 @@ struct scalar_type<bool>
...
@@ -177,7 +187,7 @@ struct scalar_type<bool>
};
};
template
<
typename
T
>
template
<
typename
T
>
struct
vector_type
<
T
,
1
>
struct
vector_type
<
T
,
1
,
typename
std
::
enable_if_t
<
is_native_type
<
T
>
()
>
>
{
{
using
d1_t
=
T
;
using
d1_t
=
T
;
using
type
=
d1_t
;
using
type
=
d1_t
;
...
@@ -211,7 +221,7 @@ struct vector_type<T, 1>
...
@@ -211,7 +221,7 @@ struct vector_type<T, 1>
int
static
err
=
0
;
int
static
err
=
0
;
template
<
typename
T
>
template
<
typename
T
>
struct
vector_type
<
T
,
2
>
struct
vector_type
<
T
,
2
,
typename
std
::
enable_if_t
<
is_native_type
<
T
>
()
>
>
{
{
using
d1_t
=
T
;
using
d1_t
=
T
;
typedef
T
d2_t
__attribute__
((
ext_vector_type
(
2
)));
typedef
T
d2_t
__attribute__
((
ext_vector_type
(
2
)));
...
@@ -269,7 +279,7 @@ struct vector_type<T, 2>
...
@@ -269,7 +279,7 @@ struct vector_type<T, 2>
};
};
template
<
typename
T
>
template
<
typename
T
>
struct
vector_type
<
T
,
4
>
struct
vector_type
<
T
,
4
,
typename
std
::
enable_if_t
<
is_native_type
<
T
>
()
>
>
{
{
using
d1_t
=
T
;
using
d1_t
=
T
;
typedef
T
d2_t
__attribute__
((
ext_vector_type
(
2
)));
typedef
T
d2_t
__attribute__
((
ext_vector_type
(
2
)));
...
@@ -339,7 +349,7 @@ struct vector_type<T, 4>
...
@@ -339,7 +349,7 @@ struct vector_type<T, 4>
};
};
template
<
typename
T
>
template
<
typename
T
>
struct
vector_type
<
T
,
8
>
struct
vector_type
<
T
,
8
,
typename
std
::
enable_if_t
<
is_native_type
<
T
>
()
>
>
{
{
using
d1_t
=
T
;
using
d1_t
=
T
;
typedef
T
d2_t
__attribute__
((
ext_vector_type
(
2
)));
typedef
T
d2_t
__attribute__
((
ext_vector_type
(
2
)));
...
@@ -421,7 +431,7 @@ struct vector_type<T, 8>
...
@@ -421,7 +431,7 @@ struct vector_type<T, 8>
};
};
template
<
typename
T
>
template
<
typename
T
>
struct
vector_type
<
T
,
16
>
struct
vector_type
<
T
,
16
,
typename
std
::
enable_if_t
<
is_native_type
<
T
>
()
>
>
{
{
using
d1_t
=
T
;
using
d1_t
=
T
;
typedef
T
d2_t
__attribute__
((
ext_vector_type
(
2
)));
typedef
T
d2_t
__attribute__
((
ext_vector_type
(
2
)));
...
@@ -515,7 +525,7 @@ struct vector_type<T, 16>
...
@@ -515,7 +525,7 @@ struct vector_type<T, 16>
};
};
template
<
typename
T
>
template
<
typename
T
>
struct
vector_type
<
T
,
32
>
struct
vector_type
<
T
,
32
,
typename
std
::
enable_if_t
<
is_native_type
<
T
>
()
>
>
{
{
using
d1_t
=
T
;
using
d1_t
=
T
;
typedef
T
d2_t
__attribute__
((
ext_vector_type
(
2
)));
typedef
T
d2_t
__attribute__
((
ext_vector_type
(
2
)));
...
@@ -619,7 +629,7 @@ struct vector_type<T, 32>
...
@@ -619,7 +629,7 @@ struct vector_type<T, 32>
};
};
template
<
typename
T
>
template
<
typename
T
>
struct
vector_type
<
T
,
64
>
struct
vector_type
<
T
,
64
,
typename
std
::
enable_if_t
<
is_native_type
<
T
>
()
>
>
{
{
using
d1_t
=
T
;
using
d1_t
=
T
;
typedef
T
d2_t
__attribute__
((
ext_vector_type
(
2
)));
typedef
T
d2_t
__attribute__
((
ext_vector_type
(
2
)));
...
@@ -735,7 +745,7 @@ struct vector_type<T, 64>
...
@@ -735,7 +745,7 @@ struct vector_type<T, 64>
};
};
template
<
typename
T
>
template
<
typename
T
>
struct
vector_type
<
T
,
128
>
struct
vector_type
<
T
,
128
,
typename
std
::
enable_if_t
<
is_native_type
<
T
>
()
>
>
{
{
using
d1_t
=
T
;
using
d1_t
=
T
;
typedef
T
d2_t
__attribute__
((
ext_vector_type
(
2
)));
typedef
T
d2_t
__attribute__
((
ext_vector_type
(
2
)));
...
@@ -861,7 +871,7 @@ struct vector_type<T, 128>
...
@@ -861,7 +871,7 @@ struct vector_type<T, 128>
};
};
template
<
typename
T
>
template
<
typename
T
>
struct
vector_type
<
T
,
256
>
struct
vector_type
<
T
,
256
,
typename
std
::
enable_if_t
<
is_native_type
<
T
>
()
>
>
{
{
using
d1_t
=
T
;
using
d1_t
=
T
;
typedef
T
d2_t
__attribute__
((
ext_vector_type
(
2
)));
typedef
T
d2_t
__attribute__
((
ext_vector_type
(
2
)));
...
@@ -1013,12 +1023,9 @@ struct non_native_vector_base
...
@@ -1013,12 +1023,9 @@ struct non_native_vector_base
T
d
[
N
];
T
d
[
N
];
};
};
// non-native vector_type
// non-native vector_type implementation
template
<
typename
T
,
index_t
N
>
struct
non_native_vector_type
;
template
<
typename
T
>
template
<
typename
T
>
struct
non_native_
vector_type
<
T
,
1
>
struct
vector_type
<
T
,
1
,
typename
std
::
enable_if_t
<!
is_native_type
<
T
>
()
>
>
{
{
using
Native_vec_
=
non_native_vector_base
<
T
,
1
>
;
using
Native_vec_
=
non_native_vector_base
<
T
,
1
>
;
...
@@ -1031,9 +1038,9 @@ struct non_native_vector_type<T, 1>
...
@@ -1031,9 +1038,9 @@ struct non_native_vector_type<T, 1>
StaticallyIndexedArray
<
d1_t
,
1
>
d1x1_
;
StaticallyIndexedArray
<
d1_t
,
1
>
d1x1_
;
}
data_
;
}
data_
;
__host__
__device__
constexpr
non_native_
vector_type
()
:
data_
{
type
{}}
{}
__host__
__device__
constexpr
vector_type
()
:
data_
{
type
{}}
{}
__host__
__device__
constexpr
non_native_
vector_type
(
type
v
)
:
data_
{
v
}
{}
__host__
__device__
constexpr
vector_type
(
type
v
)
:
data_
{
v
}
{}
template
<
typename
X
>
template
<
typename
X
>
__host__
__device__
constexpr
const
auto
&
AsType
()
const
__host__
__device__
constexpr
const
auto
&
AsType
()
const
...
@@ -1053,7 +1060,7 @@ struct non_native_vector_type<T, 1>
...
@@ -1053,7 +1060,7 @@ struct non_native_vector_type<T, 1>
};
};
template
<
typename
T
>
template
<
typename
T
>
struct
non_native_
vector_type
<
T
,
2
>
struct
vector_type
<
T
,
2
,
typename
std
::
enable_if_t
<!
is_native_type
<
T
>
()
>
>
{
{
using
Native_vec_
=
non_native_vector_base
<
T
,
2
>
;
using
Native_vec_
=
non_native_vector_base
<
T
,
2
>
;
...
@@ -1069,9 +1076,9 @@ struct non_native_vector_type<T, 2>
...
@@ -1069,9 +1076,9 @@ struct non_native_vector_type<T, 2>
StaticallyIndexedArray
<
d2_t
,
1
>
d2x1_
;
StaticallyIndexedArray
<
d2_t
,
1
>
d2x1_
;
}
data_
;
}
data_
;
__host__
__device__
constexpr
non_native_
vector_type
()
:
data_
{
type
{}}
{}
__host__
__device__
constexpr
vector_type
()
:
data_
{
type
{}}
{}
__host__
__device__
constexpr
non_native_
vector_type
(
type
v
)
:
data_
{
v
}
{}
__host__
__device__
constexpr
vector_type
(
type
v
)
:
data_
{
v
}
{}
template
<
typename
X
>
template
<
typename
X
>
__host__
__device__
constexpr
const
auto
&
AsType
()
const
__host__
__device__
constexpr
const
auto
&
AsType
()
const
...
@@ -1113,7 +1120,7 @@ struct non_native_vector_type<T, 2>
...
@@ -1113,7 +1120,7 @@ struct non_native_vector_type<T, 2>
};
};
template
<
typename
T
>
template
<
typename
T
>
struct
non_native_
vector_type
<
T
,
4
>
struct
vector_type
<
T
,
4
,
typename
std
::
enable_if_t
<!
is_native_type
<
T
>
()
>
>
{
{
using
Native_vec_
=
non_native_vector_base
<
T
,
4
>
;
using
Native_vec_
=
non_native_vector_base
<
T
,
4
>
;
...
@@ -1131,9 +1138,9 @@ struct non_native_vector_type<T, 4>
...
@@ -1131,9 +1138,9 @@ struct non_native_vector_type<T, 4>
StaticallyIndexedArray
<
d4_t
,
1
>
d4x1_
;
StaticallyIndexedArray
<
d4_t
,
1
>
d4x1_
;
}
data_
;
}
data_
;
__host__
__device__
constexpr
non_native_
vector_type
()
:
data_
{
type
{}}
{}
__host__
__device__
constexpr
vector_type
()
:
data_
{
type
{}}
{}
__host__
__device__
constexpr
non_native_
vector_type
(
type
v
)
:
data_
{
v
}
{}
__host__
__device__
constexpr
vector_type
(
type
v
)
:
data_
{
v
}
{}
template
<
typename
X
>
template
<
typename
X
>
__host__
__device__
constexpr
const
auto
&
AsType
()
const
__host__
__device__
constexpr
const
auto
&
AsType
()
const
...
@@ -1185,7 +1192,7 @@ struct non_native_vector_type<T, 4>
...
@@ -1185,7 +1192,7 @@ struct non_native_vector_type<T, 4>
};
};
template
<
typename
T
>
template
<
typename
T
>
struct
non_native_
vector_type
<
T
,
8
>
struct
vector_type
<
T
,
8
,
typename
std
::
enable_if_t
<!
is_native_type
<
T
>
()
>
>
{
{
using
Native_vec_
=
non_native_vector_base
<
T
,
8
>
;
using
Native_vec_
=
non_native_vector_base
<
T
,
8
>
;
...
@@ -1205,9 +1212,9 @@ struct non_native_vector_type<T, 8>
...
@@ -1205,9 +1212,9 @@ struct non_native_vector_type<T, 8>
StaticallyIndexedArray
<
d8_t
,
1
>
d8x1_
;
StaticallyIndexedArray
<
d8_t
,
1
>
d8x1_
;
}
data_
;
}
data_
;
__host__
__device__
constexpr
non_native_
vector_type
()
:
data_
{
type
{}}
{}
__host__
__device__
constexpr
vector_type
()
:
data_
{
type
{}}
{}
__host__
__device__
constexpr
non_native_
vector_type
(
type
v
)
:
data_
{
v
}
{}
__host__
__device__
constexpr
vector_type
(
type
v
)
:
data_
{
v
}
{}
template
<
typename
X
>
template
<
typename
X
>
__host__
__device__
constexpr
const
auto
&
AsType
()
const
__host__
__device__
constexpr
const
auto
&
AsType
()
const
...
@@ -1269,7 +1276,7 @@ struct non_native_vector_type<T, 8>
...
@@ -1269,7 +1276,7 @@ struct non_native_vector_type<T, 8>
};
};
template
<
typename
T
>
template
<
typename
T
>
struct
non_native_
vector_type
<
T
,
16
>
struct
vector_type
<
T
,
16
,
typename
std
::
enable_if_t
<!
is_native_type
<
T
>
()
>
>
{
{
using
Native_vec_
=
non_native_vector_base
<
T
,
16
>
;
using
Native_vec_
=
non_native_vector_base
<
T
,
16
>
;
...
@@ -1291,9 +1298,9 @@ struct non_native_vector_type<T, 16>
...
@@ -1291,9 +1298,9 @@ struct non_native_vector_type<T, 16>
StaticallyIndexedArray
<
d16_t
,
1
>
d16x1_
;
StaticallyIndexedArray
<
d16_t
,
1
>
d16x1_
;
}
data_
;
}
data_
;
__host__
__device__
constexpr
non_native_
vector_type
()
:
data_
{
type
{}}
{}
__host__
__device__
constexpr
vector_type
()
:
data_
{
type
{}}
{}
__host__
__device__
constexpr
non_native_
vector_type
(
type
v
)
:
data_
{
v
}
{}
__host__
__device__
constexpr
vector_type
(
type
v
)
:
data_
{
v
}
{}
template
<
typename
X
>
template
<
typename
X
>
__host__
__device__
constexpr
const
auto
&
AsType
()
const
__host__
__device__
constexpr
const
auto
&
AsType
()
const
...
@@ -1365,7 +1372,7 @@ struct non_native_vector_type<T, 16>
...
@@ -1365,7 +1372,7 @@ struct non_native_vector_type<T, 16>
};
};
template
<
typename
T
>
template
<
typename
T
>
struct
non_native_
vector_type
<
T
,
32
>
struct
vector_type
<
T
,
32
,
typename
std
::
enable_if_t
<!
is_native_type
<
T
>
()
>
>
{
{
using
Native_vec_
=
non_native_vector_base
<
T
,
32
>
;
using
Native_vec_
=
non_native_vector_base
<
T
,
32
>
;
...
@@ -1389,9 +1396,9 @@ struct non_native_vector_type<T, 32>
...
@@ -1389,9 +1396,9 @@ struct non_native_vector_type<T, 32>
StaticallyIndexedArray
<
d32_t
,
1
>
d32x1_
;
StaticallyIndexedArray
<
d32_t
,
1
>
d32x1_
;
}
data_
;
}
data_
;
__host__
__device__
constexpr
non_native_
vector_type
()
:
data_
{
type
{}}
{}
__host__
__device__
constexpr
vector_type
()
:
data_
{
type
{}}
{}
__host__
__device__
constexpr
non_native_
vector_type
(
type
v
)
:
data_
{
v
}
{}
__host__
__device__
constexpr
vector_type
(
type
v
)
:
data_
{
v
}
{}
template
<
typename
X
>
template
<
typename
X
>
__host__
__device__
constexpr
const
auto
&
AsType
()
const
__host__
__device__
constexpr
const
auto
&
AsType
()
const
...
@@ -1471,7 +1478,7 @@ struct non_native_vector_type<T, 32>
...
@@ -1471,7 +1478,7 @@ struct non_native_vector_type<T, 32>
};
};
template
<
typename
T
>
template
<
typename
T
>
struct
non_native_
vector_type
<
T
,
64
>
struct
vector_type
<
T
,
64
,
typename
std
::
enable_if_t
<!
is_native_type
<
T
>
()
>
>
{
{
using
Native_vec_
=
non_native_vector_base
<
T
,
64
>
;
using
Native_vec_
=
non_native_vector_base
<
T
,
64
>
;
...
@@ -1497,9 +1504,9 @@ struct non_native_vector_type<T, 64>
...
@@ -1497,9 +1504,9 @@ struct non_native_vector_type<T, 64>
StaticallyIndexedArray
<
d64_t
,
1
>
d64x1_
;
StaticallyIndexedArray
<
d64_t
,
1
>
d64x1_
;
}
data_
;
}
data_
;
__host__
__device__
constexpr
non_native_
vector_type
()
:
data_
{
type
{}}
{}
__host__
__device__
constexpr
vector_type
()
:
data_
{
type
{}}
{}
__host__
__device__
constexpr
non_native_
vector_type
(
type
v
)
:
data_
{
v
}
{}
__host__
__device__
constexpr
vector_type
(
type
v
)
:
data_
{
v
}
{}
template
<
typename
X
>
template
<
typename
X
>
__host__
__device__
constexpr
const
auto
&
AsType
()
const
__host__
__device__
constexpr
const
auto
&
AsType
()
const
...
@@ -1641,12 +1648,12 @@ using int8x64_t = typename vector_type<int8_t, 64>::type;
...
@@ -1641,12 +1648,12 @@ using int8x64_t = typename vector_type<int8_t, 64>::type;
// using f8x16_t = typename vector_type<f8_t, 16>::type;
// using f8x16_t = typename vector_type<f8_t, 16>::type;
// using f8x32_t = typename vector_type<f8_t, 32>::type;
// using f8x32_t = typename vector_type<f8_t, 32>::type;
// using f8x64_t = typename vector_type<f8_t, 64>::type;
// using f8x64_t = typename vector_type<f8_t, 64>::type;
using
f8x2_t
=
typename
non_native_
vector_type
<
f8_t
,
2
>::
type
;
using
f8x2_t
=
typename
vector_type
<
f8_t
,
2
>::
type
;
using
f8x4_t
=
typename
non_native_
vector_type
<
f8_t
,
4
>::
type
;
using
f8x4_t
=
typename
vector_type
<
f8_t
,
4
>::
type
;
using
f8x8_t
=
typename
non_native_
vector_type
<
f8_t
,
8
>::
type
;
using
f8x8_t
=
typename
vector_type
<
f8_t
,
8
>::
type
;
using
f8x16_t
=
typename
non_native_
vector_type
<
f8_t
,
16
>::
type
;
using
f8x16_t
=
typename
vector_type
<
f8_t
,
16
>::
type
;
using
f8x32_t
=
typename
non_native_
vector_type
<
f8_t
,
32
>::
type
;
using
f8x32_t
=
typename
vector_type
<
f8_t
,
32
>::
type
;
using
f8x64_t
=
typename
non_native_
vector_type
<
f8_t
,
64
>::
type
;
using
f8x64_t
=
typename
vector_type
<
f8_t
,
64
>::
type
;
// bf8
// bf8
// using bf8x2_t = typename vector_type<bf8_t, 2>::type;
// using bf8x2_t = typename vector_type<bf8_t, 2>::type;
...
@@ -1655,12 +1662,12 @@ using f8x64_t = typename non_native_vector_type<f8_t, 64>::type;
...
@@ -1655,12 +1662,12 @@ using f8x64_t = typename non_native_vector_type<f8_t, 64>::type;
// using bf8x16_t = typename vector_type<bf8_t, 16>::type;
// using bf8x16_t = typename vector_type<bf8_t, 16>::type;
// using bf8x32_t = typename vector_type<bf8_t, 32>::type;
// using bf8x32_t = typename vector_type<bf8_t, 32>::type;
// using bf8x64_t = typename vector_type<bf8_t, 64>::type;
// using bf8x64_t = typename vector_type<bf8_t, 64>::type;
using
bf8x2_t
=
typename
non_native_
vector_type
<
bf8_t
,
2
>::
type
;
using
bf8x2_t
=
typename
vector_type
<
bf8_t
,
2
>::
type
;
using
bf8x4_t
=
typename
non_native_
vector_type
<
bf8_t
,
4
>::
type
;
using
bf8x4_t
=
typename
vector_type
<
bf8_t
,
4
>::
type
;
using
bf8x8_t
=
typename
non_native_
vector_type
<
bf8_t
,
8
>::
type
;
using
bf8x8_t
=
typename
vector_type
<
bf8_t
,
8
>::
type
;
using
bf8x16_t
=
typename
non_native_
vector_type
<
bf8_t
,
16
>::
type
;
using
bf8x16_t
=
typename
vector_type
<
bf8_t
,
16
>::
type
;
using
bf8x32_t
=
typename
non_native_
vector_type
<
bf8_t
,
32
>::
type
;
using
bf8x32_t
=
typename
vector_type
<
bf8_t
,
32
>::
type
;
using
bf8x64_t
=
typename
non_native_
vector_type
<
bf8_t
,
64
>::
type
;
using
bf8x64_t
=
typename
vector_type
<
bf8_t
,
64
>::
type
;
// u8
// u8
// i8
// i8
...
...
include/ck/utility/transpose_vectors.hpp
View file @
dda07ed0
...
@@ -192,18 +192,10 @@ __device__ void transpose_f8_4x4(const f8x4_t& x0,
...
@@ -192,18 +192,10 @@ __device__ void transpose_f8_4x4(const f8x4_t& x0,
z2
=
__builtin_amdgcn_perm
(
bit_cast
<
int32_t
>
(
t1
),
bit_cast
<
int32_t
>
(
t0
),
m1
);
z2
=
__builtin_amdgcn_perm
(
bit_cast
<
int32_t
>
(
t1
),
bit_cast
<
int32_t
>
(
t0
),
m1
);
z3
=
__builtin_amdgcn_perm
(
bit_cast
<
int32_t
>
(
t1
),
bit_cast
<
int32_t
>
(
t0
),
m2
);
z3
=
__builtin_amdgcn_perm
(
bit_cast
<
int32_t
>
(
t1
),
bit_cast
<
int32_t
>
(
t0
),
m2
);
// y0 = bit_cast<f8x4_t>(z0);
y0
=
bit_cast
<
f8x4_t
>
(
z0
);
// y1 = bit_cast<f8x4_t>(z1);
y1
=
bit_cast
<
f8x4_t
>
(
z1
);
// y2 = bit_cast<f8x4_t>(z2);
y2
=
bit_cast
<
f8x4_t
>
(
z2
);
// y3 = bit_cast<f8x4_t>(z3);
y3
=
bit_cast
<
f8x4_t
>
(
z3
);
std
::
ignore
=
z0
;
std
::
ignore
=
z1
;
std
::
ignore
=
z2
;
std
::
ignore
=
z3
;
std
::
ignore
=
y0
;
std
::
ignore
=
y1
;
std
::
ignore
=
y2
;
std
::
ignore
=
y3
;
}
}
template
<
index_t
NX
,
index_t
NY
>
template
<
index_t
NX
,
index_t
NY
>
...
...
include/ck/utility/type_convert.hpp
View file @
dda07ed0
...
@@ -403,7 +403,7 @@ inline __host__ __device__ float2_t type_convert<float2_t, f8x2_t>(f8x2_t x)
...
@@ -403,7 +403,7 @@ inline __host__ __device__ float2_t type_convert<float2_t, f8x2_t>(f8x2_t x)
return
__builtin_amdgcn_cvt_pk_f32_fp8
(
i16val
,
0
);
return
__builtin_amdgcn_cvt_pk_f32_fp8
(
i16val
,
0
);
#else
#else
constexpr
bool
negative_zero_nan
=
true
;
constexpr
bool
negative_zero_nan
=
true
;
const
auto
f8x2_v
=
non_native_
vector_type
<
f8_t
,
2
>
(
x
);
const
auto
f8x2_v
=
vector_type
<
f8_t
,
2
>
(
x
);
vector_type
<
float
,
2
>
f32x2_v
;
vector_type
<
float
,
2
>
f32x2_v
;
f32x2_v
.
template
AsType
<
float
>()(
Number
<
0
>
{})
=
f32x2_v
.
template
AsType
<
float
>()(
Number
<
0
>
{})
=
utils
::
cast_from_f8
<
f8_t
,
float
,
negative_zero_nan
>
(
utils
::
cast_from_f8
<
f8_t
,
float
,
negative_zero_nan
>
(
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment