Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel_ROCM
Commits
63b152d6
Commit
63b152d6
authored
Oct 17, 2024
by
danyao12
Browse files
Merge branch 'develop' into ck_tile/fa_bwd_v3
parents
ae2d7d2b
14c3cfb1
Changes
132
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
736 additions
and
108 deletions
+736
-108
example/ck_tile/CMakeLists.txt
example/ck_tile/CMakeLists.txt
+1
-0
include/ck/config.h.in
include/ck/config.h.in
+0
-7
include/ck/host_utility/kernel_launch.hpp
include/ck/host_utility/kernel_launch.hpp
+6
-0
include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops.hpp
...or_operation/gpu/block/blockwise_gemm_pipeline_xdlops.hpp
+2
-2
include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_v1_ab_scale.hpp
.../gpu/block/blockwise_gemm_pipeline_xdlops_v1_ab_scale.hpp
+5
-4
include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_v2_ab_scale.hpp
.../gpu/block/blockwise_gemm_pipeline_xdlops_v2_ab_scale.hpp
+10
-8
include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_v3_ab_scale.hpp
.../gpu/block/blockwise_gemm_pipeline_xdlops_v3_ab_scale.hpp
+5
-4
include/ck/tensor_operation/gpu/device/device_cgemm.hpp
include/ck/tensor_operation/gpu/device/device_cgemm.hpp
+3
-3
include/ck/tensor_operation/gpu/device/impl/device_cgemm_4gemm_xdl_cshuffle.hpp
...ation/gpu/device/impl/device_cgemm_4gemm_xdl_cshuffle.hpp
+17
-1
include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_multi_abd_xdl_fixed_nk.hpp
...evice/impl/device_grouped_gemm_multi_abd_xdl_fixed_nk.hpp
+1
-1
include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_multiple_d_splitk_xdl_cshuffle_two_stage.hpp
...grouped_gemm_multiple_d_splitk_xdl_cshuffle_two_stage.hpp
+1
-1
include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_multiple_d_xdl_cshuffle_tile_loop.hpp
...device_grouped_gemm_multiple_d_xdl_cshuffle_tile_loop.hpp
+1
-1
include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_xdl_fixed_nk.hpp
...tion/gpu/device/impl/device_grouped_gemm_xdl_fixed_nk.hpp
+1
-1
include/ck/tensor_operation/gpu/element/unary_element_wise_operation.hpp
...or_operation/gpu/element/unary_element_wise_operation.hpp
+6
-0
include/ck/tensor_operation/gpu/warp/dpp_gemm.hpp
include/ck/tensor_operation/gpu/warp/dpp_gemm.hpp
+9
-9
include/ck/tensor_operation/gpu/warp/wmma_gemm.hpp
include/ck/tensor_operation/gpu/warp/wmma_gemm.hpp
+6
-6
include/ck/tensor_operation/gpu/warp/xdlops_gemm.hpp
include/ck/tensor_operation/gpu/warp/xdlops_gemm.hpp
+29
-29
include/ck/utility/data_type.hpp
include/ck/utility/data_type.hpp
+625
-30
include/ck/utility/math_v2.hpp
include/ck/utility/math_v2.hpp
+4
-0
include/ck_tile/core/config.hpp
include/ck_tile/core/config.hpp
+4
-1
No files found.
example/ck_tile/CMakeLists.txt
View file @
63b152d6
...
...
@@ -5,3 +5,4 @@ include_directories(AFTER
add_subdirectory
(
01_fmha
)
add_subdirectory
(
02_layernorm2d
)
add_subdirectory
(
03_gemm
)
add_subdirectory
(
04_img2col
)
include/ck/config.h.in
View file @
63b152d6
...
...
@@ -97,13 +97,6 @@
#cmakedefine CK_ENABLE_DL_KERNELS @CK_ENABLE_DL_KERNELS@
#endif
//
// Instances supports in the current CK build
//
#ifndef CK_ENABLE_INSTANCES_ONLY
#cmakedefine CK_ENABLE_INSTANCES_ONLY @CK_ENABLE_INSTANCES_ONLY@
#endif
//
// CK kernels which support XDL (MI series)
//
...
...
include/ck/host_utility/kernel_launch.hpp
View file @
63b152d6
...
...
@@ -66,6 +66,9 @@ float launch_and_time_kernel(const StreamConfig& stream_config,
hip_check_error
(
hipEventElapsedTime
(
&
total_time
,
start
,
stop
));
hip_check_error
(
hipEventDestroy
(
start
));
hip_check_error
(
hipEventDestroy
(
stop
));
return
total_time
/
nrepeat
;
}
else
...
...
@@ -143,6 +146,9 @@ float launch_and_time_kernel_with_preprocess(const StreamConfig& stream_config,
hip_check_error
(
hipEventElapsedTime
(
&
total_time
,
start
,
stop
));
hip_check_error
(
hipEventDestroy
(
start
));
hip_check_error
(
hipEventDestroy
(
stop
));
return
total_time
/
nrepeat
;
}
else
...
...
include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops.hpp
View file @
63b152d6
...
...
@@ -406,7 +406,7 @@ struct BlockwiseGemmXdlops_pipeline_v4
}
template
<
>
__device__
static
constexpr
auto
TailScheduler
<
1
>
()
__device__
constexpr
auto
TailScheduler
<
1
>
()
{
// schedule
constexpr
auto
num_ds_read_inst
=
...
...
@@ -433,7 +433,7 @@ struct BlockwiseGemmXdlops_pipeline_v4
}
template
<
>
__device__
static
constexpr
auto
TailScheduler
<
2
>
()
__device__
constexpr
auto
TailScheduler
<
2
>
()
{
// schedule
constexpr
auto
num_ds_read_inst
=
...
...
include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_v1_ab_scale.hpp
View file @
63b152d6
...
...
@@ -308,7 +308,7 @@ struct BlockwiseGemmXdlops_pipeline_v1_ab_scale<BlockGemmPipelineScheduler::Intr
typename
vector_type
<
ComputeDataType
,
xdlops_gemm
.
K1PerXdlops
>::
type
;
xdlops_gemm
.
template
Run
(
xdlops_gemm
.
template
Run
<
>
(
a_thread_vec
.
template
AsType
<
mfma_input_type
>(),
b_thread_vec
.
template
AsType
<
mfma_input_type
>(),
c_thread_buf_per_scale
.
GetVectorTypeReference
(
I0
));
...
...
@@ -390,9 +390,10 @@ struct BlockwiseGemmXdlops_pipeline_v1_ab_scale<BlockGemmPipelineScheduler::Intr
using
mfma_input_type
=
typename
vector_type
<
ComputeDataType
,
xdlops_gemm
.
K1PerXdlops
>::
type
;
xdlops_gemm
.
template
Run
(
a_thread_vec
.
template
AsType
<
mfma_input_type
>(),
b_thread_vec
.
template
AsType
<
mfma_input_type
>(),
c_thread_buf_per_scale
.
GetVectorTypeReference
(
I0
));
xdlops_gemm
.
template
Run
<
>(
a_thread_vec
.
template
AsType
<
mfma_input_type
>(),
b_thread_vec
.
template
AsType
<
mfma_input_type
>(),
c_thread_buf_per_scale
.
GetVectorTypeReference
(
I0
));
});
static_for
<
0
,
xdlops_gemm
.
GetRegSizePerXdlops
(),
1
>
{}([
&
](
auto
t
)
{
constexpr
index_t
c_offset
=
...
...
include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_v2_ab_scale.hpp
View file @
63b152d6
...
...
@@ -350,7 +350,7 @@ struct BlockwiseGemmXdlops_pipeline_v2_ab_scale<BlockGemmPipelineScheduler::Intr
typename
vector_type
<
ComputeDataType
,
xdlops_gemm
.
K1PerXdlops
>::
type
;
xdlops_gemm
.
template
Run
(
xdlops_gemm
.
template
Run
<
>
(
a_thread_vec
.
template
AsType
<
mfma_input_type
>(),
b_thread_vec
.
template
AsType
<
mfma_input_type
>(),
c_thread_buf_per_scale
.
GetVectorTypeReference
(
I0
));
...
...
@@ -443,7 +443,7 @@ struct BlockwiseGemmXdlops_pipeline_v2_ab_scale<BlockGemmPipelineScheduler::Intr
typename
vector_type
<
ComputeDataType
,
xdlops_gemm
.
K1PerXdlops
>::
type
;
xdlops_gemm
.
template
Run
(
xdlops_gemm
.
template
Run
<
>
(
a_thread_vec
.
template
AsType
<
mfma_input_type
>(),
b_thread_vec
.
template
AsType
<
mfma_input_type
>(),
c_thread_buf_per_scale
.
GetVectorTypeReference
(
I0
));
...
...
@@ -518,9 +518,10 @@ struct BlockwiseGemmXdlops_pipeline_v2_ab_scale<BlockGemmPipelineScheduler::Intr
using
mfma_input_type
=
typename
vector_type
<
ComputeDataType
,
xdlops_gemm
.
K1PerXdlops
>::
type
;
xdlops_gemm
.
template
Run
(
a_thread_vec
.
template
AsType
<
mfma_input_type
>(),
b_thread_vec
.
template
AsType
<
mfma_input_type
>(),
c_thread_buf_per_scale
.
GetVectorTypeReference
(
I0
));
xdlops_gemm
.
template
Run
<
>(
a_thread_vec
.
template
AsType
<
mfma_input_type
>(),
b_thread_vec
.
template
AsType
<
mfma_input_type
>(),
c_thread_buf_per_scale
.
GetVectorTypeReference
(
I0
));
});
static_for
<
0
,
xdlops_gemm
.
GetRegSizePerXdlops
(),
1
>
{}([
&
](
auto
t
)
{
constexpr
index_t
c_offset
=
...
...
@@ -575,9 +576,10 @@ struct BlockwiseGemmXdlops_pipeline_v2_ab_scale<BlockGemmPipelineScheduler::Intr
using
mfma_input_type
=
typename
vector_type
<
ComputeDataType
,
xdlops_gemm
.
K1PerXdlops
>::
type
;
xdlops_gemm
.
template
Run
(
a_thread_vec
.
template
AsType
<
mfma_input_type
>(),
b_thread_vec
.
template
AsType
<
mfma_input_type
>(),
c_thread_buf_per_scale
.
GetVectorTypeReference
(
I0
));
xdlops_gemm
.
template
Run
<
>(
a_thread_vec
.
template
AsType
<
mfma_input_type
>(),
b_thread_vec
.
template
AsType
<
mfma_input_type
>(),
c_thread_buf_per_scale
.
GetVectorTypeReference
(
I0
));
});
static_for
<
0
,
xdlops_gemm
.
GetRegSizePerXdlops
(),
1
>
{}([
&
](
auto
t
)
{
constexpr
index_t
c_offset
=
...
...
include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_v3_ab_scale.hpp
View file @
63b152d6
...
...
@@ -427,7 +427,7 @@ struct BlockwiseGemmXdlops_pipeline_v3_ab_scale<BlockGemmPipelineScheduler::Intr
typename
vector_type
<
ComputeDataType
,
xdlops_gemm
.
K1PerXdlops
>::
type
;
xdlops_gemm
.
template
Run
(
xdlops_gemm
.
template
Run
<
>
(
a_thread_vec
.
template
AsType
<
mfma_input_type
>(),
b_thread_vec
.
template
AsType
<
mfma_input_type
>(),
c_thread_buf_per_scale
.
GetVectorTypeReference
(
I0
));
...
...
@@ -504,9 +504,10 @@ struct BlockwiseGemmXdlops_pipeline_v3_ab_scale<BlockGemmPipelineScheduler::Intr
using
mfma_input_type
=
typename
vector_type
<
ComputeDataType
,
xdlops_gemm
.
K1PerXdlops
>::
type
;
xdlops_gemm
.
template
Run
(
a_thread_vec
.
template
AsType
<
mfma_input_type
>(),
b_thread_vec
.
template
AsType
<
mfma_input_type
>(),
c_thread_buf_per_scale
.
GetVectorTypeReference
(
I0
));
xdlops_gemm
.
template
Run
<
>(
a_thread_vec
.
template
AsType
<
mfma_input_type
>(),
b_thread_vec
.
template
AsType
<
mfma_input_type
>(),
c_thread_buf_per_scale
.
GetVectorTypeReference
(
I0
));
});
static_for
<
0
,
xdlops_gemm
.
GetRegSizePerXdlops
(),
1
>
{}([
&
](
auto
t
)
{
constexpr
index_t
c_offset
=
...
...
include/ck/tensor_operation/gpu/device/device_cgemm.hpp
View file @
63b152d6
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-202
3
, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-202
4
, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "device_base.hpp"
...
...
@@ -31,13 +31,13 @@ struct DeviceCGemm : public BaseOperator
CElementwiseOperation
c_element_op
,
ck
::
index_t
KBatch
=
1
)
=
0
;
virtual
std
::
unique_ptr
<
BaseInvoker
>
MakeInvokerPointer
()
=
0
;
virtual
std
::
unique_ptr
<
BaseInvoker
>
MakeInvokerPointer
()
=
0
;
virtual
std
::
size_t
GetWorkspaceSize
(
index_t
MRaw
,
index_t
NRaw
,
index_t
KRaw
,
index_t
StrideA
,
index_t
StrideB
,
index_t
StrideC
)
=
0
;
index_t
StrideC
)
const
=
0
;
};
template
<
typename
AElementwiseOperation
,
...
...
include/ck/tensor_operation/gpu/device/impl/device_cgemm_4gemm_xdl_cshuffle.hpp
View file @
63b152d6
...
...
@@ -598,10 +598,26 @@ struct DeviceCGemm_4Gemm_Xdl_CShuffle
[[
maybe_unused
]]
index_t
K
,
[[
maybe_unused
]]
index_t
StrideA
,
[[
maybe_unused
]]
index_t
StrideB
,
index_t
StrideC
)
override
index_t
StrideC
)
const
override
{
return
2
*
sizeof
(
CDataType
)
*
GetCElementSpaceSize
(
M
,
N
,
StrideC
);
}
std
::
size_t
GetWorkSpaceSize
(
const
BaseArgument
*
base_arg
)
const
override
{
const
auto
*
parg
=
dynamic_cast
<
const
Argument
*>
(
base_arg
);
if
(
!
parg
)
{
std
::
ostringstream
err
;
err
<<
"Provided argument pointer is not of an Argument class!"
<<
" In "
<<
__FILE__
<<
":"
<<
__LINE__
<<
", in function: "
<<
__func__
;
throw
std
::
runtime_error
(
err
.
str
());
}
return
GetWorkspaceSize
(
parg
->
M
,
parg
->
N
,
parg
->
K
,
parg
->
StrideA
,
parg
->
StrideB
,
parg
->
StrideC
);
}
};
}
// namespace device
...
...
include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_multi_abd_xdl_fixed_nk.hpp
View file @
63b152d6
...
...
@@ -64,7 +64,7 @@ __global__ void
const
index_t
N
=
gemm_desc_ptr
[
group_id
].
N
;
const
index_t
K
=
gemm_desc_ptr
[
group_id
].
K
;
if
(
M
*
N
*
K
==
0
)
if
(
M
==
0
||
N
==
0
||
K
==
0
)
return
;
const
auto
StrideAs
=
gemm_desc_ptr
[
group_id
].
StrideAs
;
...
...
include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_multiple_d_splitk_xdl_cshuffle_two_stage.hpp
View file @
63b152d6
...
...
@@ -345,7 +345,7 @@ struct DeviceGroupedGemmMultipleDSplitKXdlCShuffleTwoStage
const
index_t
N
=
gemm_descs
[
i
].
N_
;
const
index_t
K
=
gemm_descs
[
i
].
K_
;
if
(
M
*
N
*
K
==
0
)
if
(
M
==
0
||
N
==
0
||
K
==
0
)
{
skipped_group_count_
++
;
continue
;
...
...
include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_multiple_d_xdl_cshuffle_tile_loop.hpp
View file @
63b152d6
...
...
@@ -109,7 +109,7 @@ __global__ void
N
=
gemm_desc_ptr
[
group_id
].
N
;
K
=
gemm_desc_ptr
[
group_id
].
K
;
if
(
M
*
N
*
K
==
0
)
if
(
M
==
0
||
N
==
0
||
K
==
0
)
{
grid_size_grp
=
0
;
continue
;
...
...
include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_xdl_fixed_nk.hpp
View file @
63b152d6
...
...
@@ -68,7 +68,7 @@ __global__ void
const
index_t
N
=
gemm_desc_ptr
[
group_id
].
N
;
const
index_t
K
=
gemm_desc_ptr
[
group_id
].
K
;
if
(
M
*
N
*
K
==
0
)
if
(
M
==
0
||
N
==
0
||
K
==
0
)
return
;
const
auto
StrideA
=
gemm_desc_ptr
[
group_id
].
StrideA
;
...
...
include/ck/tensor_operation/gpu/element/unary_element_wise_operation.hpp
View file @
63b152d6
...
...
@@ -419,6 +419,12 @@ struct UnaryAbs
y
=
ck
::
math
::
abs
(
x
);
};
template
<
>
__host__
__device__
void
operator
()(
f8_t
&
y
,
const
f8_t
&
x
)
const
{
y
=
ck
::
type_convert
<
f8_t
>
(
ck
::
math
::
abs
(
ck
::
type_convert
<
float
>
(
x
)));
};
};
struct
UnarySqrt
...
...
include/ck/tensor_operation/gpu/warp/dpp_gemm.hpp
View file @
63b152d6
...
...
@@ -324,55 +324,55 @@ struct DppSelector
static
constexpr
auto
GetDpp
();
template
<
>
static
constexpr
auto
GetDpp
<
half_t
,
8
,
32
>
()
constexpr
auto
GetDpp
<
half_t
,
8
,
32
>
()
{
return
DppInstr
::
dpp8_f16_8x32x2
;
}
template
<
>
static
constexpr
auto
GetDpp
<
half_t
,
8
,
16
>
()
constexpr
auto
GetDpp
<
half_t
,
8
,
16
>
()
{
return
DppInstr
::
dpp8_f16_8x16x2
;
}
template
<
>
static
constexpr
auto
GetDpp
<
half_t
,
16
,
16
>
()
constexpr
auto
GetDpp
<
half_t
,
16
,
16
>
()
{
return
DppInstr
::
dpp8_f16_16x16x2
;
}
template
<
>
static
constexpr
auto
GetDpp
<
half_t
,
32
,
8
>
()
constexpr
auto
GetDpp
<
half_t
,
32
,
8
>
()
{
return
DppInstr
::
dpp8_f16_32x8x2
;
}
template
<
>
static
constexpr
auto
GetDpp
<
half_t
,
1
,
32
>
()
constexpr
auto
GetDpp
<
half_t
,
1
,
32
>
()
{
return
DppInstr
::
dpp8_f16_1x32x2
;
}
template
<
>
static
constexpr
auto
GetDpp
<
half_t
,
2
,
32
>
()
constexpr
auto
GetDpp
<
half_t
,
2
,
32
>
()
{
return
DppInstr
::
dpp8_f16_2x32x2
;
}
template
<
>
static
constexpr
auto
GetDpp
<
half_t
,
2
,
16
>
()
constexpr
auto
GetDpp
<
half_t
,
2
,
16
>
()
{
return
DppInstr
::
dpp8_f16_2x16x2
;
}
template
<
>
static
constexpr
auto
GetDpp
<
half_t
,
4
,
16
>
()
constexpr
auto
GetDpp
<
half_t
,
4
,
16
>
()
{
return
DppInstr
::
dpp8_f16_4x16x2
;
}
template
<
>
static
constexpr
auto
GetDpp
<
half_t
,
4
,
32
>
()
constexpr
auto
GetDpp
<
half_t
,
4
,
32
>
()
{
return
DppInstr
::
dpp8_f16_4x32x2
;
}
...
...
include/ck/tensor_operation/gpu/warp/wmma_gemm.hpp
View file @
63b152d6
...
...
@@ -415,7 +415,7 @@ struct WmmaSelector
static
constexpr
auto
GetWmma
();
template
<
>
static
constexpr
auto
GetWmma
<
half_t
,
half_t
,
float
,
16
,
16
>
()
constexpr
auto
GetWmma
<
half_t
,
half_t
,
float
,
16
,
16
>
()
{
#ifdef __gfx12__
return
WmmaInstr
::
wmma_f32_16x16x16_f16_gfx12
;
...
...
@@ -425,7 +425,7 @@ struct WmmaSelector
}
template
<
>
static
constexpr
auto
GetWmma
<
bhalf_t
,
bhalf_t
,
float
,
16
,
16
>
()
constexpr
auto
GetWmma
<
bhalf_t
,
bhalf_t
,
float
,
16
,
16
>
()
{
#ifdef __gfx12__
return
WmmaInstr
::
wmma_f32_16x16x16_bf16_gfx12
;
...
...
@@ -435,19 +435,19 @@ struct WmmaSelector
}
template
<
>
static
constexpr
auto
GetWmma
<
half_t
,
half_t
,
half_t
,
16
,
16
>
()
constexpr
auto
GetWmma
<
half_t
,
half_t
,
half_t
,
16
,
16
>
()
{
return
WmmaInstr
::
wmma_f16_16x16x16_f16
;
}
template
<
>
static
constexpr
auto
GetWmma
<
bhalf_t
,
bhalf_t
,
bhalf_t
,
16
,
16
>
()
constexpr
auto
GetWmma
<
bhalf_t
,
bhalf_t
,
bhalf_t
,
16
,
16
>
()
{
return
WmmaInstr
::
wmma_bf16_16x16x16_bf16
;
}
template
<
>
static
constexpr
auto
GetWmma
<
int8_t
,
int8_t
,
int
,
16
,
16
>
()
constexpr
auto
GetWmma
<
int8_t
,
int8_t
,
int
,
16
,
16
>
()
{
#ifdef __gfx12__
return
WmmaInstr
::
wmma_i32_16x16x16_iu8_gfx12
;
...
...
@@ -458,7 +458,7 @@ struct WmmaSelector
#ifdef CK_EXPERIMENTAL_BIT_INT_EXTENSION_INT4
template
<
>
static
constexpr
auto
GetWmma
<
int4_t
,
int4_t
,
int
,
16
,
16
>
()
constexpr
auto
GetWmma
<
int4_t
,
int4_t
,
int
,
16
,
16
>
()
{
return
WmmaInstr
::
wmma_i32_16x16x16_iu4
;
}
...
...
include/ck/tensor_operation/gpu/warp/xdlops_gemm.hpp
View file @
63b152d6
...
...
@@ -651,97 +651,97 @@ struct MfmaSelector
static
constexpr
auto
GetMfma
();
template
<
>
static
constexpr
auto
GetMfma
<
double
,
16
,
16
>
()
constexpr
auto
GetMfma
<
double
,
16
,
16
>
()
{
return
MfmaInstr
::
mfma_f64_16x16x4f64
;
}
template
<
>
static
constexpr
auto
GetMfma
<
float
,
64
,
64
>
()
constexpr
auto
GetMfma
<
float
,
64
,
64
>
()
{
return
MfmaInstr
::
mfma_f32_32x32x1xf32
;
}
template
<
>
static
constexpr
auto
GetMfma
<
float
,
32
,
64
>
()
constexpr
auto
GetMfma
<
float
,
32
,
64
>
()
{
return
MfmaInstr
::
mfma_f32_32x32x1xf32
;
}
template
<
>
static
constexpr
auto
GetMfma
<
float
,
16
,
64
>
()
constexpr
auto
GetMfma
<
float
,
16
,
64
>
()
{
return
MfmaInstr
::
mfma_f32_16x16x1xf32
;
}
template
<
>
static
constexpr
auto
GetMfma
<
float
,
8
,
64
>
()
constexpr
auto
GetMfma
<
float
,
8
,
64
>
()
{
return
MfmaInstr
::
mfma_f32_4x4x1xf32
;
}
template
<
>
static
constexpr
auto
GetMfma
<
float
,
4
,
64
>
()
constexpr
auto
GetMfma
<
float
,
4
,
64
>
()
{
return
MfmaInstr
::
mfma_f32_4x4x1xf32
;
}
template
<
>
static
constexpr
auto
GetMfma
<
float
,
32
,
32
>
()
constexpr
auto
GetMfma
<
float
,
32
,
32
>
()
{
return
MfmaInstr
::
mfma_f32_32x32x2xf32
;
}
template
<
>
static
constexpr
auto
GetMfma
<
float
,
16
,
16
>
()
constexpr
auto
GetMfma
<
float
,
16
,
16
>
()
{
return
MfmaInstr
::
mfma_f32_16x16x4xf32
;
}
template
<
>
static
constexpr
auto
GetMfma
<
half_t
,
64
,
64
>
()
constexpr
auto
GetMfma
<
half_t
,
64
,
64
>
()
{
return
MfmaInstr
::
mfma_f32_32x32x4f16
;
}
template
<
>
static
constexpr
auto
GetMfma
<
half_t
,
32
,
64
>
()
constexpr
auto
GetMfma
<
half_t
,
32
,
64
>
()
{
return
MfmaInstr
::
mfma_f32_32x32x4f16
;
}
template
<
>
static
constexpr
auto
GetMfma
<
half_t
,
32
,
32
>
()
constexpr
auto
GetMfma
<
half_t
,
32
,
32
>
()
{
return
MfmaInstr
::
mfma_f32_32x32x8f16
;
}
template
<
>
static
constexpr
auto
GetMfma
<
half_t
,
16
,
16
>
()
constexpr
auto
GetMfma
<
half_t
,
16
,
16
>
()
{
return
MfmaInstr
::
mfma_f32_16x16x16f16
;
}
template
<
>
static
constexpr
auto
GetMfma
<
half_t
,
16
,
64
>
()
constexpr
auto
GetMfma
<
half_t
,
16
,
64
>
()
{
return
MfmaInstr
::
mfma_f32_16x16x4f16
;
}
template
<
>
static
constexpr
auto
GetMfma
<
half_t
,
8
,
64
>
()
constexpr
auto
GetMfma
<
half_t
,
8
,
64
>
()
{
return
MfmaInstr
::
mfma_f32_4x4x4f16
;
}
template
<
>
static
constexpr
auto
GetMfma
<
half_t
,
4
,
64
>
()
constexpr
auto
GetMfma
<
half_t
,
4
,
64
>
()
{
return
MfmaInstr
::
mfma_f32_4x4x4f16
;
}
template
<
>
static
constexpr
auto
GetMfma
<
bhalf_t
,
32
,
32
>
()
constexpr
auto
GetMfma
<
bhalf_t
,
32
,
32
>
()
{
#if defined(CK_USE_AMD_MFMA_BF16_1K_OP)
return
MfmaInstr
::
mfma_f32_32x32x8bf16_1k
;
...
...
@@ -751,7 +751,7 @@ struct MfmaSelector
}
template
<
>
static
constexpr
auto
GetMfma
<
bhalf_t
,
16
,
16
>
()
constexpr
auto
GetMfma
<
bhalf_t
,
16
,
16
>
()
{
#if defined(CK_USE_AMD_MFMA_BF16_1K_OP)
return
MfmaInstr
::
mfma_f32_16x16x16bf16_1k
;
...
...
@@ -762,72 +762,72 @@ struct MfmaSelector
#if defined(CK_USE_AMD_MFMA_GFX940)
template
<
>
static
constexpr
auto
GetMfma
<
int8_t
,
32
,
32
>
()
constexpr
auto
GetMfma
<
int8_t
,
32
,
32
>
()
{
return
MfmaInstr
::
mfma_i32_32x32x16i8
;
}
template
<
>
static
constexpr
auto
GetMfma
<
int8_t
,
16
,
16
>
()
constexpr
auto
GetMfma
<
int8_t
,
16
,
16
>
()
{
return
MfmaInstr
::
mfma_i32_16x16x32i8
;
}
#else
template
<
>
static
constexpr
auto
GetMfma
<
int8_t
,
32
,
32
>
()
constexpr
auto
GetMfma
<
int8_t
,
32
,
32
>
()
{
return
MfmaInstr
::
mfma_i32_32x32x8i8
;
}
template
<
>
static
constexpr
auto
GetMfma
<
int8_t
,
16
,
16
>
()
constexpr
auto
GetMfma
<
int8_t
,
16
,
16
>
()
{
return
MfmaInstr
::
mfma_i32_16x16x16i8
;
}
#endif
template
<
>
static
constexpr
auto
GetMfma
<
f8_t
,
32
,
32
>
()
constexpr
auto
GetMfma
<
f8_t
,
32
,
32
>
()
{
return
MfmaInstr
::
mfma_f32_32x32x16f8f8
;
}
template
<
>
static
constexpr
auto
GetMfma
<
f8_t
,
16
,
16
>
()
constexpr
auto
GetMfma
<
f8_t
,
16
,
16
>
()
{
return
MfmaInstr
::
mfma_f32_16x16x32f8f8
;
}
template
<
>
static
constexpr
auto
GetMfma
<
bf8_t
,
32
,
32
>
()
constexpr
auto
GetMfma
<
bf8_t
,
32
,
32
>
()
{
return
MfmaInstr
::
mfma_f32_32x32x16bf8bf8
;
}
template
<
>
static
constexpr
auto
GetMfma
<
bf8_t
,
16
,
16
>
()
constexpr
auto
GetMfma
<
bf8_t
,
16
,
16
>
()
{
return
MfmaInstr
::
mfma_f32_16x16x32bf8bf8
;
}
template
<
>
static
constexpr
auto
GetMfma
<
f8_t
,
32
,
32
,
bf8_t
>
()
constexpr
auto
GetMfma
<
f8_t
,
32
,
32
,
bf8_t
>
()
{
return
MfmaInstr
::
mfma_f32_32x32x16f8bf8
;
}
template
<
>
static
constexpr
auto
GetMfma
<
f8_t
,
16
,
16
,
bf8_t
>
()
constexpr
auto
GetMfma
<
f8_t
,
16
,
16
,
bf8_t
>
()
{
return
MfmaInstr
::
mfma_f32_16x16x32f8bf8
;
}
template
<
>
static
constexpr
auto
GetMfma
<
bf8_t
,
32
,
32
,
f8_t
>
()
constexpr
auto
GetMfma
<
bf8_t
,
32
,
32
,
f8_t
>
()
{
return
MfmaInstr
::
mfma_f32_32x32x16bf8f8
;
}
template
<
>
static
constexpr
auto
GetMfma
<
bf8_t
,
16
,
16
,
f8_t
>
()
constexpr
auto
GetMfma
<
bf8_t
,
16
,
16
,
f8_t
>
()
{
return
MfmaInstr
::
mfma_f32_16x16x32bf8f8
;
}
...
...
include/ck/utility/data_type.hpp
View file @
63b152d6
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-202
3
, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-202
4
, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
...
...
@@ -13,8 +13,24 @@ using int4_t = _BitInt(4);
using
f8_t
=
_BitInt
(
8
);
using
bf8_t
=
unsigned
_BitInt
(
8
);
inline
constexpr
auto
next_pow2
(
uint32_t
x
)
{
// Precondition: x > 1.
return
x
>
1u
?
(
1u
<<
(
32u
-
__builtin_clz
(
x
-
1u
)))
:
x
;
}
// native types: double, float, _Float16, ushort, int32_t, int8_t, uint8_t, f8_t, bf8_t, bool
template
<
typename
T
>
inline
constexpr
bool
is_native_type
()
{
return
is_same
<
T
,
double
>::
value
||
is_same
<
T
,
float
>::
value
||
is_same
<
T
,
half_t
>::
value
||
is_same
<
T
,
bhalf_t
>::
value
||
is_same
<
T
,
int32_t
>::
value
||
is_same
<
T
,
int8_t
>::
value
||
is_same
<
T
,
uint8_t
>::
value
||
is_same
<
T
,
f8_t
>::
value
||
is_same
<
T
,
bf8_t
>::
value
||
is_same
<
T
,
bool
>::
value
;
}
// vector_type
template
<
typename
T
,
index_t
N
>
template
<
typename
T
,
index_t
N
,
typename
Enable
=
void
>
struct
vector_type
;
// Caution: DO NOT REMOVE
...
...
@@ -171,7 +187,7 @@ struct scalar_type<bool>
};
template
<
typename
T
>
struct
vector_type
<
T
,
1
>
struct
vector_type
<
T
,
1
,
typename
std
::
enable_if_t
<
is_native_type
<
T
>
()
>
>
{
using
d1_t
=
T
;
using
type
=
d1_t
;
...
...
@@ -189,7 +205,8 @@ struct vector_type<T, 1>
template
<
typename
X
>
__host__
__device__
constexpr
const
auto
&
AsType
()
const
{
static_assert
(
is_same
<
X
,
d1_t
>::
value
,
"wrong!"
);
static_assert
(
is_same
<
X
,
d1_t
>::
value
,
"Something went wrong, please check src and dst types."
);
return
data_
.
d1x1_
;
}
...
...
@@ -197,7 +214,8 @@ struct vector_type<T, 1>
template
<
typename
X
>
__host__
__device__
constexpr
auto
&
AsType
()
{
static_assert
(
is_same
<
X
,
d1_t
>::
value
,
"wrong!"
);
static_assert
(
is_same
<
X
,
d1_t
>::
value
,
"Something went wrong, please check src and dst types."
);
return
data_
.
d1x1_
;
}
...
...
@@ -205,7 +223,7 @@ struct vector_type<T, 1>
__device__
int
static
err
=
0
;
template
<
typename
T
>
struct
vector_type
<
T
,
2
>
struct
vector_type
<
T
,
2
,
typename
std
::
enable_if_t
<
is_native_type
<
T
>
()
>
>
{
using
d1_t
=
T
;
typedef
T
d2_t
__attribute__
((
ext_vector_type
(
2
)));
...
...
@@ -226,7 +244,8 @@ struct vector_type<T, 2>
template
<
typename
X
>
__host__
__device__
constexpr
const
auto
&
AsType
()
const
{
static_assert
(
is_same
<
X
,
d1_t
>::
value
||
is_same
<
X
,
d2_t
>::
value
,
"wrong!"
);
static_assert
(
is_same
<
X
,
d1_t
>::
value
||
is_same
<
X
,
d2_t
>::
value
,
"Something went wrong, please check src and dst types."
);
if
constexpr
(
is_same
<
X
,
d1_t
>::
value
)
{
...
...
@@ -245,7 +264,8 @@ struct vector_type<T, 2>
template
<
typename
X
>
__host__
__device__
constexpr
auto
&
AsType
()
{
static_assert
(
is_same
<
X
,
d1_t
>::
value
||
is_same
<
X
,
d2_t
>::
value
,
"wrong!"
);
static_assert
(
is_same
<
X
,
d1_t
>::
value
||
is_same
<
X
,
d2_t
>::
value
,
"Something went wrong, please check src and dst types."
);
if
constexpr
(
is_same
<
X
,
d1_t
>::
value
)
{
...
...
@@ -263,7 +283,7 @@ struct vector_type<T, 2>
};
template
<
typename
T
>
struct
vector_type
<
T
,
4
>
struct
vector_type
<
T
,
4
,
typename
std
::
enable_if_t
<
is_native_type
<
T
>
()
>
>
{
using
d1_t
=
T
;
typedef
T
d2_t
__attribute__
((
ext_vector_type
(
2
)));
...
...
@@ -287,7 +307,7 @@ struct vector_type<T, 4>
__host__
__device__
constexpr
const
auto
&
AsType
()
const
{
static_assert
(
is_same
<
X
,
d1_t
>::
value
||
is_same
<
X
,
d2_t
>::
value
||
is_same
<
X
,
d4_t
>::
value
,
"
wrong!
"
);
"
Something went wrong, please check src and dst types.
"
);
if
constexpr
(
is_same
<
X
,
d1_t
>::
value
)
{
...
...
@@ -311,7 +331,7 @@ struct vector_type<T, 4>
__host__
__device__
constexpr
auto
&
AsType
()
{
static_assert
(
is_same
<
X
,
d1_t
>::
value
||
is_same
<
X
,
d2_t
>::
value
||
is_same
<
X
,
d4_t
>::
value
,
"
wrong!
"
);
"
Something went wrong, please check src and dst types.
"
);
if
constexpr
(
is_same
<
X
,
d1_t
>::
value
)
{
...
...
@@ -333,7 +353,7 @@ struct vector_type<T, 4>
};
template
<
typename
T
>
struct
vector_type
<
T
,
8
>
struct
vector_type
<
T
,
8
,
typename
std
::
enable_if_t
<
is_native_type
<
T
>
()
>
>
{
using
d1_t
=
T
;
typedef
T
d2_t
__attribute__
((
ext_vector_type
(
2
)));
...
...
@@ -360,7 +380,7 @@ struct vector_type<T, 8>
{
static_assert
(
is_same
<
X
,
d1_t
>::
value
||
is_same
<
X
,
d2_t
>::
value
||
is_same
<
X
,
d4_t
>::
value
||
is_same
<
X
,
d8_t
>::
value
,
"
wrong!
"
);
"
Something went wrong, please check src and dst types.
"
);
if
constexpr
(
is_same
<
X
,
d1_t
>::
value
)
{
...
...
@@ -389,7 +409,7 @@ struct vector_type<T, 8>
{
static_assert
(
is_same
<
X
,
d1_t
>::
value
||
is_same
<
X
,
d2_t
>::
value
||
is_same
<
X
,
d4_t
>::
value
||
is_same
<
X
,
d8_t
>::
value
,
"
wrong!
"
);
"
Something went wrong, please check src and dst types.
"
);
if
constexpr
(
is_same
<
X
,
d1_t
>::
value
)
{
...
...
@@ -415,7 +435,7 @@ struct vector_type<T, 8>
};
template
<
typename
T
>
struct
vector_type
<
T
,
16
>
struct
vector_type
<
T
,
16
,
typename
std
::
enable_if_t
<
is_native_type
<
T
>
()
>
>
{
using
d1_t
=
T
;
typedef
T
d2_t
__attribute__
((
ext_vector_type
(
2
)));
...
...
@@ -445,7 +465,7 @@ struct vector_type<T, 16>
static_assert
(
is_same
<
X
,
d1_t
>::
value
||
is_same
<
X
,
d2_t
>::
value
||
is_same
<
X
,
d4_t
>::
value
||
is_same
<
X
,
d8_t
>::
value
||
is_same
<
X
,
d16_t
>::
value
,
"
wrong!
"
);
"
Something went wrong, please check src and dst types.
"
);
if
constexpr
(
is_same
<
X
,
d1_t
>::
value
)
{
...
...
@@ -479,7 +499,7 @@ struct vector_type<T, 16>
static_assert
(
is_same
<
X
,
d1_t
>::
value
||
is_same
<
X
,
d2_t
>::
value
||
is_same
<
X
,
d4_t
>::
value
||
is_same
<
X
,
d8_t
>::
value
||
is_same
<
X
,
d16_t
>::
value
,
"
wrong!
"
);
"
Something went wrong, please check src and dst types.
"
);
if
constexpr
(
is_same
<
X
,
d1_t
>::
value
)
{
...
...
@@ -509,7 +529,7 @@ struct vector_type<T, 16>
};
template
<
typename
T
>
struct
vector_type
<
T
,
32
>
struct
vector_type
<
T
,
32
,
typename
std
::
enable_if_t
<
is_native_type
<
T
>
()
>
>
{
using
d1_t
=
T
;
typedef
T
d2_t
__attribute__
((
ext_vector_type
(
2
)));
...
...
@@ -541,7 +561,7 @@ struct vector_type<T, 32>
static_assert
(
is_same
<
X
,
d1_t
>::
value
||
is_same
<
X
,
d2_t
>::
value
||
is_same
<
X
,
d4_t
>::
value
||
is_same
<
X
,
d8_t
>::
value
||
is_same
<
X
,
d16_t
>::
value
||
is_same
<
X
,
d32_t
>::
value
,
"
wrong!
"
);
"
Something went wrong, please check src and dst types.
"
);
if
constexpr
(
is_same
<
X
,
d1_t
>::
value
)
{
...
...
@@ -579,7 +599,7 @@ struct vector_type<T, 32>
static_assert
(
is_same
<
X
,
d1_t
>::
value
||
is_same
<
X
,
d2_t
>::
value
||
is_same
<
X
,
d4_t
>::
value
||
is_same
<
X
,
d8_t
>::
value
||
is_same
<
X
,
d16_t
>::
value
||
is_same
<
X
,
d32_t
>::
value
,
"
wrong!
"
);
"
Something went wrong, please check src and dst types.
"
);
if
constexpr
(
is_same
<
X
,
d1_t
>::
value
)
{
...
...
@@ -613,7 +633,7 @@ struct vector_type<T, 32>
};
template
<
typename
T
>
struct
vector_type
<
T
,
64
>
struct
vector_type
<
T
,
64
,
typename
std
::
enable_if_t
<
is_native_type
<
T
>
()
>
>
{
using
d1_t
=
T
;
typedef
T
d2_t
__attribute__
((
ext_vector_type
(
2
)));
...
...
@@ -648,7 +668,7 @@ struct vector_type<T, 64>
is_same
<
X
,
d4_t
>::
value
||
is_same
<
X
,
d8_t
>::
value
||
is_same
<
X
,
d16_t
>::
value
||
is_same
<
X
,
d32_t
>::
value
||
is_same
<
X
,
d64_t
>::
value
,
"
wrong!
"
);
"
Something went wrong, please check src and dst types.
"
);
if
constexpr
(
is_same
<
X
,
d1_t
>::
value
)
{
...
...
@@ -691,7 +711,7 @@ struct vector_type<T, 64>
is_same
<
X
,
d4_t
>::
value
||
is_same
<
X
,
d8_t
>::
value
||
is_same
<
X
,
d16_t
>::
value
||
is_same
<
X
,
d32_t
>::
value
||
is_same
<
X
,
d64_t
>::
value
,
"
wrong!
"
);
"
Something went wrong, please check src and dst types.
"
);
if
constexpr
(
is_same
<
X
,
d1_t
>::
value
)
{
...
...
@@ -729,7 +749,7 @@ struct vector_type<T, 64>
};
template
<
typename
T
>
struct
vector_type
<
T
,
128
>
struct
vector_type
<
T
,
128
,
typename
std
::
enable_if_t
<
is_native_type
<
T
>
()
>
>
{
using
d1_t
=
T
;
typedef
T
d2_t
__attribute__
((
ext_vector_type
(
2
)));
...
...
@@ -766,7 +786,7 @@ struct vector_type<T, 128>
is_same
<
X
,
d4_t
>::
value
||
is_same
<
X
,
d8_t
>::
value
||
is_same
<
X
,
d16_t
>::
value
||
is_same
<
X
,
d32_t
>::
value
||
is_same
<
X
,
d64_t
>::
value
||
is_same
<
X
,
d128_t
>::
value
,
"
wrong!
"
);
"
Something went wrong, please check src and dst types.
"
);
if
constexpr
(
is_same
<
X
,
d1_t
>::
value
)
{
...
...
@@ -813,7 +833,7 @@ struct vector_type<T, 128>
is_same
<
X
,
d4_t
>::
value
||
is_same
<
X
,
d8_t
>::
value
||
is_same
<
X
,
d16_t
>::
value
||
is_same
<
X
,
d32_t
>::
value
||
is_same
<
X
,
d64_t
>::
value
||
is_same
<
X
,
d128_t
>::
value
,
"
wrong!
"
);
"
Something went wrong, please check src and dst types.
"
);
if
constexpr
(
is_same
<
X
,
d1_t
>::
value
)
{
...
...
@@ -855,7 +875,7 @@ struct vector_type<T, 128>
};
template
<
typename
T
>
struct
vector_type
<
T
,
256
>
struct
vector_type
<
T
,
256
,
typename
std
::
enable_if_t
<
is_native_type
<
T
>
()
>
>
{
using
d1_t
=
T
;
typedef
T
d2_t
__attribute__
((
ext_vector_type
(
2
)));
...
...
@@ -894,7 +914,7 @@ struct vector_type<T, 256>
is_same
<
X
,
d1_t
>::
value
||
is_same
<
X
,
d2_t
>::
value
||
is_same
<
X
,
d4_t
>::
value
||
is_same
<
X
,
d8_t
>::
value
||
is_same
<
X
,
d16_t
>::
value
||
is_same
<
X
,
d32_t
>::
value
||
is_same
<
X
,
d64_t
>::
value
||
is_same
<
X
,
d128_t
>::
value
||
is_same
<
X
,
d256_t
>::
value
,
"
wrong!
"
);
"
Something went wrong, please check src and dst types.
"
);
if
constexpr
(
is_same
<
X
,
d1_t
>::
value
)
{
...
...
@@ -945,7 +965,7 @@ struct vector_type<T, 256>
is_same
<
X
,
d1_t
>::
value
||
is_same
<
X
,
d2_t
>::
value
||
is_same
<
X
,
d4_t
>::
value
||
is_same
<
X
,
d8_t
>::
value
||
is_same
<
X
,
d16_t
>::
value
||
is_same
<
X
,
d32_t
>::
value
||
is_same
<
X
,
d64_t
>::
value
||
is_same
<
X
,
d128_t
>::
value
||
is_same
<
X
,
d256_t
>::
value
,
"
wrong!
"
);
"
Something went wrong, please check src and dst types.
"
);
if
constexpr
(
is_same
<
X
,
d1_t
>::
value
)
{
...
...
@@ -990,6 +1010,581 @@ struct vector_type<T, 256>
}
};
template
<
typename
T
,
index_t
N
>
struct
non_native_vector_base
{
using
type
=
non_native_vector_base
<
T
,
N
>
;
__host__
__device__
non_native_vector_base
()
=
default
;
__host__
__device__
non_native_vector_base
(
const
type
&
)
=
default
;
__host__
__device__
non_native_vector_base
(
type
&&
)
=
default
;
__host__
__device__
~
non_native_vector_base
()
=
default
;
T
d
[
N
];
};
// non-native vector_type implementation
template
<
typename
T
>
struct
vector_type
<
T
,
1
,
typename
std
::
enable_if_t
<!
is_native_type
<
T
>
()
>>
{
using
d1_t
=
T
;
using
type
=
d1_t
;
union
alignas
(
next_pow2
(
1
*
sizeof
(
T
)))
{
d1_t
d1_
;
StaticallyIndexedArray
<
d1_t
,
1
>
d1x1_
;
}
data_
;
__host__
__device__
constexpr
vector_type
()
:
data_
{
type
{}}
{}
__host__
__device__
constexpr
vector_type
(
type
v
)
:
data_
{
v
}
{}
template
<
typename
X
>
__host__
__device__
constexpr
const
auto
&
AsType
()
const
{
static_assert
(
is_same
<
X
,
d1_t
>::
value
,
"Something went wrong, please check src and dst types."
);
return
data_
.
d1x1_
;
}
template
<
typename
X
>
__host__
__device__
constexpr
auto
&
AsType
()
{
static_assert
(
is_same
<
X
,
d1_t
>::
value
,
"Something went wrong, please check src and dst types."
);
return
data_
.
d1x1_
;
}
};
template
<
typename
T
>
struct
vector_type
<
T
,
2
,
typename
std
::
enable_if_t
<!
is_native_type
<
T
>
()
>>
{
using
d1_t
=
T
;
using
d2_t
=
non_native_vector_base
<
T
,
2
>
;
using
type
=
d2_t
;
union
alignas
(
next_pow2
(
2
*
sizeof
(
T
)))
{
d2_t
d2_
;
StaticallyIndexedArray
<
d1_t
,
2
>
d1x2_
;
StaticallyIndexedArray
<
d2_t
,
1
>
d2x1_
;
}
data_
;
__host__
__device__
constexpr
vector_type
()
:
data_
{
type
{}}
{}
__host__
__device__
constexpr
vector_type
(
type
v
)
:
data_
{
v
}
{}
template
<
typename
X
>
__host__
__device__
constexpr
const
auto
&
AsType
()
const
{
static_assert
(
is_same
<
X
,
d1_t
>::
value
||
is_same
<
X
,
d2_t
>::
value
,
"Something went wrong, please check src and dst types."
);
if
constexpr
(
is_same
<
X
,
d1_t
>::
value
)
{
return
data_
.
d1x2_
;
}
else
if
constexpr
(
is_same
<
X
,
d2_t
>::
value
)
{
return
data_
.
d2x1_
;
}
else
{
return
err
;
}
}
template
<
typename
X
>
__host__
__device__
constexpr
auto
&
AsType
()
{
static_assert
(
is_same
<
X
,
d1_t
>::
value
||
is_same
<
X
,
d2_t
>::
value
,
"Something went wrong, please check src and dst types."
);
if
constexpr
(
is_same
<
X
,
d1_t
>::
value
)
{
return
data_
.
d1x2_
;
}
else
if
constexpr
(
is_same
<
X
,
d2_t
>::
value
)
{
return
data_
.
d2x1_
;
}
else
{
return
err
;
}
}
};
template
<
typename
T
>
struct
vector_type
<
T
,
4
,
typename
std
::
enable_if_t
<!
is_native_type
<
T
>
()
>>
{
using
d1_t
=
T
;
using
d2_t
=
non_native_vector_base
<
T
,
2
>
;
using
d4_t
=
non_native_vector_base
<
T
,
4
>
;
using
type
=
d4_t
;
union
alignas
(
next_pow2
(
4
*
sizeof
(
T
)))
{
d4_t
d4_
;
StaticallyIndexedArray
<
d1_t
,
4
>
d1x4_
;
StaticallyIndexedArray
<
d2_t
,
2
>
d2x2_
;
StaticallyIndexedArray
<
d4_t
,
1
>
d4x1_
;
}
data_
;
__host__
__device__
constexpr
vector_type
()
:
data_
{
type
{}}
{}
__host__
__device__
constexpr
vector_type
(
type
v
)
:
data_
{
v
}
{}
template
<
typename
X
>
__host__
__device__
constexpr
const
auto
&
AsType
()
const
{
static_assert
(
is_same
<
X
,
d1_t
>::
value
||
is_same
<
X
,
d2_t
>::
value
||
is_same
<
X
,
d4_t
>::
value
,
"Something went wrong, please check src and dst types."
);
if
constexpr
(
is_same
<
X
,
d1_t
>::
value
)
{
return
data_
.
d1x4_
;
}
else
if
constexpr
(
is_same
<
X
,
d2_t
>::
value
)
{
return
data_
.
d2x2_
;
}
else
if
constexpr
(
is_same
<
X
,
d4_t
>::
value
)
{
return
data_
.
d4x1_
;
}
else
{
return
err
;
}
}
template
<
typename
X
>
__host__
__device__
constexpr
auto
&
AsType
()
{
static_assert
(
is_same
<
X
,
d1_t
>::
value
||
is_same
<
X
,
d2_t
>::
value
||
is_same
<
X
,
d4_t
>::
value
,
"Something went wrong, please check src and dst types."
);
if
constexpr
(
is_same
<
X
,
d1_t
>::
value
)
{
return
data_
.
d1x4_
;
}
else
if
constexpr
(
is_same
<
X
,
d2_t
>::
value
)
{
return
data_
.
d2x2_
;
}
else
if
constexpr
(
is_same
<
X
,
d4_t
>::
value
)
{
return
data_
.
d4x1_
;
}
else
{
return
err
;
}
}
};
template
<
typename
T
>
struct
vector_type
<
T
,
8
,
typename
std
::
enable_if_t
<!
is_native_type
<
T
>
()
>>
{
using
d1_t
=
T
;
using
d2_t
=
non_native_vector_base
<
T
,
2
>
;
using
d4_t
=
non_native_vector_base
<
T
,
4
>
;
using
d8_t
=
non_native_vector_base
<
T
,
8
>
;
using
type
=
d8_t
;
union
alignas
(
next_pow2
(
8
*
sizeof
(
T
)))
{
d8_t
d8_
;
StaticallyIndexedArray
<
d1_t
,
8
>
d1x8_
;
StaticallyIndexedArray
<
d2_t
,
4
>
d2x4_
;
StaticallyIndexedArray
<
d4_t
,
2
>
d4x2_
;
StaticallyIndexedArray
<
d8_t
,
1
>
d8x1_
;
}
data_
;
__host__
__device__
constexpr
vector_type
()
:
data_
{
type
{}}
{}
__host__
__device__
constexpr
vector_type
(
type
v
)
:
data_
{
v
}
{}
template
<
typename
X
>
__host__
__device__
constexpr
const
auto
&
AsType
()
const
{
static_assert
(
is_same
<
X
,
d1_t
>::
value
||
is_same
<
X
,
d2_t
>::
value
||
is_same
<
X
,
d4_t
>::
value
||
is_same
<
X
,
d8_t
>::
value
,
"Something went wrong, please check src and dst types."
);
if
constexpr
(
is_same
<
X
,
d1_t
>::
value
)
{
return
data_
.
d1x8_
;
}
else
if
constexpr
(
is_same
<
X
,
d2_t
>::
value
)
{
return
data_
.
d2x4_
;
}
else
if
constexpr
(
is_same
<
X
,
d4_t
>::
value
)
{
return
data_
.
d4x2_
;
}
else
if
constexpr
(
is_same
<
X
,
d8_t
>::
value
)
{
return
data_
.
d8x1_
;
}
else
{
return
err
;
}
}
template
<
typename
X
>
__host__
__device__
constexpr
auto
&
AsType
()
{
static_assert
(
is_same
<
X
,
d1_t
>::
value
||
is_same
<
X
,
d2_t
>::
value
||
is_same
<
X
,
d4_t
>::
value
||
is_same
<
X
,
d8_t
>::
value
,
"Something went wrong, please check src and dst types."
);
if
constexpr
(
is_same
<
X
,
d1_t
>::
value
)
{
return
data_
.
d1x8_
;
}
else
if
constexpr
(
is_same
<
X
,
d2_t
>::
value
)
{
return
data_
.
d2x4_
;
}
else
if
constexpr
(
is_same
<
X
,
d4_t
>::
value
)
{
return
data_
.
d4x2_
;
}
else
if
constexpr
(
is_same
<
X
,
d8_t
>::
value
)
{
return
data_
.
d8x1_
;
}
else
{
return
err
;
}
}
};
template
<
typename
T
>
struct
vector_type
<
T
,
16
,
typename
std
::
enable_if_t
<!
is_native_type
<
T
>
()
>>
{
using
d1_t
=
T
;
using
d2_t
=
non_native_vector_base
<
T
,
2
>
;
using
d4_t
=
non_native_vector_base
<
T
,
4
>
;
using
d8_t
=
non_native_vector_base
<
T
,
8
>
;
using
d16_t
=
non_native_vector_base
<
T
,
16
>
;
using
type
=
d16_t
;
union
alignas
(
next_pow2
(
16
*
sizeof
(
T
)))
{
d16_t
d16_
;
StaticallyIndexedArray
<
d1_t
,
16
>
d1x16_
;
StaticallyIndexedArray
<
d2_t
,
8
>
d2x8_
;
StaticallyIndexedArray
<
d4_t
,
4
>
d4x4_
;
StaticallyIndexedArray
<
d8_t
,
2
>
d8x2_
;
StaticallyIndexedArray
<
d16_t
,
1
>
d16x1_
;
}
data_
;
__host__
__device__
constexpr
vector_type
()
:
data_
{
type
{}}
{}
__host__
__device__
constexpr
vector_type
(
type
v
)
:
data_
{
v
}
{}
template
<
typename
X
>
__host__
__device__
constexpr
const
auto
&
AsType
()
const
{
static_assert
(
is_same
<
X
,
d1_t
>::
value
||
is_same
<
X
,
d2_t
>::
value
||
is_same
<
X
,
d4_t
>::
value
||
is_same
<
X
,
d8_t
>::
value
||
is_same
<
X
,
d16_t
>::
value
,
"Something went wrong, please check src and dst types."
);
if
constexpr
(
is_same
<
X
,
d1_t
>::
value
)
{
return
data_
.
d1x16_
;
}
else
if
constexpr
(
is_same
<
X
,
d2_t
>::
value
)
{
return
data_
.
d2x8_
;
}
else
if
constexpr
(
is_same
<
X
,
d4_t
>::
value
)
{
return
data_
.
d4x4_
;
}
else
if
constexpr
(
is_same
<
X
,
d8_t
>::
value
)
{
return
data_
.
d8x2_
;
}
else
if
constexpr
(
is_same
<
X
,
d16_t
>::
value
)
{
return
data_
.
d16x1_
;
}
else
{
return
err
;
}
}
template
<
typename
X
>
__host__
__device__
constexpr
auto
&
AsType
()
{
static_assert
(
is_same
<
X
,
d1_t
>::
value
||
is_same
<
X
,
d2_t
>::
value
||
is_same
<
X
,
d4_t
>::
value
||
is_same
<
X
,
d8_t
>::
value
||
is_same
<
X
,
d16_t
>::
value
,
"Something went wrong, please check src and dst types."
);
if
constexpr
(
is_same
<
X
,
d1_t
>::
value
)
{
return
data_
.
d1x16_
;
}
else
if
constexpr
(
is_same
<
X
,
d2_t
>::
value
)
{
return
data_
.
d2x8_
;
}
else
if
constexpr
(
is_same
<
X
,
d4_t
>::
value
)
{
return
data_
.
d4x4_
;
}
else
if
constexpr
(
is_same
<
X
,
d8_t
>::
value
)
{
return
data_
.
d8x2_
;
}
else
if
constexpr
(
is_same
<
X
,
d16_t
>::
value
)
{
return
data_
.
d16x1_
;
}
else
{
return
err
;
}
}
};
template
<
typename
T
>
struct
vector_type
<
T
,
32
,
typename
std
::
enable_if_t
<!
is_native_type
<
T
>
()
>>
{
using
d1_t
=
T
;
using
d2_t
=
non_native_vector_base
<
T
,
2
>
;
using
d4_t
=
non_native_vector_base
<
T
,
4
>
;
using
d8_t
=
non_native_vector_base
<
T
,
8
>
;
using
d16_t
=
non_native_vector_base
<
T
,
16
>
;
using
d32_t
=
non_native_vector_base
<
T
,
32
>
;
using
type
=
d32_t
;
union
alignas
(
next_pow2
(
32
*
sizeof
(
T
)))
{
d32_t
d32_
;
StaticallyIndexedArray
<
d1_t
,
32
>
d1x32_
;
StaticallyIndexedArray
<
d2_t
,
16
>
d2x16_
;
StaticallyIndexedArray
<
d4_t
,
8
>
d4x8_
;
StaticallyIndexedArray
<
d8_t
,
4
>
d8x4_
;
StaticallyIndexedArray
<
d16_t
,
2
>
d16x2_
;
StaticallyIndexedArray
<
d32_t
,
1
>
d32x1_
;
}
data_
;
__host__
__device__
constexpr
vector_type
()
:
data_
{
type
{}}
{}
__host__
__device__
constexpr
vector_type
(
type
v
)
:
data_
{
v
}
{}
template
<
typename
X
>
__host__
__device__
constexpr
const
auto
&
AsType
()
const
{
static_assert
(
is_same
<
X
,
d1_t
>::
value
||
is_same
<
X
,
d2_t
>::
value
||
is_same
<
X
,
d4_t
>::
value
||
is_same
<
X
,
d8_t
>::
value
||
is_same
<
X
,
d16_t
>::
value
||
is_same
<
X
,
d32_t
>::
value
,
"Something went wrong, please check src and dst types."
);
if
constexpr
(
is_same
<
X
,
d1_t
>::
value
)
{
return
data_
.
d1x32_
;
}
else
if
constexpr
(
is_same
<
X
,
d2_t
>::
value
)
{
return
data_
.
d2x16_
;
}
else
if
constexpr
(
is_same
<
X
,
d4_t
>::
value
)
{
return
data_
.
d4x8_
;
}
else
if
constexpr
(
is_same
<
X
,
d8_t
>::
value
)
{
return
data_
.
d8x4_
;
}
else
if
constexpr
(
is_same
<
X
,
d16_t
>::
value
)
{
return
data_
.
d16x2_
;
}
else
if
constexpr
(
is_same
<
X
,
d32_t
>::
value
)
{
return
data_
.
d32x1_
;
}
else
{
return
err
;
}
}
template
<
typename
X
>
__host__
__device__
constexpr
auto
&
AsType
()
{
static_assert
(
is_same
<
X
,
d1_t
>::
value
||
is_same
<
X
,
d2_t
>::
value
||
is_same
<
X
,
d4_t
>::
value
||
is_same
<
X
,
d8_t
>::
value
||
is_same
<
X
,
d16_t
>::
value
||
is_same
<
X
,
d32_t
>::
value
,
"Something went wrong, please check src and dst types."
);
if
constexpr
(
is_same
<
X
,
d1_t
>::
value
)
{
return
data_
.
d1x32_
;
}
else
if
constexpr
(
is_same
<
X
,
d2_t
>::
value
)
{
return
data_
.
d2x16_
;
}
else
if
constexpr
(
is_same
<
X
,
d4_t
>::
value
)
{
return
data_
.
d4x8_
;
}
else
if
constexpr
(
is_same
<
X
,
d8_t
>::
value
)
{
return
data_
.
d8x4_
;
}
else
if
constexpr
(
is_same
<
X
,
d16_t
>::
value
)
{
return
data_
.
d16x2_
;
}
else
if
constexpr
(
is_same
<
X
,
d32_t
>::
value
)
{
return
data_
.
d32x1_
;
}
else
{
return
err
;
}
}
};
template
<
typename
T
>
struct
vector_type
<
T
,
64
,
typename
std
::
enable_if_t
<!
is_native_type
<
T
>
()
>>
{
using
d1_t
=
T
;
using
d2_t
=
non_native_vector_base
<
T
,
2
>
;
using
d4_t
=
non_native_vector_base
<
T
,
4
>
;
using
d8_t
=
non_native_vector_base
<
T
,
8
>
;
using
d16_t
=
non_native_vector_base
<
T
,
16
>
;
using
d32_t
=
non_native_vector_base
<
T
,
32
>
;
using
d64_t
=
non_native_vector_base
<
T
,
64
>
;
using
type
=
d64_t
;
union
alignas
(
next_pow2
(
64
*
sizeof
(
T
)))
{
d64_t
d64_
;
StaticallyIndexedArray
<
d1_t
,
64
>
d1x64_
;
StaticallyIndexedArray
<
d2_t
,
32
>
d2x32_
;
StaticallyIndexedArray
<
d4_t
,
16
>
d4x16_
;
StaticallyIndexedArray
<
d8_t
,
8
>
d8x8_
;
StaticallyIndexedArray
<
d16_t
,
4
>
d16x4_
;
StaticallyIndexedArray
<
d32_t
,
2
>
d32x2_
;
StaticallyIndexedArray
<
d64_t
,
1
>
d64x1_
;
}
data_
;
__host__
__device__
constexpr
vector_type
()
:
data_
{
type
{}}
{}
__host__
__device__
constexpr
vector_type
(
type
v
)
:
data_
{
v
}
{}
template
<
typename
X
>
__host__
__device__
constexpr
const
auto
&
AsType
()
const
{
static_assert
(
is_same
<
X
,
d1_t
>::
value
||
is_same
<
X
,
d2_t
>::
value
||
is_same
<
X
,
d4_t
>::
value
||
is_same
<
X
,
d8_t
>::
value
||
is_same
<
X
,
d16_t
>::
value
||
is_same
<
X
,
d32_t
>::
value
||
is_same
<
X
,
d64_t
>::
value
,
"Something went wrong, please check src and dst types."
);
if
constexpr
(
is_same
<
X
,
d1_t
>::
value
)
{
return
data_
.
d1x64_
;
}
else
if
constexpr
(
is_same
<
X
,
d2_t
>::
value
)
{
return
data_
.
d2x32_
;
}
else
if
constexpr
(
is_same
<
X
,
d4_t
>::
value
)
{
return
data_
.
d4x16_
;
}
else
if
constexpr
(
is_same
<
X
,
d8_t
>::
value
)
{
return
data_
.
d8x8_
;
}
else
if
constexpr
(
is_same
<
X
,
d16_t
>::
value
)
{
return
data_
.
d16x4_
;
}
else
if
constexpr
(
is_same
<
X
,
d32_t
>::
value
)
{
return
data_
.
d32x2_
;
}
else
if
constexpr
(
is_same
<
X
,
d64_t
>::
value
)
{
return
data_
.
d64x1_
;
}
else
{
return
err
;
}
}
template
<
typename
X
>
__host__
__device__
constexpr
auto
&
AsType
()
{
static_assert
(
is_same
<
X
,
d1_t
>::
value
||
is_same
<
X
,
d2_t
>::
value
||
is_same
<
X
,
d4_t
>::
value
||
is_same
<
X
,
d8_t
>::
value
||
is_same
<
X
,
d16_t
>::
value
||
is_same
<
X
,
d32_t
>::
value
||
is_same
<
X
,
d64_t
>::
value
,
"Something went wrong, please check src and dst types."
);
if
constexpr
(
is_same
<
X
,
d1_t
>::
value
)
{
return
data_
.
d1x64_
;
}
else
if
constexpr
(
is_same
<
X
,
d2_t
>::
value
)
{
return
data_
.
d2x32_
;
}
else
if
constexpr
(
is_same
<
X
,
d4_t
>::
value
)
{
return
data_
.
d4x16_
;
}
else
if
constexpr
(
is_same
<
X
,
d8_t
>::
value
)
{
return
data_
.
d8x8_
;
}
else
if
constexpr
(
is_same
<
X
,
d16_t
>::
value
)
{
return
data_
.
d16x4_
;
}
else
if
constexpr
(
is_same
<
X
,
d32_t
>::
value
)
{
return
data_
.
d32x2_
;
}
else
if
constexpr
(
is_same
<
X
,
d64_t
>::
value
)
{
return
data_
.
d64x1_
;
}
else
{
return
err
;
}
}
};
using
int64_t
=
long
;
// fp64
...
...
@@ -1051,8 +1646,8 @@ using bf8x8_t = typename vector_type<bf8_t, 8>::type;
using
bf8x16_t
=
typename
vector_type
<
bf8_t
,
16
>::
type
;
using
bf8x32_t
=
typename
vector_type
<
bf8_t
,
32
>::
type
;
using
bf8x64_t
=
typename
vector_type
<
bf8_t
,
64
>::
type
;
// u8
// i8
using
uint8x2_t
=
typename
vector_type
<
uint8_t
,
2
>::
type
;
using
uint8x4_t
=
typename
vector_type
<
uint8_t
,
4
>::
type
;
using
uint8x8_t
=
typename
vector_type
<
uint8_t
,
8
>::
type
;
...
...
include/ck/utility/math_v2.hpp
View file @
63b152d6
...
...
@@ -80,6 +80,8 @@ static inline __host__ bool isnan(half_t x)
return
(
xx
&
0x7FFF
)
>
0x7C00
;
};
static
inline
__host__
bool
isnan
(
f8_t
x
)
{
return
(
x
&
0x80
);
};
#ifdef CK_EXPERIMENTAL_BIT_INT_EXTENSION_INT4
static
inline
__host__
bool
isnan
(
int4_t
x
)
{
...
...
@@ -529,6 +531,8 @@ static inline __device__ bool isnan(half_t x)
return
(
xx
&
0x7FFF
)
>
0x7C00
;
};
static
inline
__device__
bool
isnan
(
f8_t
x
)
{
return
(
x
&
0x80
);
};
static
inline
__device__
half_t
sqrt
(
half_t
x
)
{
return
static_cast
<
half_t
>
(
__builtin_amdgcn_sqrtf
(
static_cast
<
float
>
(
x
)));
...
...
include/ck_tile/core/config.hpp
View file @
63b152d6
...
...
@@ -157,8 +157,11 @@
#endif
#endif
// workaround for ROCm 6.2 and later
#ifndef CK_TILE_WORKAROUND_ROCM_6_2_SCRATCH_MEMORY_ISSUE
#if HIP_VERSION_MAJOR == 6 && HIP_VERSION_MINOR == 2 && HIP_VERSION_PATCH >= 41133
#if(HIP_VERSION_MAJOR == 6 && HIP_VERSION_MINOR == 2 && HIP_VERSION_PATCH >= 41133) || \
(HIP_VERSION_MAJOR == 6 && HIP_VERSION_MINOR == 3 && HIP_VERSION_PATCH >= 42131) || \
(HIP_VERSION_MAJOR == 6 && HIP_VERSION_MINOR > 3)
#define CK_TILE_WORKAROUND_ROCM_6_2_SCRATCH_MEMORY_ISSUE 1
#else
#define CK_TILE_WORKAROUND_ROCM_6_2_SCRATCH_MEMORY_ISSUE 0
...
...
Prev
1
2
3
4
5
6
7
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment