Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel_ROCM
Commits
5e28c17a
Commit
5e28c17a
authored
Sep 13, 2024
by
Rostyslav Geyyer
Browse files
Add mfma selection
parent
29eaa2dc
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
42 additions
and
20 deletions
+42
-20
include/ck/tensor_operation/gpu/warp/xdlops_gemm.hpp
include/ck/tensor_operation/gpu/warp/xdlops_gemm.hpp
+7
-1
include/ck/utility/data_type.hpp
include/ck/utility/data_type.hpp
+35
-19
No files found.
include/ck/tensor_operation/gpu/warp/xdlops_gemm.hpp
View file @
5e28c17a
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-202
3
, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-202
4
, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
...
...
@@ -716,6 +716,12 @@ struct MfmaSelector
return
MfmaInstr
::
mfma_f32_32x32x8f16
;
}
template
<
>
static
constexpr
auto
GetMfma
<
custom_half_t
,
32
,
32
>
()
{
return
MfmaInstr
::
mfma_f32_32x32x8f16
;
}
template
<
>
static
constexpr
auto
GetMfma
<
half_t
,
16
,
16
>
()
{
...
...
include/ck/utility/data_type.hpp
View file @
5e28c17a
...
...
@@ -17,8 +17,10 @@ struct custom_half_t
{
using
type
=
short
;
type
data
;
custom_half_t
()
:
data
{
type
{}}
{}
custom_half_t
(
type
init
)
:
data
{
init
}
{}
__host__
__device__
constexpr
custom_half_t
()
:
data
{
type
{}}
{}
__host__
__device__
constexpr
custom_half_t
(
type
init
)
:
data
{
init
}
{}
__host__
__device__
constexpr
custom_half_t
(
int
init
)
:
data
{
static_cast
<
type
>
(
init
)}
{}
__host__
__device__
constexpr
custom_half_t
(
float
init
)
:
data
{
static_cast
<
type
>
(
init
)}
{}
};
inline
constexpr
auto
next_pow2
(
uint32_t
x
)
...
...
@@ -37,6 +39,22 @@ inline constexpr bool is_native_type()
is_same
<
T
,
bool
>::
value
;
}
template
<
typename
T
,
index_t
N
>
struct
non_native_vector_base
{
using
VecT
=
non_native_vector_base
<
T
,
N
>
;
__host__
__device__
non_native_vector_base
()
=
default
;
__host__
__device__
non_native_vector_base
(
const
VecT
&
)
=
default
;
__host__
__device__
non_native_vector_base
(
VecT
&&
)
=
default
;
__host__
__device__
~
non_native_vector_base
()
=
default
;
T
d
[
N
];
};
// vector_type
template
<
typename
T
,
index_t
N
,
typename
Enable
=
void
>
struct
vector_type
;
...
...
@@ -114,7 +132,13 @@ struct scalar_type<vector_type<T, N>>
static
constexpr
index_t
vector_size
=
N
;
};
//
template
<
typename
T
,
index_t
N
>
struct
scalar_type
<
non_native_vector_base
<
T
,
N
>>
{
using
type
=
T
;
static
constexpr
index_t
vector_size
=
N
;
};
template
<
>
struct
scalar_type
<
double
>
{
...
...
@@ -1021,22 +1045,6 @@ struct vector_type<T, 256, typename std::enable_if_t<is_native_type<T>()>>
}
};
template
<
typename
T
,
index_t
N
>
struct
non_native_vector_base
{
using
VecT
=
non_native_vector_base
<
T
,
N
>
;
__host__
__device__
non_native_vector_base
()
=
default
;
__host__
__device__
non_native_vector_base
(
const
VecT
&
)
=
default
;
__host__
__device__
non_native_vector_base
(
VecT
&&
)
=
default
;
__host__
__device__
~
non_native_vector_base
()
=
default
;
T
d
[
N
];
};
// non-native vector_type implementation
template
<
typename
T
>
struct
vector_type
<
T
,
1
,
typename
std
::
enable_if_t
<!
is_native_type
<
T
>
()
>>
...
...
@@ -1631,6 +1639,14 @@ using half16_t = typename vector_type<half_t, 16>::type;
using
half32_t
=
typename
vector_type
<
half_t
,
32
>::
type
;
using
half64_t
=
typename
vector_type
<
half_t
,
64
>::
type
;
// custom fp16
using
custom_half2_t
=
typename
vector_type
<
custom_half_t
,
2
>::
type
;
using
custom_half4_t
=
typename
vector_type
<
custom_half_t
,
4
>::
type
;
using
custom_half8_t
=
typename
vector_type
<
custom_half_t
,
8
>::
type
;
using
custom_half16_t
=
typename
vector_type
<
custom_half_t
,
16
>::
type
;
using
custom_half32_t
=
typename
vector_type
<
custom_half_t
,
32
>::
type
;
using
custom_half64_t
=
typename
vector_type
<
custom_half_t
,
64
>::
type
;
// bfp16
using
bhalf2_t
=
typename
vector_type
<
bhalf_t
,
2
>::
type
;
using
bhalf4_t
=
typename
vector_type
<
bhalf_t
,
4
>::
type
;
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment