Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel_ROCM
Commits
dbb7002d
Commit
dbb7002d
authored
Feb 06, 2025
by
Adam Osewski
Browse files
Merge remote-tracking branch 'origin/develop' into aosewski/hotloop
parents
96c8d948
2bef5501
Changes
228
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
1754 additions
and
184 deletions
+1754
-184
include/ck/tensor_operation/gpu/grid/gridwise_batched_gemm_multiple_d_softmax_gemm_xdl_cshuffle_v1.hpp
..._batched_gemm_multiple_d_softmax_gemm_xdl_cshuffle_v1.hpp
+1
-0
include/ck/tensor_operation/gpu/grid/gridwise_batched_gemm_softmax_gemm_xdl_cshuffle_v1.hpp
...id/gridwise_batched_gemm_softmax_gemm_xdl_cshuffle_v1.hpp
+1
-0
include/ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_abd_xdl_cshuffle.hpp
...tion/gpu/grid/gridwise_gemm_multiple_abd_xdl_cshuffle.hpp
+40
-13
include/ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_d_xdl_cshuffle.hpp
...ration/gpu/grid/gridwise_gemm_multiple_d_xdl_cshuffle.hpp
+13
-1
include/ck/tensor_operation/gpu/grid/gridwise_gemm_pipeline_selector.hpp
...or_operation/gpu/grid/gridwise_gemm_pipeline_selector.hpp
+7
-2
include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp
...operation/gpu/thread/threadwise_tensor_slice_transfer.hpp
+2
-2
include/ck/tensor_operation/gpu/warp/xdlops_gemm.hpp
include/ck/tensor_operation/gpu/warp/xdlops_gemm.hpp
+312
-11
include/ck/tensor_operation/operator_transform/transform_conv_fwd_to_gemm.hpp
...eration/operator_transform/transform_conv_fwd_to_gemm.hpp
+74
-63
include/ck/utility/amd_buffer_addressing.hpp
include/ck/utility/amd_buffer_addressing.hpp
+16
-2
include/ck/utility/amd_ck_fp8.hpp
include/ck/utility/amd_ck_fp8.hpp
+29
-33
include/ck/utility/amd_wave_read_first_lane.hpp
include/ck/utility/amd_wave_read_first_lane.hpp
+14
-13
include/ck/utility/amd_xdlops.hpp
include/ck/utility/amd_xdlops.hpp
+264
-1
include/ck/utility/array.hpp
include/ck/utility/array.hpp
+4
-2
include/ck/utility/container_helper.hpp
include/ck/utility/container_helper.hpp
+3
-3
include/ck/utility/data_type.hpp
include/ck/utility/data_type.hpp
+868
-32
include/ck/utility/debug.hpp
include/ck/utility/debug.hpp
+2
-1
include/ck/utility/e8m0.hpp
include/ck/utility/e8m0.hpp
+80
-0
include/ck/utility/enable_if.hpp
include/ck/utility/enable_if.hpp
+18
-1
include/ck/utility/env.hpp
include/ck/utility/env.hpp
+3
-1
include/ck/utility/functional.hpp
include/ck/utility/functional.hpp
+3
-3
No files found.
include/ck/tensor_operation/gpu/grid/gridwise_batched_gemm_multiple_d_softmax_gemm_xdl_cshuffle_v1.hpp
View file @
dbb7002d
...
...
@@ -773,6 +773,7 @@ struct GridwiseBatchedGemmMultipleDSoftmaxGemm_Xdl_CShuffle
// with 'group_size' amount of contiguous elements. Having Gemm1KPack greater than A1K1 will
// cause mismatch in summation index for example c[0:7] = a1[[0:3, 8:11]] * b1[0:7].
// therefore we may just as well assign Gemm1KPack = group_size
constexpr
index_t
Gemm1KPack
=
MfmaSelector
<
FloatAB
,
MPerXdl
,
NPerXdl
>::
selected_mfma
.
group_size
;
...
...
include/ck/tensor_operation/gpu/grid/gridwise_batched_gemm_softmax_gemm_xdl_cshuffle_v1.hpp
View file @
dbb7002d
...
...
@@ -628,6 +628,7 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle
// with 'group_size' amount of contiguous elements. Having Gemm1KPack greater than A1K1 will
// cause mismatch in summation index for example c[0:7] = a1[[0:3, 8:11]] * b1[0:7].
// therefore we may just as well assign Gemm1KPack = group_size
constexpr
index_t
Gemm1KPack
=
MfmaSelector
<
FloatAB
,
MPerXdl
,
NPerXdl
>::
selected_mfma
.
group_size
;
...
...
include/ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_abd_xdl_cshuffle.hpp
View file @
dbb7002d
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-202
4
, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-202
5
, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
...
...
@@ -423,10 +423,17 @@ struct GridwiseGemmMultipleABD_xdl_cshuffle
}
template
<
typename
AsLayout
,
GemmSpecialization
GemmSpec
>
__host__
__device__
static
auto
MakeAsGridDescriptor_M_K
(
const
std
::
array
<
index_t
,
NumATensor
>&
MRaws
,
const
std
::
array
<
index_t
,
NumATensor
>&
KRaws
,
const
std
::
array
<
index_t
,
NumATensor
>&
AsStride
)
__host__
__device__
static
auto
MakeAsGridDescriptor_M_K
(
#ifdef CK_CODE_GEN_RTC
const
ck
::
Array
<
index_t
,
NumATensor
>&
MRaws
,
const
ck
::
Array
<
index_t
,
NumATensor
>&
KRaws
,
const
ck
::
Array
<
index_t
,
NumATensor
>&
AsStride
#else
const
std
::
array
<
index_t
,
NumATensor
>&
MRaws
,
const
std
::
array
<
index_t
,
NumATensor
>&
KRaws
,
const
std
::
array
<
index_t
,
NumATensor
>&
AsStride
#endif
)
{
return
generate_tuple
(
[
&
](
auto
i
)
{
...
...
@@ -462,10 +469,17 @@ struct GridwiseGemmMultipleABD_xdl_cshuffle
}
template
<
typename
BsLayout
,
GemmSpecialization
GemmSpec
>
__host__
__device__
static
auto
MakeBsGridDescriptor_N_K
(
const
std
::
array
<
index_t
,
NumBTensor
>&
NRaws
,
const
std
::
array
<
index_t
,
NumBTensor
>&
KRaws
,
const
std
::
array
<
index_t
,
NumBTensor
>&
BsStride
)
__host__
__device__
static
auto
MakeBsGridDescriptor_N_K
(
#ifdef CK_CODE_GEN_RTC
const
ck
::
Array
<
index_t
,
NumBTensor
>&
NRaws
,
const
ck
::
Array
<
index_t
,
NumBTensor
>&
KRaws
,
const
ck
::
Array
<
index_t
,
NumBTensor
>&
BsStride
#else
const
std
::
array
<
index_t
,
NumBTensor
>&
NRaws
,
const
std
::
array
<
index_t
,
NumBTensor
>&
KRaws
,
const
std
::
array
<
index_t
,
NumBTensor
>&
BsStride
#endif
)
{
return
generate_tuple
(
[
&
](
auto
i
)
{
...
...
@@ -500,10 +514,17 @@ struct GridwiseGemmMultipleABD_xdl_cshuffle
}
template
<
typename
DsLayout
,
GemmSpecialization
GemmSpec
>
__host__
__device__
static
auto
MakeDsGridDescriptor_M_N
(
const
std
::
array
<
index_t
,
NumDTensor
>&
MRaws
,
const
std
::
array
<
index_t
,
NumDTensor
>&
NRaws
,
const
std
::
array
<
index_t
,
NumDTensor
>&
DsStride
)
__host__
__device__
static
auto
MakeDsGridDescriptor_M_N
(
#ifdef CK_CODE_GEN_RTC
const
ck
::
Array
<
index_t
,
NumDTensor
>&
MRaws
,
const
ck
::
Array
<
index_t
,
NumDTensor
>&
NRaws
,
const
ck
::
Array
<
index_t
,
NumDTensor
>&
DsStride
#else
const
std
::
array
<
index_t
,
NumDTensor
>&
MRaws
,
const
std
::
array
<
index_t
,
NumDTensor
>&
NRaws
,
const
std
::
array
<
index_t
,
NumDTensor
>&
DsStride
#endif
)
{
return
generate_tuple
(
[
&
](
auto
i
)
{
...
...
@@ -969,9 +990,15 @@ struct GridwiseGemmMultipleABD_xdl_cshuffle
const
index_t
M
,
const
index_t
N
,
const
index_t
K
,
#ifdef CK_CODE_GEN_RTC
const
ck
::
Array
<
index_t
,
NumATensor
>
StrideAs
,
const
ck
::
Array
<
index_t
,
NumBTensor
>
StrideBs
,
const
ck
::
Array
<
index_t
,
NumDTensor
>
StrideDs
,
#else
const
std
::
array
<
index_t
,
NumATensor
>
StrideAs
,
const
std
::
array
<
index_t
,
NumBTensor
>
StrideBs
,
const
std
::
array
<
index_t
,
NumDTensor
>
StrideDs
,
#endif
const
index_t
StrideE
,
const
Block2ETileMap
&
block_2_etile_map
)
{
...
...
include/ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_d_xdl_cshuffle.hpp
View file @
dbb7002d
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-202
4
, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-202
5
, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
...
...
@@ -473,11 +473,19 @@ struct GridwiseGemmMultipleD_xdl_cshuffle
return
matrix_padder
.
PadCDescriptor_M_N
(
e_grid_desc_mraw_nraw
);
}
#ifdef CK_CODE_GEN_RTC
template
<
typename
DsLayout
,
GemmSpecialization
GemmSpec
>
__host__
__device__
static
auto
MakeDsGridDescriptor_M_N
(
const
ck
::
Array
<
index_t
,
NumDTensor
>&
MRaws
,
const
ck
::
Array
<
index_t
,
NumDTensor
>&
NRaws
,
const
ck
::
Array
<
index_t
,
NumDTensor
>&
DsStride
)
#else
template
<
typename
DsLayout
,
GemmSpecialization
GemmSpec
>
__host__
__device__
static
auto
MakeDsGridDescriptor_M_N
(
const
std
::
array
<
index_t
,
NumDTensor
>&
MRaws
,
const
std
::
array
<
index_t
,
NumDTensor
>&
NRaws
,
const
std
::
array
<
index_t
,
NumDTensor
>&
DsStride
)
#endif
{
return
generate_tuple
(
[
&
](
auto
i
)
{
...
...
@@ -941,7 +949,11 @@ struct GridwiseGemmMultipleD_xdl_cshuffle
const
index_t
K
,
const
index_t
StrideA
,
const
index_t
StrideB
,
#ifdef CK_CODE_GEN_RTC
const
ck
::
Array
<
index_t
,
NumDTensor
>
StrideDs
,
#else
const
std
::
array
<
index_t
,
NumDTensor
>
StrideDs
,
#endif
const
index_t
StrideE
,
const
Block2ETileMap
&
block_2_etile_map
)
{
...
...
include/ck/tensor_operation/gpu/grid/gridwise_gemm_pipeline_selector.hpp
View file @
dbb7002d
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-202
4
, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-202
5
, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#ifndef CK_CODE_GEN_RTC
#include <iostream>
#include <ostream>
#endif
#include "ck/tensor_operation/gpu/grid/gridwise_gemm_pipeline_v1.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_gemm_pipeline_v2.hpp"
...
...
@@ -53,12 +54,15 @@ constexpr auto GridwiseGemmPipeline_Selector()
}
else
{
#ifndef CK_CODE_GEN_RTC
std
::
cerr
<<
"GridwiseGemmPipeline configuration is not available"
<<
std
::
endl
;
#endif
}
}
}
// namespace ck
#ifndef CK_CODE_GEN_RTC
inline
std
::
ostream
&
operator
<<
(
std
::
ostream
&
os
,
const
ck
::
PipelineVersion
&
p
)
{
switch
(
p
)
...
...
@@ -71,3 +75,4 @@ inline std::ostream& operator<<(std::ostream& os, const ck::PipelineVersion& p)
}
return
os
;
}
#endif
include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp
View file @
dbb7002d
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-202
3
, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-202
5
, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
...
...
@@ -212,7 +212,7 @@ template <typename SrcData,
typename
enable_if
<
DstDesc
::
IsKnownAtCompileTime
(),
bool
>
::
type
=
false
>
struct
ThreadwiseTensorSliceTransfer_v2
{
static_assert
((
InvalidElementAsNaN
&&
!
std
::
is_integral
<
DstData
>::
value
)
||
static_assert
((
InvalidElementAsNaN
&&
!
ck
::
is_integral
<
DstData
>::
value
)
||
(
!
InvalidElementAsNaN
),
"Filling invalid element as NaN is only for floating point types"
);
...
...
include/ck/tensor_operation/gpu/warp/xdlops_gemm.hpp
View file @
dbb7002d
...
...
@@ -37,7 +37,17 @@ enum struct MfmaInstr
mfma_f32_32x32x16f8bf8
,
mfma_f32_16x16x32f8bf8
,
mfma_f32_32x32x16bf8f8
,
mfma_f32_16x16x32bf8f8
mfma_f32_16x16x32bf8f8
,
mfma_f32_32x32x16f16
,
mfma_f32_16x16x32f16
,
mfma_f32_32x32x16bf16
,
mfma_f32_16x16x32bf16
,
mfma_i32_32x32x32i8
,
mfma_i32_16x16x64i8
,
mfma_f32_32x32x64f8f6f4
,
mfma_f32_16x16x128f8f6f4
,
mfma_scale_f32_32x32x64f8f6f4
,
mfma_scale_f32_16x16x128f8f6f4
};
template
<
MfmaInstr
instr
>
...
...
@@ -198,6 +208,50 @@ struct mfma_type<MfmaInstr::mfma_f32_32x32x8f16>
}
};
template
<
>
struct
mfma_type
<
MfmaInstr
::
mfma_f32_32x32x16f16
>
{
static
constexpr
index_t
group_size
=
4
;
static
constexpr
index_t
num_groups_per_blk
=
4
;
static
constexpr
index_t
num_regs_per_blk
=
16
;
static
constexpr
index_t
num_threads_per_blk
=
32
;
static
constexpr
index_t
wave_size
=
64
;
static
constexpr
index_t
num_input_blks
=
2
;
static
constexpr
index_t
num_output_blks
=
1
;
static
constexpr
index_t
m_per_blk
=
32
;
static
constexpr
index_t
n_per_blk
=
32
;
static
constexpr
index_t
k_per_blk
=
8
;
static
constexpr
bool
is_k_reduction
=
true
;
template
<
index_t
MPerXdlops
,
index_t
NPerXdlops
,
class
FloatA
,
class
FloatB
,
class
FloatC
>
__device__
void
run
(
const
FloatA
&
a
,
const
FloatB
&
b
,
FloatC
&
reg_c
)
const
{
intrin_mfma_f32_32x32x16f16
<
MPerXdlops
,
NPerXdlops
>::
Run
(
a
,
b
,
reg_c
);
}
};
template
<
>
struct
mfma_type
<
MfmaInstr
::
mfma_f32_16x16x32f16
>
{
static
constexpr
index_t
group_size
=
4
;
static
constexpr
index_t
num_groups_per_blk
=
1
;
static
constexpr
index_t
num_regs_per_blk
=
4
;
static
constexpr
index_t
num_threads_per_blk
=
16
;
static
constexpr
index_t
wave_size
=
64
;
static
constexpr
index_t
num_input_blks
=
4
;
static
constexpr
index_t
num_output_blks
=
1
;
static
constexpr
index_t
m_per_blk
=
16
;
static
constexpr
index_t
n_per_blk
=
16
;
static
constexpr
index_t
k_per_blk
=
8
;
static
constexpr
bool
is_k_reduction
=
true
;
template
<
index_t
MPerXdlops
,
index_t
NPerXdlops
,
class
FloatA
,
class
FloatB
,
class
FloatC
>
__device__
void
run
(
const
FloatA
&
a
,
const
FloatB
&
b
,
FloatC
&
reg_c
)
const
{
intrin_mfma_f32_16x16x32f16
<
MPerXdlops
,
NPerXdlops
>::
Run
(
a
,
b
,
reg_c
);
}
};
template
<
>
struct
mfma_type
<
MfmaInstr
::
mfma_f32_16x16x16f16
>
{
...
...
@@ -264,6 +318,28 @@ struct mfma_type<MfmaInstr::mfma_f32_4x4x4f16>
}
};
template
<
>
struct
mfma_type
<
MfmaInstr
::
mfma_f32_32x32x16bf16
>
{
static
constexpr
index_t
group_size
=
4
;
static
constexpr
index_t
num_groups_per_blk
=
4
;
static
constexpr
index_t
num_regs_per_blk
=
16
;
static
constexpr
index_t
num_threads_per_blk
=
32
;
static
constexpr
index_t
wave_size
=
64
;
static
constexpr
index_t
num_input_blks
=
2
;
static
constexpr
index_t
num_output_blks
=
1
;
static
constexpr
index_t
m_per_blk
=
32
;
static
constexpr
index_t
n_per_blk
=
32
;
static
constexpr
index_t
k_per_blk
=
8
;
static
constexpr
bool
is_k_reduction
=
true
;
template
<
index_t
MPerXdlops
,
index_t
NPerXdlops
,
class
FloatA
,
class
FloatB
,
class
FloatC
>
__device__
void
run
(
const
FloatA
&
a
,
const
FloatB
&
b
,
FloatC
&
reg_c
)
const
{
intrin_mfma_f32_32x32x16bf16
<
MPerXdlops
,
NPerXdlops
>::
Run
(
a
,
b
,
reg_c
);
}
};
template
<
>
struct
mfma_type
<
MfmaInstr
::
mfma_f32_32x32x8bf16_1k
>
{
...
...
@@ -286,6 +362,28 @@ struct mfma_type<MfmaInstr::mfma_f32_32x32x8bf16_1k>
}
};
template
<
>
struct
mfma_type
<
MfmaInstr
::
mfma_f32_16x16x32bf16
>
{
static
constexpr
index_t
group_size
=
4
;
static
constexpr
index_t
num_groups_per_blk
=
1
;
static
constexpr
index_t
num_regs_per_blk
=
4
;
static
constexpr
index_t
num_threads_per_blk
=
16
;
static
constexpr
index_t
wave_size
=
64
;
static
constexpr
index_t
num_input_blks
=
4
;
static
constexpr
index_t
num_output_blks
=
1
;
static
constexpr
index_t
m_per_blk
=
16
;
static
constexpr
index_t
n_per_blk
=
16
;
static
constexpr
index_t
k_per_blk
=
8
;
static
constexpr
bool
is_k_reduction
=
true
;
template
<
index_t
MPerXdlops
,
index_t
NPerXdlops
,
class
FloatA
,
class
FloatB
,
class
FloatC
>
__device__
void
run
(
const
FloatA
&
a
,
const
FloatB
&
b
,
FloatC
&
reg_c
)
const
{
intrin_mfma_f32_16x16x32bf16
<
MPerXdlops
,
NPerXdlops
>::
Run
(
a
,
b
,
reg_c
);
}
};
template
<
>
struct
mfma_type
<
MfmaInstr
::
mfma_f32_16x16x16bf16_1k
>
{
...
...
@@ -440,6 +538,50 @@ struct mfma_type<MfmaInstr::mfma_i32_16x16x32i8>
}
};
template
<
>
struct
mfma_type
<
MfmaInstr
::
mfma_i32_32x32x32i8
>
{
static
constexpr
index_t
group_size
=
4
;
static
constexpr
index_t
num_groups_per_blk
=
4
;
static
constexpr
index_t
num_regs_per_blk
=
16
;
static
constexpr
index_t
num_threads_per_blk
=
32
;
static
constexpr
index_t
wave_size
=
64
;
static
constexpr
index_t
num_input_blks
=
2
;
static
constexpr
index_t
num_output_blks
=
1
;
static
constexpr
index_t
m_per_blk
=
32
;
static
constexpr
index_t
n_per_blk
=
32
;
static
constexpr
index_t
k_per_blk
=
16
;
static
constexpr
bool
is_k_reduction
=
true
;
template
<
index_t
MPerXdlops
,
index_t
NPerXdlops
,
class
FloatA
,
class
FloatB
,
class
FloatC
>
__device__
void
run
(
const
FloatA
&
a
,
const
FloatB
&
b
,
FloatC
&
reg_c
)
const
{
intrin_mfma_i32_32x32x32i8
<
MPerXdlops
,
NPerXdlops
>::
Run
(
a
,
b
,
reg_c
);
}
};
template
<
>
struct
mfma_type
<
MfmaInstr
::
mfma_i32_16x16x64i8
>
{
static
constexpr
index_t
group_size
=
4
;
static
constexpr
index_t
num_groups_per_blk
=
1
;
static
constexpr
index_t
num_regs_per_blk
=
4
;
static
constexpr
index_t
num_threads_per_blk
=
16
;
static
constexpr
index_t
wave_size
=
64
;
static
constexpr
index_t
num_input_blks
=
4
;
static
constexpr
index_t
num_output_blks
=
1
;
static
constexpr
index_t
m_per_blk
=
16
;
static
constexpr
index_t
n_per_blk
=
16
;
static
constexpr
index_t
k_per_blk
=
16
;
static
constexpr
bool
is_k_reduction
=
true
;
template
<
index_t
MPerXdlops
,
index_t
NPerXdlops
,
class
FloatA
,
class
FloatB
,
class
FloatC
>
__device__
void
run
(
const
FloatA
&
a
,
const
FloatB
&
b
,
FloatC
&
reg_c
)
const
{
intrin_mfma_i32_16x16x64i8
<
MPerXdlops
,
NPerXdlops
>::
Run
(
a
,
b
,
reg_c
);
}
};
template
<
>
struct
mfma_type
<
MfmaInstr
::
mfma_f64_16x16x4f64
>
{
...
...
@@ -638,16 +780,115 @@ struct mfma_type<MfmaInstr::mfma_f32_16x16x32bf8f8>
}
};
// TODO: fix mfma...f8f6f4 instructions
template
<
>
struct
mfma_type
<
MfmaInstr
::
mfma_f32_32x32x64f8f6f4
>
{
// clang-format off
static
constexpr
index_t
group_size
=
4
;
// ??? group_size * num_groups_per_blk == num_regs_per_blk
static
constexpr
index_t
num_groups_per_blk
=
4
;
// ??? group_size * num_groups_per_blk == num_regs_per_blk
static
constexpr
index_t
num_regs_per_blk
=
16
;
// m_per_blk * n_per_blk / wave_size
static
constexpr
index_t
num_threads_per_blk
=
32
;
// n_per_blk
static
constexpr
index_t
wave_size
=
64
;
// fixed
static
constexpr
index_t
num_input_blks
=
2
;
// m_per_blk / num_regs_per_blk
static
constexpr
index_t
num_output_blks
=
1
;
// (is_k_reduction == true) ???
static
constexpr
index_t
m_per_blk
=
32
;
// from the instruction
static
constexpr
index_t
n_per_blk
=
32
;
// from the instruction
static
constexpr
index_t
k_per_blk
=
32
;
// (is_k_reduction == true) ? 64 / num_input_blks
static
constexpr
bool
is_k_reduction
=
true
;
// ???
// clang-format on
template
<
index_t
MPerXdlops
,
index_t
NPerXdlops
,
class
FloatA
,
class
FloatB
,
class
FloatC
>
__device__
void
run
(
const
FloatA
&
a
,
const
FloatB
&
b
,
FloatC
&
reg_c
)
const
{
intrin_mfma_f32_32x32x64f8f6f4
<
MPerXdlops
,
NPerXdlops
>::
Run
(
a
,
b
,
reg_c
);
}
};
template
<
>
struct
mfma_type
<
MfmaInstr
::
mfma_f32_16x16x128f8f6f4
>
{
// clang-format off
static
constexpr
index_t
group_size
=
4
;
// ??? group_size * num_groups_per_blk == num_regs_per_blk
static
constexpr
index_t
num_groups_per_blk
=
1
;
// ??? group_size * num_groups_per_blk == num_regs_per_blk
static
constexpr
index_t
num_regs_per_blk
=
4
;
// m_per_blk * n_per_blk / wave_size
static
constexpr
index_t
num_threads_per_blk
=
16
;
// == n_per_blk
static
constexpr
index_t
wave_size
=
64
;
// fixed
static
constexpr
index_t
num_input_blks
=
4
;
// m_per_blk / num_regs_per_blk
static
constexpr
index_t
num_output_blks
=
1
;
// (is_k_reduction == true) ???
static
constexpr
index_t
m_per_blk
=
16
;
// from the instruction
static
constexpr
index_t
n_per_blk
=
16
;
// from the instruction
static
constexpr
index_t
k_per_blk
=
32
;
// (is_k_reduction == true) ? 128 / num_input_blks
static
constexpr
bool
is_k_reduction
=
true
;
// ???
// clang-format on
template
<
index_t
MPerXdlops
,
index_t
NPerXdlops
,
class
FloatA
,
class
FloatB
,
class
FloatC
>
__device__
void
run
(
const
FloatA
&
a
,
const
FloatB
&
b
,
FloatC
&
reg_c
)
const
{
intrin_mfma_f32_16x16x128f8f6f4
<
MPerXdlops
,
NPerXdlops
>::
Run
(
a
,
b
,
reg_c
);
}
};
template
<
>
struct
mfma_type
<
MfmaInstr
::
mfma_scale_f32_32x32x64f8f6f4
>
{
// clang-format off
static
constexpr
index_t
group_size
=
4
;
// ??? group_size * num_groups_per_blk == num_regs_per_blk
static
constexpr
index_t
num_groups_per_blk
=
4
;
// ??? group_size * num_groups_per_blk == num_regs_per_blk
static
constexpr
index_t
num_regs_per_blk
=
16
;
// m_per_blk * n_per_blk / wave_size
static
constexpr
index_t
num_threads_per_blk
=
32
;
// n_per_blk
static
constexpr
index_t
wave_size
=
64
;
// fixed
static
constexpr
index_t
num_input_blks
=
2
;
// m_per_blk / num_regs_per_blk
static
constexpr
index_t
num_output_blks
=
1
;
// (is_k_reduction == true) ???
static
constexpr
index_t
m_per_blk
=
32
;
// from the instruction
static
constexpr
index_t
n_per_blk
=
32
;
// from the instruction
static
constexpr
index_t
k_per_blk
=
32
;
// (is_k_reduction == true) ? 64 / num_input_blks
static
constexpr
bool
is_k_reduction
=
true
;
// ???
// clang-format on
template
<
index_t
MPerXdlops
,
index_t
NPerXdlops
,
class
FloatA
,
class
FloatB
,
class
FloatC
>
__device__
void
run
(
const
FloatA
&
a
,
const
FloatB
&
b
,
FloatC
&
reg_c
)
const
{
intrin_mfma_scale_f32_32x32x64f8f6f4
<
MPerXdlops
,
NPerXdlops
>::
Run
(
a
,
b
,
reg_c
);
}
};
template
<
>
struct
mfma_type
<
MfmaInstr
::
mfma_scale_f32_16x16x128f8f6f4
>
{
// clang-format off
static
constexpr
index_t
group_size
=
4
;
// ??? group_size * num_groups_per_blk == num_regs_per_blk
static
constexpr
index_t
num_groups_per_blk
=
1
;
// ??? group_size * num_groups_per_blk == num_regs_per_blk
static
constexpr
index_t
num_regs_per_blk
=
4
;
// m_per_blk * n_per_blk / wave_size
static
constexpr
index_t
num_threads_per_blk
=
16
;
// == n_per_blk
static
constexpr
index_t
wave_size
=
64
;
// fixed
static
constexpr
index_t
num_input_blks
=
4
;
// m_per_blk / num_regs_per_blk
static
constexpr
index_t
num_output_blks
=
1
;
// (is_k_reduction == true) ???
static
constexpr
index_t
m_per_blk
=
16
;
// from the instruction
static
constexpr
index_t
n_per_blk
=
16
;
// from the instruction
static
constexpr
index_t
k_per_blk
=
32
;
// (is_k_reduction == true) ? 128 / num_input_blks
static
constexpr
bool
is_k_reduction
=
true
;
// ???
// clang-format on
template
<
index_t
MPerXdlops
,
index_t
NPerXdlops
,
class
FloatA
,
class
FloatB
,
class
FloatC
>
__device__
void
run
(
const
FloatA
&
a
,
const
FloatB
&
b
,
FloatC
&
reg_c
)
const
{
intrin_mfma_scale_f32_16x16x128f8f6f4
<
MPerXdlops
,
NPerXdlops
>::
Run
(
a
,
b
,
reg_c
);
}
};
template
<
typename
base_type
,
index_t
MPerXdlops
,
index_t
NPerXdlops
,
typename
additional_type
=
base_type
>
typename
additional_type
=
base_type
,
bool
is_single_rate_mfma
=
false
>
struct
MfmaSelector
{
template
<
typename
base_type_
,
index_t
MPerXdlops_
,
index_t
NPerXdlops_
,
typename
additional_type_
=
base_type_
>
typename
additional_type_
=
base_type_
,
bool
is_single_rate_mfma_
=
false
>
static
constexpr
auto
GetMfma
();
template
<
>
...
...
@@ -711,13 +952,32 @@ struct MfmaSelector
}
template
<
>
constexpr
auto
GetMfma
<
half_t
,
32
,
32
>
()
constexpr
auto
GetMfma
<
half_t
,
32
,
32
,
half_t
,
false
>
()
{
#if defined(__gfx950__)
return
MfmaInstr
::
mfma_f32_32x32x16f16
;
#else
return
MfmaInstr
::
mfma_f32_32x32x8f16
;
#endif
}
template
<
>
constexpr
auto
GetMfma
<
half_t
,
32
,
32
,
half_t
,
true
>
()
{
return
MfmaInstr
::
mfma_f32_32x32x8f16
;
}
template
<
>
constexpr
auto
GetMfma
<
half_t
,
16
,
16
>
()
constexpr
auto
GetMfma
<
half_t
,
16
,
16
,
half_t
,
false
>
()
{
#if defined(__gfx950__)
return
MfmaInstr
::
mfma_f32_16x16x32f16
;
#else
return
MfmaInstr
::
mfma_f32_16x16x16f16
;
#endif
}
template
<
>
constexpr
auto
GetMfma
<
half_t
,
16
,
16
,
half_t
,
true
>
()
{
return
MfmaInstr
::
mfma_f32_16x16x16f16
;
}
...
...
@@ -741,7 +1001,19 @@ struct MfmaSelector
}
template
<
>
constexpr
auto
GetMfma
<
bhalf_t
,
32
,
32
>
()
constexpr
auto
GetMfma
<
bhalf_t
,
32
,
32
,
bhalf_t
,
false
>
()
{
#if defined(__gfx950__)
return
MfmaInstr
::
mfma_f32_32x32x16bf16
;
#elif defined(CK_USE_AMD_MFMA_BF16_1K_OP)
return
MfmaInstr
::
mfma_f32_32x32x8bf16_1k
;
#else
return
MfmaInstr
::
mfma_f32_32x32x4bf16
;
#endif
}
template
<
>
constexpr
auto
GetMfma
<
bhalf_t
,
32
,
32
,
bhalf_t
,
true
>
()
{
#if defined(CK_USE_AMD_MFMA_BF16_1K_OP)
return
MfmaInstr
::
mfma_f32_32x32x8bf16_1k
;
...
...
@@ -751,7 +1023,19 @@ struct MfmaSelector
}
template
<
>
constexpr
auto
GetMfma
<
bhalf_t
,
16
,
16
>
()
constexpr
auto
GetMfma
<
bhalf_t
,
16
,
16
,
bhalf_t
,
false
>
()
{
#if defined(__gfx950__)
return
MfmaInstr
::
mfma_f32_16x16x32bf16
;
#elif defined(CK_USE_AMD_MFMA_BF16_1K_OP)
return
MfmaInstr
::
mfma_f32_16x16x16bf16_1k
;
#else
return
MfmaInstr
::
mfma_f32_16x16x8bf16
;
#endif
}
template
<
>
constexpr
auto
GetMfma
<
bhalf_t
,
16
,
16
,
bhalf_t
,
true
>
()
{
#if defined(CK_USE_AMD_MFMA_BF16_1K_OP)
return
MfmaInstr
::
mfma_f32_16x16x16bf16_1k
;
...
...
@@ -760,7 +1044,18 @@ struct MfmaSelector
#endif
}
#if defined(CK_USE_AMD_MFMA_GFX940)
#if defined(__gfx950__)
template
<
>
constexpr
auto
GetMfma
<
int8_t
,
32
,
32
>
()
{
return
MfmaInstr
::
mfma_i32_32x32x32i8
;
}
template
<
>
constexpr
auto
GetMfma
<
int8_t
,
16
,
16
>
()
{
return
MfmaInstr
::
mfma_i32_16x16x64i8
;
}
#elif defined(__gfx942__)
template
<
>
constexpr
auto
GetMfma
<
int8_t
,
32
,
32
>
()
{
...
...
@@ -832,8 +1127,8 @@ struct MfmaSelector
return
MfmaInstr
::
mfma_f32_16x16x32bf8f8
;
}
static
constexpr
auto
selected_mfma
=
mfma_type
<
GetMfma
<
base_type
,
MPerXdlops
,
NPerXdlops
,
additional_type
>
()
>
{};
static
constexpr
auto
selected_mfma
=
mfma_type
<
GetMfma
<
base_type
,
MPerXdlops
,
NPerXdlops
,
additional_type
,
is_single_rate_mfma
>
()
>
{};
__host__
__device__
constexpr
MfmaSelector
()
{
...
...
@@ -1135,7 +1430,13 @@ struct XdlopsGemm
return
TransposeC
?
CIndex4D
{
blk_td
,
I0
,
blk_id
,
I0
}
:
CIndex4D
{
I0
,
blk_id
,
I0
,
blk_td
};
}
static
constexpr
auto
mfma
=
MfmaSelector
<
base_type
,
MPerXdlops
,
NPerXdlops
,
additional_type
>
{};
// Falls back to single rate instruction on gfx950 if KPack <= 4; no change on gfx942-
static
constexpr
auto
mfma
=
MfmaSelector
<
base_type
,
MPerXdlops
,
NPerXdlops
,
additional_type
,
((
is_same
<
base_type
,
half_t
>::
value
||
is_same
<
base_type
,
bhalf_t
>::
value
)
&&
KPack
<=
4
)
?
true
:
false
>
{};
static
constexpr
auto
mfma_instr
=
mfma
.
selected_mfma
;
...
...
include/ck/tensor_operation/operator_transform/transform_conv_fwd_to_gemm.hpp
View file @
dbb7002d
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-202
4
, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-202
5
, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck/library/utility/numeric.hpp"
#include "ck/utility/common_header.hpp"
#include "ck/tensor_description/tensor_descriptor.hpp"
#include "ck/tensor_description/tensor_descriptor_helper.hpp"
...
...
@@ -148,8 +147,8 @@ struct TransformConvFwdToGemm
template
<
typename
ConvDimsType
,
typename
ConvSpatialDimsType
,
index_t
NDim
=
NDimSpatial
,
typename
std
::
enable_if
<
NDim
==
1
,
bool
>
::
type
=
false
>
index_t
NDim
=
NDimSpatial
,
typename
ck
::
enable_if
<
NDim
==
1
,
bool
>
::
type
=
false
>
__host__
__device__
TransformConvFwdToGemm
(
const
ConvDimsType
&
a_g_n_c_wis_lengths
,
const
ConvDimsType
&
a_g_n_c_wis_strides
,
const
ConvDimsType
&
b_g_k_c_xs_lengths
,
...
...
@@ -201,11 +200,15 @@ struct TransformConvFwdToGemm
InRightPadW_
{
input_right_pads
[
I0
]},
ZYX_
{
X_
}
{
#ifdef CK_CODE_GEN_RTC
static_assert
(
is_same_v
<
ConvSpatialDimsType
,
ck
::
Array
<
IndexType
,
NDimSpatial
>>
);
static_assert
(
is_same_v
<
ConvDimsType
,
ck
::
Array
<
IndexType
,
NDimSpatial
+
I3
>>
);
#else
static_assert
(
is_same_v
<
ConvSpatialDimsType
,
std
::
array
<
IndexType
,
NDimSpatial
>>
||
is_same_v
<
ConvSpatialDimsType
,
ck
::
Array
<
IndexType
,
NDimSpatial
>>
);
static_assert
(
is_same_v
<
ConvDimsType
,
std
::
array
<
IndexType
,
NDimSpatial
+
I3
>>
||
is_same_v
<
ConvDimsType
,
ck
::
Array
<
IndexType
,
NDimSpatial
+
I3
>>
);
#endif
if
constexpr
(
SplitN
)
{
N_
=
GetSplitedNSize
(
...
...
@@ -219,8 +222,8 @@ struct TransformConvFwdToGemm
template
<
typename
ConvDimsType
,
typename
ConvSpatialDimsType
,
index_t
NDim
=
NDimSpatial
,
typename
std
::
enable_if
<
NDim
==
2
,
bool
>
::
type
=
false
>
index_t
NDim
=
NDimSpatial
,
typename
ck
::
enable_if
<
NDim
==
2
,
bool
>
::
type
=
false
>
__host__
__device__
TransformConvFwdToGemm
(
const
ConvDimsType
&
a_g_n_c_wis_lengths
,
const
ConvDimsType
&
a_g_n_c_wis_strides
,
const
ConvDimsType
&
b_g_k_c_xs_lengths
,
...
...
@@ -272,11 +275,15 @@ struct TransformConvFwdToGemm
InRightPadW_
{
input_right_pads
[
I1
]},
ZYX_
{
Y_
*
X_
}
{
#ifdef CK_CODE_GEN_RTC
static_assert
(
is_same_v
<
ConvSpatialDimsType
,
ck
::
Array
<
IndexType
,
NDimSpatial
>>
);
static_assert
(
is_same_v
<
ConvDimsType
,
ck
::
Array
<
IndexType
,
NDimSpatial
+
I3
>>
);
#else
static_assert
(
is_same_v
<
ConvSpatialDimsType
,
std
::
array
<
IndexType
,
NDimSpatial
>>
||
is_same_v
<
ConvSpatialDimsType
,
ck
::
Array
<
IndexType
,
NDimSpatial
>>
);
static_assert
(
is_same_v
<
ConvDimsType
,
std
::
array
<
IndexType
,
NDimSpatial
+
I3
>>
||
is_same_v
<
ConvDimsType
,
ck
::
Array
<
IndexType
,
NDimSpatial
+
I3
>>
);
#endif
if
constexpr
(
SplitN
)
{
N_
=
GetSplitedNSize
(
...
...
@@ -290,8 +297,8 @@ struct TransformConvFwdToGemm
template
<
typename
ConvDimsType
,
typename
ConvSpatialDimsType
,
index_t
NDim
=
NDimSpatial
,
typename
std
::
enable_if
<
NDim
==
3
,
bool
>
::
type
=
false
>
index_t
NDim
=
NDimSpatial
,
typename
ck
::
enable_if
<
NDim
==
3
,
bool
>
::
type
=
false
>
__host__
__device__
TransformConvFwdToGemm
(
const
ConvDimsType
&
a_g_n_c_wis_lengths
,
const
ConvDimsType
&
a_g_n_c_wis_strides
,
const
ConvDimsType
&
b_g_k_c_xs_lengths
,
...
...
@@ -343,11 +350,15 @@ struct TransformConvFwdToGemm
InRightPadW_
{
input_right_pads
[
I2
]},
ZYX_
{
Z_
*
Y_
*
X_
}
{
#ifdef CK_CODE_GEN_RTC
static_assert
(
is_same_v
<
ConvSpatialDimsType
,
ck
::
Array
<
IndexType
,
NDimSpatial
>>
);
static_assert
(
is_same_v
<
ConvDimsType
,
ck
::
Array
<
IndexType
,
NDimSpatial
+
I3
>>
);
#else
static_assert
(
is_same_v
<
ConvSpatialDimsType
,
std
::
array
<
IndexType
,
NDimSpatial
>>
||
is_same_v
<
ConvSpatialDimsType
,
ck
::
Array
<
IndexType
,
NDimSpatial
>>
);
static_assert
(
is_same_v
<
ConvDimsType
,
std
::
array
<
IndexType
,
NDimSpatial
+
I3
>>
||
is_same_v
<
ConvDimsType
,
ck
::
Array
<
IndexType
,
NDimSpatial
+
I3
>>
);
#endif
if
constexpr
(
SplitN
)
{
N_
=
GetSplitedNSize
(
...
...
@@ -478,11 +489,11 @@ struct TransformConvFwdToGemm
// TODO: implement ck::tensor_layout::convolution that describe packed/strided dimemsion as
// properties
template
<
typename
ALayout
,
typename
std
::
enable_if
<
NDimSpatial
==
1
&&
(
is_same_v
<
ALayout
,
tensor_layout
::
convolution
::
G_NW_C
>
||
is_same_v
<
ALayout
,
tensor_layout
::
convolution
::
NWGC
>
||
is_same_v
<
ALayout
,
tensor_layout
::
convolution
::
GNWC
>
),
bool
>::
type
=
false
>
typename
ck
::
enable_if
<
NDimSpatial
==
1
&&
(
is_same_v
<
ALayout
,
tensor_layout
::
convolution
::
G_NW_C
>
||
is_same_v
<
ALayout
,
tensor_layout
::
convolution
::
NWGC
>
||
is_same_v
<
ALayout
,
tensor_layout
::
convolution
::
GNWC
>
),
bool
>::
type
=
false
>
__host__
__device__
auto
MakeADescriptor_M_K
()
const
{
if
constexpr
(
ConvForwardSpecialization
==
...
...
@@ -691,11 +702,11 @@ struct TransformConvFwdToGemm
}
template
<
typename
ALayout
,
typename
std
::
enable_if
<
NDimSpatial
==
2
&&
(
is_same_v
<
ALayout
,
tensor_layout
::
convolution
::
G_NHW_C
>
||
is_same_v
<
ALayout
,
tensor_layout
::
convolution
::
NHWGC
>
||
is_same_v
<
ALayout
,
tensor_layout
::
convolution
::
GNHWC
>
),
bool
>::
type
=
false
>
typename
ck
::
enable_if
<
NDimSpatial
==
2
&&
(
is_same_v
<
ALayout
,
tensor_layout
::
convolution
::
G_NHW_C
>
||
is_same_v
<
ALayout
,
tensor_layout
::
convolution
::
NHWGC
>
||
is_same_v
<
ALayout
,
tensor_layout
::
convolution
::
GNHWC
>
),
bool
>::
type
=
false
>
__host__
__device__
auto
MakeADescriptor_M_K
()
const
{
...
...
@@ -932,7 +943,7 @@ struct TransformConvFwdToGemm
}
template
<
typename
ALayout
,
typename
std
::
enable_if
<
typename
ck
::
enable_if
<
NDimSpatial
==
3
&&
(
is_same_v
<
ALayout
,
tensor_layout
::
convolution
::
G_NDHW_C
>
||
is_same_v
<
ALayout
,
tensor_layout
::
convolution
::
NDHWGC
>
||
is_same_v
<
ALayout
,
tensor_layout
::
convolution
::
GNDHWC
>
),
...
...
@@ -1242,19 +1253,19 @@ struct TransformConvFwdToGemm
}
template
<
typename
BLayout
,
typename
std
::
enable_if
<
is_same_v
<
BLayout
,
tensor_layout
::
convolution
::
GKXC
>
||
is_same_v
<
BLayout
,
tensor_layout
::
convolution
::
GKYXC
>
||
is_same_v
<
BLayout
,
tensor_layout
::
convolution
::
GKZYXC
>
,
bool
>::
type
=
false
>
typename
ck
::
enable_if
<
is_same_v
<
BLayout
,
tensor_layout
::
convolution
::
GKXC
>
||
is_same_v
<
BLayout
,
tensor_layout
::
convolution
::
GKYXC
>
||
is_same_v
<
BLayout
,
tensor_layout
::
convolution
::
GKZYXC
>
,
bool
>::
type
=
false
>
__host__
__device__
auto
MakeBDescriptor_N_K
()
const
{
if
constexpr
(
ConvForwardSpecialization
==
device
::
ConvolutionForwardSpecialization
::
Filter3x3
)
{
using
FilterSizeNumType
=
std
::
conditional_t
<
NDimSpatial
==
1
,
Number
<
3
>
,
std
::
conditional_t
<
NDimSpatial
==
2
,
Number
<
9
>
,
Number
<
27
>>>
;
ck
::
conditional_t
<
NDimSpatial
==
1
,
Number
<
3
>
,
ck
::
conditional_t
<
NDimSpatial
==
2
,
Number
<
9
>
,
Number
<
27
>>>
;
if
constexpr
(
NumGroupsToMerge
==
1
)
{
...
...
@@ -1297,13 +1308,13 @@ struct TransformConvFwdToGemm
template
<
typename
BLayout
,
typename
std
::
enable_if
<
is_same_v
<
BLayout
,
tensor_layout
::
convolution
::
G_K_X_C
>
||
is_same_v
<
BLayout
,
tensor_layout
::
convolution
::
G_K_YX_C
>
||
is_same_v
<
BLayout
,
tensor_layout
::
convolution
::
G_K_ZYX_C
>
||
is_same_v
<
BLayout
,
tensor_layout
::
convolution
::
KXGC
>
||
is_same_v
<
BLayout
,
tensor_layout
::
convolution
::
KYXGC
>
||
is_same_v
<
BLayout
,
tensor_layout
::
convolution
::
KZYXGC
>
,
bool
>::
type
=
false
>
typename
ck
::
enable_if
<
is_same_v
<
BLayout
,
tensor_layout
::
convolution
::
G_K_X_C
>
||
is_same_v
<
BLayout
,
tensor_layout
::
convolution
::
G_K_YX_C
>
||
is_same_v
<
BLayout
,
tensor_layout
::
convolution
::
G_K_ZYX_C
>
||
is_same_v
<
BLayout
,
tensor_layout
::
convolution
::
KXGC
>
||
is_same_v
<
BLayout
,
tensor_layout
::
convolution
::
KYXGC
>
||
is_same_v
<
BLayout
,
tensor_layout
::
convolution
::
KZYXGC
>
,
bool
>::
type
=
false
>
__host__
__device__
auto
MakeBDescriptor_N_K
()
const
{
const
auto
wei_k_yx_c_desc
=
make_naive_tensor_descriptor
(
...
...
@@ -1318,36 +1329,36 @@ struct TransformConvFwdToGemm
return
wei_gemmn_gemmk_desc
;
}
template
<
typename
CLayout
,
index_t
NDimSp
=
NDimSpatial
,
template
<
typename
CLayout
,
index_t
NDimSp
=
NDimSpatial
,
typename
std
::
enable_if
<
NDimSp
==
1
&&
(
is_same_v
<
CLayout
,
tensor_layout
::
convolution
::
G_K
>),
bool
>::
type
=
false
>
typename
ck
::
enable_if
<
NDimSp
==
1
&&
(
is_same_v
<
CLayout
,
tensor_layout
::
convolution
::
G_K
>),
bool
>::
type
=
false
>
__host__
__device__
auto
MakeCDescriptor_M_N
()
const
{
return
make_naive_tensor_descriptor
(
make_tuple
(
N_
*
Wo_
,
K_
),
make_tuple
(
I0
,
KStrideTensorC_
));
}
template
<
typename
CLayout
,
index_t
NDimSp
=
NDimSpatial
,
template
<
typename
CLayout
,
index_t
NDimSp
=
NDimSpatial
,
typename
std
::
enable_if
<
NDimSp
==
2
&&
(
is_same_v
<
CLayout
,
tensor_layout
::
convolution
::
G_K
>),
bool
>::
type
=
false
>
typename
ck
::
enable_if
<
NDimSp
==
2
&&
(
is_same_v
<
CLayout
,
tensor_layout
::
convolution
::
G_K
>),
bool
>::
type
=
false
>
__host__
__device__
auto
MakeCDescriptor_M_N
()
const
{
return
make_naive_tensor_descriptor
(
make_tuple
(
N_
*
Ho_
*
Wo_
,
K_
),
make_tuple
(
I0
,
KStrideTensorC_
));
}
template
<
typename
CLayout
,
index_t
NDimSp
=
NDimSpatial
,
template
<
typename
CLayout
,
index_t
NDimSp
=
NDimSpatial
,
typename
std
::
enable_if
<
NDimSp
==
3
&&
(
is_same_v
<
CLayout
,
tensor_layout
::
convolution
::
G_K
>),
bool
>::
type
=
false
>
typename
ck
::
enable_if
<
NDimSp
==
3
&&
(
is_same_v
<
CLayout
,
tensor_layout
::
convolution
::
G_K
>),
bool
>::
type
=
false
>
__host__
__device__
auto
MakeCDescriptor_M_N
()
const
{
return
make_naive_tensor_descriptor
(
make_tuple
(
N_
*
Do_
*
Ho_
*
Wo_
,
K_
),
...
...
@@ -1355,12 +1366,12 @@ struct TransformConvFwdToGemm
}
template
<
typename
CLayout
,
index_t
NDimSp
=
NDimSpatial
,
typename
std
::
enable_if
<
NDimSp
==
1
&&
(
is_same_v
<
CLayout
,
tensor_layout
::
convolution
::
G_NW_K
>
||
is_same_v
<
CLayout
,
tensor_layout
::
convolution
::
NWGK
>
||
is_same_v
<
CLayout
,
tensor_layout
::
convolution
::
GNWK
>
),
bool
>::
type
=
false
>
index_t
NDimSp
=
NDimSpatial
,
typename
ck
::
enable_if
<
NDimSp
==
1
&&
(
is_same_v
<
CLayout
,
tensor_layout
::
convolution
::
G_NW_K
>
||
is_same_v
<
CLayout
,
tensor_layout
::
convolution
::
NWGK
>
||
is_same_v
<
CLayout
,
tensor_layout
::
convolution
::
GNWK
>
),
bool
>::
type
=
false
>
__host__
__device__
auto
MakeCDescriptor_M_N
()
const
{
const
IndexType
NDoHoWo
=
N_
*
Wo_
;
...
...
@@ -1410,11 +1421,11 @@ struct TransformConvFwdToGemm
template
<
typename
CLayout
,
index_t
NDimSp
=
NDimSpatial
,
typename
std
::
enable_if
<
NDimSp
==
2
&&
(
is_same_v
<
CLayout
,
tensor_layout
::
convolution
::
G_NHW_K
>
||
is_same_v
<
CLayout
,
tensor_layout
::
convolution
::
NHWGK
>
||
is_same_v
<
CLayout
,
tensor_layout
::
convolution
::
GNHWK
>
),
bool
>::
type
=
false
>
typename
ck
::
enable_if
<
NDimSp
==
2
&&
(
is_same_v
<
CLayout
,
tensor_layout
::
convolution
::
G_NHW_K
>
||
is_same_v
<
CLayout
,
tensor_layout
::
convolution
::
NHWGK
>
||
is_same_v
<
CLayout
,
tensor_layout
::
convolution
::
GNHWK
>
),
bool
>::
type
=
false
>
__host__
__device__
auto
MakeCDescriptor_M_N
()
const
{
const
IndexType
NDoHoWo
=
N_
*
Ho_
*
Wo_
;
...
...
@@ -1467,7 +1478,7 @@ struct TransformConvFwdToGemm
template
<
typename
CLayout
,
index_t
NDimSp
=
NDimSpatial
,
typename
std
::
enable_if
<
typename
ck
::
enable_if
<
NDimSp
==
3
&&
(
is_same_v
<
CLayout
,
tensor_layout
::
convolution
::
G_NDHW_K
>
||
is_same_v
<
CLayout
,
tensor_layout
::
convolution
::
NDHWGK
>
||
is_same_v
<
CLayout
,
tensor_layout
::
convolution
::
GNDHWK
>
),
...
...
include/ck/utility/amd_buffer_addressing.hpp
View file @
dbb7002d
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-202
4
, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-202
5
, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "data_type.hpp"
...
...
@@ -581,7 +581,7 @@ __device__ void amd_global_atomic_add_impl(const typename vector_type<T, N>::typ
tmp
.
template
AsType
<
half2_t
>()[
i
]);
});
}
#if defined(__gfx942__)
#if defined(__gfx942__)
|| defined(__gfx950__)
else
if
constexpr
(
is_same
<
T
,
bhalf_t
>::
value
)
{
vector_type
<
bhalf_t
,
N
>
tmp
{
src_thread_data
};
...
...
@@ -1021,15 +1021,24 @@ __device__ void amd_direct_load_global_to_lds(const T* global_base_ptr,
constexpr
auto
bytes_per_thread
=
sizeof
(
T
)
*
NumElemsPerThread
;
static_assert
(
bytes_per_thread
==
dword_bytes
);
#ifndef CK_CODE_GEN_RTC
const
uint32_t
*
global_ptr
=
reinterpret_cast
<
uint32_t
*>
(
reinterpret_cast
<
uintptr_t
>
(
global_base_ptr
));
#else
const
uint32_t
*
global_ptr
=
reinterpret_cast
<
uint32_t
*>
(
reinterpret_cast
<
size_t
>
(
global_base_ptr
));
#endif
const
int32x4_t
src_resource
=
make_wave_buffer_resource
(
global_ptr
,
src_element_space_size
);
const
index_t
global_offset_bytes
=
is_valid
?
global_offset
*
sizeof
(
T
)
:
0x80000000
;
#if CK_USE_AMD_LDS_DIRECT_LOAD_INLINE_ASM
T
*
lds_ptr
=
lds_base_ptr
+
lds_offset
;
#ifndef CK_CODE_GEN_RTC
auto
const
lds_ptr_sgpr
=
__builtin_amdgcn_readfirstlane
((
reinterpret_cast
<
uintptr_t
>
(
lds_ptr
)));
#else
auto
const
lds_ptr_sgpr
=
__builtin_amdgcn_readfirstlane
((
reinterpret_cast
<
size_t
>
(
lds_ptr
)));
#endif
asm
volatile
(
"s_mov_b32 m0, %0;
\n\t
"
"buffer_load_dword %1, %2, 0 offen lds;
\n\t
"
::
"s"
(
lds_ptr_sgpr
),
"v"
(
global_offset_bytes
),
...
...
@@ -1038,8 +1047,13 @@ __device__ void amd_direct_load_global_to_lds(const T* global_base_ptr,
#else
// LDS pointer must be attributed with the LDS address space.
__attribute__
((
address_space
(
3
)))
uint32_t
*
lds_ptr
=
#ifndef CK_CODE_GEN_RTC
reinterpret_cast
<
__attribute__
((
address_space
(
3
)))
uint32_t
*>
(
reinterpret_cast
<
uintptr_t
>
(
lds_base_ptr
+
lds_offset
));
#else
reinterpret_cast
<
__attribute__
((
address_space
(
3
)))
uint32_t
*>
(
reinterpret_cast
<
size_t
>
(
lds_base_ptr
+
lds_offset
));
#endif
llvm_amdgcn_raw_buffer_load_lds
(
src_resource
,
lds_ptr
,
sizeof
(
uint32_t
),
global_offset_bytes
,
0
,
0
,
0
);
...
...
include/ck/utility/amd_ck_fp8.hpp
View file @
dbb7002d
// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2024
-2025
, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck/ck.hpp"
#include "ck/utility/enable_if.hpp"
#include "ck/utility/random_gen.hpp"
#include "ck/utility/type.hpp"
...
...
@@ -18,39 +20,25 @@
#define CK_USE_OCP_FP8 0
#endif
namespace
{
// https://en.cppreference.com/w/cpp/types/conditional
template
<
bool
B
,
class
T
,
class
F
>
struct
conditional
{
using
type
=
T
;
};
template
<
class
T
,
class
F
>
struct
conditional
<
false
,
T
,
F
>
{
using
type
=
F
;
};
}
// namespace
namespace
ck
{
using
f8_fnuz_t
=
_BitInt
(
8
);
using
bf8_fnuz_t
=
unsigned
_BitInt
(
8
);
#if(defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__) || defined(__gfx1200__) || \
defined(__gfx1201__)
) &&
\
defined(__gfx1201__)
|| defined(__gfx950__)) &&
\
__HIP_DEVICE_COMPILE__
#define CK_FP8_CVT_FAST_PATH 1
#else
#define CK_FP8_CVT_FAST_PATH 0
#endif
#if(defined(__gfx1200__) || defined(__gfx1201__)) && __HIP_DEVICE_COMPILE__
#if(defined(__gfx1200__) || defined(__gfx1201__)
|| defined(__gfx950__)
) && __HIP_DEVICE_COMPILE__
#define CK_OCP_FP8_CVT_FAST_PATH 1
#else
#define CK_OCP_FP8_CVT_FAST_PATH 0
#endif
namespace
ck
{
using
f8_fnuz_t
=
_BitInt
(
8
);
using
bf8_fnuz_t
=
unsigned
_BitInt
(
8
);
typedef
unsigned
char
fp8_storage_t
;
/**
...
...
@@ -205,10 +193,11 @@ __host__ __device__ static inline T cast_from_f8(fp8_storage_t x)
}
}
typename
conditional
<
typename
std
::
conditional
<
sizeof
(
T
)
==
2
,
unsigned
short
int
,
typename
conditional
<
sizeof
(
T
)
==
4
,
unsigned
int
,
unsigned
long
long
>::
type
>::
type
retval
;
typename
std
::
conditional
<
sizeof
(
T
)
==
4
,
unsigned
int
,
unsigned
long
long
>::
type
>::
type
retval
;
if
constexpr
(
we
==
5
&&
is_half
&&
!
is_fnuz
)
{
...
...
@@ -301,7 +290,6 @@ static __device__ float2_t cast_to_f32x2_from_f8x2(fp8x2_storage_t v)
return
__builtin_amdgcn_cvt_pk_f32_bf8
(
i16val
,
false
);
}
}
#endif
}
// namespace fp8_impl
...
...
@@ -376,7 +364,7 @@ struct bf8_ocp_t
__host__
explicit
operator
float
()
const
#endif
{
#if defined(__gfx1200__) || defined(__gfx1201__)
#if
defined(__gfx950__) ||
defined(__gfx1200__) || defined(__gfx1201__)
return
fp8_impl
::
cast_to_f32_from_f8
<
default_interpret
>
(
this
->
data
);
#else
return
fp8_impl
::
cast_from_f8
<
float
,
wm
,
we
,
false
>
(
...
...
@@ -390,7 +378,7 @@ struct bf8_ocp_t
__host__
explicit
operator
_Float16
()
const
#endif
{
#if defined(__gfx1200__) || defined(__gfx1201__)
#if
defined(__gfx950__) ||
defined(__gfx1200__) || defined(__gfx1201__)
return
static_cast
<
_Float16
>
(
fp8_impl
::
cast_to_f32_from_f8
<
default_interpret
>
(
this
->
data
));
#else
return
fp8_impl
::
cast_from_f8
<
_Float16
,
wm
,
we
,
false
>
(
...
...
@@ -424,9 +412,9 @@ __host__ __device__ inline constexpr bool fp8_is_nan(bf8_fnuz_t a)
}
template
<
typename
T
,
std
::
enable_if_t
<
std
::
is_same_v
<
T
,
bf8_ocp_t
>
||
std
::
is_same_v
<
T
,
f8_ocp_t
>
||
std
::
is_same_v
<
T
,
bf8_fnuz_t
>
||
std
::
is_same_v
<
T
,
f8_fnuz_t
>
,
bool
>
=
true
>
ck
::
enable_if_t
<
is_same_v
<
T
,
bf8_ocp_t
>
||
is_same_v
<
T
,
f8_ocp_t
>
||
is_same_v
<
T
,
bf8_fnuz_t
>
||
is_same_v
<
T
,
f8_fnuz_t
>
,
bool
>
=
true
>
__host__
__device__
static
inline
constexpr
bool
fp8_is_inf
(
T
)
{
return
false
;
...
...
@@ -551,10 +539,10 @@ __host__ __device__ static inline fp8_storage_t cast_to_f8(T _x, unsigned int rn
constexpr
int
mfmt
=
(
sizeof
(
T
)
==
8
)
?
52
:
((
sizeof
(
T
)
==
4
)
?
23
:
10
);
using
T_bitwise
=
typename
conditional
<
using
T_bitwise
=
typename
std
::
conditional
<
sizeof
(
T
)
==
2
,
unsigned
short
int
,
typename
conditional
<
sizeof
(
T
)
==
4
,
unsigned
int
,
unsigned
long
long
>::
type
>::
type
;
typename
std
::
conditional
<
sizeof
(
T
)
==
4
,
unsigned
int
,
unsigned
long
long
>::
type
>::
type
;
T_bitwise
x_bitwise
=
bit_cast
<
T_bitwise
>
(
_x
);
unsigned
long
long
x
{
x_bitwise
};
...
...
@@ -823,7 +811,11 @@ __host__ __device__ static inline fp8_storage_t cvt_float_to_fp8(const float f)
if
constexpr
(
stochastic_rounding
)
{
constexpr
int
seed
=
1254739
;
rng
=
prand_generator
<
float
,
seed
>
(
reinterpret_cast
<
uintptr_t
>
(
&
f
),
f
);
#ifndef CK_CODE_GEN_RTC
rng
=
prand_generator
<
float
,
seed
>
(
reinterpret_cast
<
uintptr_t
>
(
&
f
),
f
);
#else
rng
=
prand_generator
<
float
,
seed
>
(
reinterpret_cast
<
size_t
>
(
&
f
),
f
);
#endif
}
return
cast_to_f8_from_f32
<
interp
,
sat
==
ck_saturation_t
::
CK_SATFINITE
,
stochastic_rounding
>
(
f
,
rng
);
...
...
@@ -839,7 +831,11 @@ __host__ static inline fp8_storage_t cvt_float_to_fp8(const float f)
if
constexpr
(
stochastic_rounding
)
{
constexpr
int
seed
=
1254739
;
#ifndef CK_CODE_GEN_RTC
rng
=
prand_generator
<
float
,
seed
>
(
reinterpret_cast
<
uintptr_t
>
(
&
f
),
f
);
#else
rng
=
prand_generator
<
float
,
seed
>
(
reinterpret_cast
<
size_t
>
(
&
f
),
f
);
#endif
}
if
constexpr
(
interp
==
ck_fp8_interpretation_t
::
CK_E4M3_FNUZ
)
...
...
include/ck/utility/amd_wave_read_first_lane.hpp
View file @
dbb7002d
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-202
4
, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-202
5
, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
...
...
@@ -7,10 +7,12 @@
#include "ck/utility/functional2.hpp"
#include "ck/utility/math.hpp"
#ifndef CK_CODE_GEN_RTC
#include <array>
#include <cstddef>
#include <cstdint>
#include <type_traits>
#endif
namespace
ck
{
namespace
detail
{
...
...
@@ -37,7 +39,7 @@ struct get_carrier<3>
{
using
value_type
=
uint32_t
;
std
::
a
rray
<
std
::
byte
,
3
>
bytes
;
A
rray
<
ck
::
byte
,
3
>
bytes
;
static_assert
(
sizeof
(
bytes
)
<=
sizeof
(
value_type
));
// replacement of host std::copy_n()
...
...
@@ -61,22 +63,22 @@ struct get_carrier<3>
// method to trigger template substitution failure
__device__
carrier
(
const
carrier
&
other
)
noexcept
{
copy_n
(
other
.
bytes
.
begin
(),
bytes
.
s
ize
(),
bytes
.
begin
());
copy_n
(
other
.
bytes
.
begin
(),
bytes
.
S
ize
(),
bytes
.
begin
());
}
public:
__device__
carrier
&
operator
=
(
value_type
value
)
noexcept
{
copy_n
(
reinterpret_cast
<
const
std
::
byte
*>
(
&
value
),
bytes
.
s
ize
(),
bytes
.
begin
());
copy_n
(
reinterpret_cast
<
const
ck
::
byte
*>
(
&
value
),
bytes
.
S
ize
(),
bytes
.
begin
());
return
*
this
;
}
__device__
operator
value_type
()
const
noexcept
{
std
::
byte
result
[
sizeof
(
value_type
)];
ck
::
byte
result
[
sizeof
(
value_type
)];
copy_n
(
bytes
.
begin
(),
bytes
.
s
ize
(),
result
);
copy_n
(
bytes
.
begin
(),
bytes
.
S
ize
(),
result
);
return
*
reinterpret_cast
<
const
value_type
*>
(
result
);
}
...
...
@@ -109,8 +111,8 @@ __device__ inline int64_t amd_wave_read_first_lane(int64_t value)
{
constexpr
unsigned
object_size
=
sizeof
(
int64_t
);
constexpr
unsigned
second_part_offset
=
object_size
/
2
;
auto
*
const
from_obj
=
reinterpret_cast
<
const
std
::
byte
*>
(
&
value
);
alignas
(
int64_t
)
std
::
byte
to_obj
[
object_size
];
auto
*
const
from_obj
=
reinterpret_cast
<
const
ck
::
byte
*>
(
&
value
);
alignas
(
int64_t
)
ck
::
byte
to_obj
[
object_size
];
using
Sgpr
=
uint32_t
;
...
...
@@ -122,17 +124,16 @@ __device__ inline int64_t amd_wave_read_first_lane(int64_t value)
return
*
reinterpret_cast
<
int64_t
*>
(
to_obj
);
}
template
<
typename
Object
,
typename
=
std
::
enable_if_t
<
std
::
is_class_v
<
Object
>
&&
std
::
is_trivially_copyable_v
<
Object
>>>
template
<
typename
Object
,
typename
=
ck
::
enable_if_t
<
ck
::
is_class_v
<
Object
>
&&
ck
::
is_trivially_copyable_v
<
Object
>>>
__device__
auto
amd_wave_read_first_lane
(
const
Object
&
obj
)
{
using
Size
=
unsigned
;
constexpr
Size
SgprSize
=
4
;
constexpr
Size
ObjectSize
=
sizeof
(
Object
);
auto
*
const
from_obj
=
reinterpret_cast
<
const
std
::
byte
*>
(
&
obj
);
alignas
(
Object
)
std
::
byte
to_obj
[
ObjectSize
];
auto
*
const
from_obj
=
reinterpret_cast
<
const
ck
::
byte
*>
(
&
obj
);
alignas
(
Object
)
ck
::
byte
to_obj
[
ObjectSize
];
constexpr
Size
RemainedSize
=
ObjectSize
%
SgprSize
;
constexpr
Size
CompleteSgprCopyBoundary
=
ObjectSize
-
RemainedSize
;
...
...
include/ck/utility/amd_xdlops.hpp
View file @
dbb7002d
...
...
@@ -5,7 +5,7 @@
namespace
ck
{
// Define the common macro for MI300 models
#if defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__)
#if defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__)
|| defined(__gfx950__)
#define __gfx94__
#endif
...
...
@@ -134,6 +134,46 @@ struct intrin_mfma_f32_32x32x4f16<32, 64>
}
};
template
<
index_t
MPerWave
,
index_t
NPerWave
>
struct
intrin_mfma_f32_32x32x16f16
;
template
<
>
struct
intrin_mfma_f32_32x32x16f16
<
32
,
32
>
{
template
<
class
FloatC
>
__device__
static
void
Run
(
const
half8_t
&
reg_a
,
const
half8_t
&
reg_b
,
FloatC
&
reg_c
)
{
#if defined(__gfx950__)
reg_c
.
template
AsType
<
float16_t
>()(
Number
<
0
>
{})
=
__builtin_amdgcn_mfma_f32_32x32x16_f16
(
reg_a
,
reg_b
,
reg_c
.
template
AsType
<
float16_t
>()[
Number
<
0
>
{}],
0
,
0
,
0
);
#else
ignore
=
reg_a
;
ignore
=
reg_b
;
ignore
=
reg_c
;
#endif // defined(__gfx950__)
}
};
template
<
index_t
MPerWave
,
index_t
NPerWave
>
struct
intrin_mfma_f32_16x16x32f16
;
template
<
>
struct
intrin_mfma_f32_16x16x32f16
<
16
,
16
>
{
template
<
class
FloatC
>
__device__
static
void
Run
(
const
half8_t
&
reg_a
,
const
half8_t
&
reg_b
,
FloatC
&
reg_c
)
{
#if defined(__gfx950__)
reg_c
.
template
AsType
<
float4_t
>()(
Number
<
0
>
{})
=
__builtin_amdgcn_mfma_f32_16x16x32_f16
(
reg_a
,
reg_b
,
reg_c
.
template
AsType
<
float4_t
>()[
Number
<
0
>
{}],
0
,
0
,
0
);
#else
ignore
=
reg_a
;
ignore
=
reg_b
;
ignore
=
reg_c
;
#endif // defined(__gfx950__)
}
};
template
<
index_t
MPerWave
,
index_t
NPerWave
>
struct
intrin_mfma_f32_32x32x8f16
;
...
...
@@ -204,6 +244,46 @@ struct intrin_mfma_f32_4x4x4f16<8, 64>
};
// bfp16
template
<
index_t
MPerWave
,
index_t
NPerWave
>
struct
intrin_mfma_f32_32x32x16bf16
;
template
<
>
struct
intrin_mfma_f32_32x32x16bf16
<
32
,
32
>
{
template
<
class
FloatC
>
__device__
static
void
Run
(
const
bhalf8_t
&
reg_a
,
const
bhalf8_t
&
reg_b
,
FloatC
&
reg_c
)
{
#if defined(__gfx950__)
reg_c
.
template
AsType
<
float16_t
>()(
Number
<
0
>
{})
=
__builtin_amdgcn_mfma_f32_32x32x16_bf16
(
reg_a
,
reg_b
,
reg_c
.
template
AsType
<
float16_t
>()[
Number
<
0
>
{}],
0
,
0
,
0
);
#else
ignore
=
reg_a
;
ignore
=
reg_b
;
ignore
=
reg_c
;
#endif // defined(__gfx950__)
}
};
template
<
index_t
MPerWave
,
index_t
NPerWave
>
struct
intrin_mfma_f32_16x16x32bf16
;
template
<
>
struct
intrin_mfma_f32_16x16x32bf16
<
16
,
16
>
{
template
<
class
FloatC
>
__device__
static
void
Run
(
const
bhalf8_t
&
reg_a
,
const
bhalf8_t
&
reg_b
,
FloatC
&
reg_c
)
{
#if defined(__gfx950__)
reg_c
.
template
AsType
<
float4_t
>()(
Number
<
0
>
{})
=
__builtin_amdgcn_mfma_f32_16x16x32_bf16
(
reg_a
,
reg_b
,
reg_c
.
template
AsType
<
float4_t
>()[
Number
<
0
>
{}],
0
,
0
,
0
);
#else
ignore
=
reg_a
;
ignore
=
reg_b
;
ignore
=
reg_c
;
#endif // defined(__gfx950__)
}
};
template
<
index_t
MPerWave
,
index_t
NPerWave
>
struct
intrin_mfma_f32_32x32x8bf16_1k
;
...
...
@@ -298,6 +378,46 @@ struct intrin_mfma_i32_16x16x16i8<16, 16>
}
};
template
<
index_t
MPerWave
,
index_t
NPerWave
>
struct
intrin_mfma_i32_32x32x32i8
;
template
<
>
struct
intrin_mfma_i32_32x32x32i8
<
32
,
32
>
{
template
<
class
FloatC
>
__device__
static
void
Run
(
const
int8x16_t
&
reg_a
,
const
int8x16_t
&
reg_b
,
FloatC
&
reg_c
)
{
#if defined(__gfx950__)
reg_c
.
template
AsType
<
int32x16_t
>()(
Number
<
0
>
{})
=
__builtin_amdgcn_mfma_i32_32x32x32_i8
(
reg_a
,
reg_b
,
reg_c
.
template
AsType
<
int32x16_t
>()[
Number
<
0
>
{}],
0
,
0
,
0
);
#else
ignore
=
reg_a
;
ignore
=
reg_b
;
ignore
=
reg_c
;
#endif // defined(__gfx950__)
}
};
template
<
index_t
MPerWave
,
index_t
NPerWave
>
struct
intrin_mfma_i32_16x16x64i8
;
template
<
>
struct
intrin_mfma_i32_16x16x64i8
<
16
,
16
>
{
template
<
class
FloatC
>
__device__
static
void
Run
(
const
int8x16_t
&
reg_a
,
const
int8x16_t
&
reg_b
,
FloatC
&
reg_c
)
{
#if defined(__gfx950__)
reg_c
.
template
AsType
<
int32x4_t
>()(
Number
<
0
>
{})
=
__builtin_amdgcn_mfma_i32_16x16x64_i8
(
reg_a
,
reg_b
,
reg_c
.
template
AsType
<
int32x4_t
>()[
Number
<
0
>
{}],
0
,
0
,
0
);
#else
ignore
=
reg_a
;
ignore
=
reg_b
;
ignore
=
reg_c
;
#endif // defined(__gfx950__)
}
};
template
<
index_t
MPerWave
,
index_t
NPerWave
>
struct
intrin_mfma_i32_32x32x16i8
;
...
...
@@ -356,6 +476,149 @@ struct intrin_mfma_f64_16x16x4f64<16, 16>
}
};
template
<
index_t
MPerWave
,
index_t
NPerWave
>
struct
intrin_mfma_f32_32x32x64f8f6f4
;
/// @brief Performs a matrix fused multiply-accumulate operation on 32x32x64 submatrices for f8, f6,
/// and f4 data types.
///
/// @note Calls scaled version of the instruction as the original instruction is not supported in
/// the backend. That is the intended use. There is a backend optimization to select the unscaled
/// operation if the scale is 0.
template
<
>
struct
intrin_mfma_f32_32x32x64f8f6f4
<
32
,
32
>
{
template
<
class
FloatC
>
__device__
static
void
Run
(
const
f8x32_t
&
reg_a
,
const
f8x32_t
&
reg_b
,
FloatC
&
reg_c
)
{
#if defined(__gfx950__)
reg_c
.
template
AsType
<
float16_t
>()(
Number
<
0
>
{})
=
__builtin_amdgcn_mfma_scale_f32_32x32x64_f8f6f4
(
reg_a
,
reg_b
,
reg_c
.
template
AsType
<
float16_t
>()[
Number
<
0
>
{}],
0
,
// cbsz
0
,
// blgp
0
,
0
,
0
,
0
);
#else
ignore
=
reg_a
;
ignore
=
reg_b
;
ignore
=
reg_c
;
#endif
}
};
template
<
index_t
MPerWave
,
index_t
NPerWave
>
struct
intrin_mfma_scale_f32_32x32x64f8f6f4
;
template
<
>
struct
intrin_mfma_scale_f32_32x32x64f8f6f4
<
32
,
32
>
{
template
<
class
FloatC
>
__device__
static
void
Run
(
const
f8x32_t
&
reg_a
,
const
int32_t
scale_a
,
const
f8x32_t
&
reg_b
,
const
int32_t
scale_b
,
FloatC
&
reg_c
)
{
#if defined(__gfx950__)
// https://github.com/ROCm/llvm-project/blob/656552edc693e2bb4abc9258399c39d190fce2b3/llvm/test/Verifier/AMDGPU/mfma-scale.ll#L10
reg_c
.
template
AsType
<
float16_t
>()(
Number
<
0
>
{})
=
__builtin_amdgcn_mfma_scale_f32_32x32x64_f8f6f4
(
reg_a
,
reg_b
,
reg_c
.
template
AsType
<
float16_t
>()[
Number
<
0
>
{}],
0
,
// cbsz
0
,
// blgp
0
,
// { OPSEL_HI[0], OPSEL[0] }?
scale_a
,
0
,
// { OPSEL_HI[1], OPSEL[1] }?
scale_b
);
#else
ignore
=
reg_a
;
ignore
=
scale_a
;
ignore
=
reg_b
;
ignore
=
scale_b
;
ignore
=
reg_c
;
#endif
}
};
template
<
index_t
MPerWave
,
index_t
NPerWave
>
struct
intrin_mfma_scale_f32_16x16x128f8f6f4
;
template
<
>
struct
intrin_mfma_scale_f32_16x16x128f8f6f4
<
16
,
16
>
{
template
<
class
FloatC
>
__device__
static
void
Run
(
const
f8x32_t
&
reg_a
,
const
int32_t
scale_a
,
const
f8x32_t
&
reg_b
,
const
int32_t
scale_b
,
FloatC
&
reg_c
)
{
#if defined(__gfx950__)
// https://github.com/ROCm/llvm-project/blob/656552edc693e2bb4abc9258399c39d190fce2b3/llvm/test/Verifier/AMDGPU/mfma-scale.ll#L10
reg_c
.
template
AsType
<
float4_t
>()(
Number
<
0
>
{})
=
__builtin_amdgcn_mfma_scale_f32_16x16x128_f8f6f4
(
reg_a
,
reg_b
,
reg_c
.
template
AsType
<
float4_t
>()[
Number
<
0
>
{}],
0
,
// cbsz
0
,
// blgp
0
,
// { OPSEL_HI[0], OPSEL[0] }?
scale_a
,
0
,
// { OPSEL_HI[1], OPSEL[1] }?
scale_b
);
#else
ignore
=
reg_a
;
ignore
=
scale_a
;
ignore
=
reg_b
;
ignore
=
scale_b
;
ignore
=
reg_c
;
#endif
}
};
template
<
index_t
MPerWave
,
index_t
NPerWave
>
struct
intrin_mfma_f32_16x16x128f8f6f4
;
/// @brief Performs a matrix fused multiply-accumulate operation on 16x16x128 submatrices for f8f6f4
/// data types.
///
/// @note Calls scaled version of the instruction as the original instruction is not supported in
/// the backend. That is the intended use. There is a backend optimization to select the unscaled
/// operation if the scale is 0.
template
<
>
struct
intrin_mfma_f32_16x16x128f8f6f4
<
16
,
16
>
{
template
<
class
FloatC
>
__device__
static
void
Run
(
const
f8x32_t
&
reg_a
,
const
f8x32_t
&
reg_b
,
FloatC
&
reg_c
)
{
#if defined(__gfx950__)
// https://github.com/ROCm/llvm-project/blob/656552edc693e2bb4abc9258399c39d190fce2b3/llvm/test/Verifier/AMDGPU/mfma-scale.ll#L10
reg_c
.
template
AsType
<
float4_t
>()(
Number
<
0
>
{})
=
__builtin_amdgcn_mfma_scale_f32_16x16x128_f8f6f4
(
reg_a
,
reg_b
,
reg_c
.
template
AsType
<
float4_t
>()[
Number
<
0
>
{}],
0
,
// cbsz
0
,
// blgp
0
,
0
,
0
,
0
);
#else
ignore
=
reg_a
;
ignore
=
reg_b
;
ignore
=
reg_c
;
#endif
}
};
template
<
index_t
MPerWave
,
index_t
NPerWave
>
struct
intrin_mfma_f32_32x32x16f8f8
;
...
...
include/ck/utility/array.hpp
View file @
dbb7002d
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-202
3
, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-202
5
, Advanced Micro Devices, Inc. All rights reserved.
#ifndef CK_ARRAY_HPP
#define CK_ARRAY_HPP
...
...
@@ -38,6 +38,8 @@ struct Array
}
__host__
__device__
constexpr
const
TData
*
begin
()
const
{
return
&
mData
[
0
];
}
__host__
__device__
constexpr
const
TData
*
end
()
const
{
return
&
mData
[
NSize
];
}
__host__
__device__
constexpr
TData
*
begin
()
{
return
&
mData
[
0
];
}
__host__
__device__
constexpr
TData
*
end
()
{
return
&
mData
[
NSize
];
}
};
// empty Array
...
...
@@ -54,7 +56,7 @@ template <typename X, typename... Xs>
__host__
__device__
constexpr
auto
make_array
(
X
&&
x
,
Xs
&&
...
xs
)
{
using
data_type
=
remove_cvref_t
<
X
>
;
return
Array
<
data_type
,
sizeof
...(
Xs
)
+
1
>
{
std
::
forward
<
X
>
(
x
),
std
::
forward
<
Xs
>
(
xs
)...};
return
Array
<
data_type
,
sizeof
...(
Xs
)
+
1
>
{
ck
::
forward
<
X
>
(
x
),
ck
::
forward
<
Xs
>
(
xs
)...};
}
// make empty array
...
...
include/ck/utility/container_helper.hpp
View file @
dbb7002d
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-202
3
, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-202
5
, Advanced Micro Devices, Inc. All rights reserved.
#ifndef CK_CONTAINER_HELPER_HPP
#define CK_CONTAINER_HELPER_HPP
...
...
@@ -326,14 +326,14 @@ template <typename T, index_t NX, index_t NY>
__host__
__device__
constexpr
auto
container_concat
(
const
Array
<
T
,
NX
>&
ax
,
const
Array
<
T
,
NY
>&
ay
)
{
return
unpack2
(
[
&
](
auto
&&
...
zs
)
{
return
make_array
(
std
::
forward
<
decltype
(
zs
)
>
(
zs
)...);
},
ax
,
ay
);
[
&
](
auto
&&
...
zs
)
{
return
make_array
(
ck
::
forward
<
decltype
(
zs
)
>
(
zs
)...);
},
ax
,
ay
);
}
template
<
typename
...
X
,
typename
...
Y
>
__host__
__device__
constexpr
auto
container_concat
(
const
Tuple
<
X
...
>&
tx
,
const
Tuple
<
Y
...
>&
ty
)
{
return
unpack2
(
[
&
](
auto
&&
...
zs
)
{
return
make_tuple
(
std
::
forward
<
decltype
(
zs
)
>
(
zs
)...);
},
tx
,
ty
);
[
&
](
auto
&&
...
zs
)
{
return
make_tuple
(
ck
::
forward
<
decltype
(
zs
)
>
(
zs
)...);
},
tx
,
ty
);
}
template
<
typename
Container
>
...
...
include/ck/utility/data_type.hpp
View file @
dbb7002d
...
...
@@ -4,13 +4,316 @@
#pragma once
#include "ck/utility/amd_ck_fp8.hpp"
#include "ck/utility/e8m0.hpp"
#include "ck/utility/statically_indexed_array.hpp"
#ifdef CK_CODE_GEN_RTC
using
int8_t
=
signed
char
;
using
uint8_t
=
unsigned
char
;
using
int16_t
=
signed
short
;
using
uint16_t
=
unsigned
short
;
using
float_t
=
float
;
#endif
namespace
ck
{
#ifdef CK_CODE_GEN_RTC
using
byte
=
unsigned
char
;
#else
using
std
::
byte
;
#endif
using
bhalf_t
=
ushort
;
using
half_t
=
_Float16
;
using
int4_t
=
_BitInt
(
4
);
using
f4_t
=
unsigned
_BitInt
(
4
);
using
f6_t
=
_BitInt
(
6
);
// e2m3 format
using
bf6_t
=
unsigned
_BitInt
(
6
);
// e3m2 format
struct
f4x2_pk_t
{
using
type
=
uint8_t
;
type
data
;
f4x2_pk_t
()
:
data
{
type
{}}
{}
f4x2_pk_t
(
type
init
)
:
data
{
init
}
{}
template
<
index_t
I
>
__host__
__device__
inline
type
unpack
(
Number
<
I
>
)
const
{
static_assert
(
I
<
2
,
"Index is out of range."
);
if
constexpr
(
I
==
0
)
return
data
&
0b00001111
;
else
return
(
data
>>
4
);
}
__host__
__device__
inline
type
pack
(
const
type
x0
,
const
type
x1
)
{
return
(
x1
<<
4
)
|
(
x0
&
0b00001111
);
}
};
struct
f6x16_pk_t
{
// store 16 elements of f6_t in an array of 3 uint32_t
using
element_type
=
uint32_t
;
using
type
=
StaticallyIndexedArray_v2
<
element_type
,
3
>
;
type
data
;
typedef
int8_t
test_vec_t
__attribute__
((
ext_vector_type
(
16
)));
f6x16_pk_t
()
:
data
{
type
{}}
{}
f6x16_pk_t
(
type
init
)
:
data
{
init
}
{}
template
<
index_t
I
>
__host__
__device__
inline
f6_t
unpack
(
Number
<
I
>
)
{
static_assert
(
I
<
16
,
"Index out of range for 16 f6_t elements."
);
constexpr
int
num_bits_elem
=
6
;
constexpr
int
num_bits_vec_elem
=
32
;
constexpr
int
vector_size
=
3
;
constexpr
int
bit_pos
=
I
*
num_bits_elem
;
constexpr
int
arr_idx
=
bit_pos
/
num_bits_vec_elem
;
constexpr
int
bit_offset
=
bit_pos
%
num_bits_vec_elem
;
uint32_t
bits
=
data
.
At
(
Number
<
arr_idx
>
{})
>>
bit_offset
;
constexpr
int
overhang
=
bit_offset
+
num_bits_elem
-
num_bits_vec_elem
;
if
constexpr
(
overhang
>
0
&&
(
arr_idx
+
1
)
<
vector_size
)
{
bits
|=
(
data
.
At
(
Number
<
arr_idx
+
1
>
{})
&
((
1u
<<
overhang
)
-
1
))
<<
(
num_bits_elem
-
overhang
);
}
return
static_cast
<
f6_t
>
(
bits
&
0x3F
);
}
__host__
__device__
inline
type
pack
(
const
test_vec_t
&
x
)
{
type
packed
{};
// for each of the 16 f6_t values, place its 6 bits in the correct position
ck
::
static_for
<
0
,
16
,
1
>
{}([
&
](
auto
i
)
{
uint32_t
bits
=
static_cast
<
uint32_t
>
(
x
[
static_cast
<
int
>
(
i
)])
&
0x3F
;
constexpr
int
num_bits_elem
=
6
;
constexpr
int
num_bits_vec_elem
=
32
;
constexpr
int
vector_size
=
3
;
constexpr
int
bit_pos
=
i
*
num_bits_elem
;
constexpr
int
arr_index
=
bit_pos
/
num_bits_vec_elem
;
constexpr
int
bit_offset
=
bit_pos
%
num_bits_vec_elem
;
constexpr
int
overhang
=
bit_offset
+
num_bits_elem
-
num_bits_vec_elem
;
uint32_t
old_value
=
packed
.
At
(
Number
<
arr_index
>
{});
// insert bits into the current 32-bit block
old_value
|=
(
bits
<<
bit_offset
);
packed
.
At
(
Number
<
arr_index
>
{})
=
old_value
;
// if it crosses into the next block, shift the remainder
if
constexpr
(
overhang
>
0
&&
(
arr_index
+
1
)
<
vector_size
)
{
uint32_t
next_value
=
packed
.
At
(
Number
<
arr_index
+
1
>
{});
next_value
|=
(
bits
>>
(
num_bits_elem
-
overhang
));
packed
.
At
(
Number
<
arr_index
+
1
>
{})
=
next_value
;
}
});
return
packed
;
}
};
struct
f6x32_pk_t
{
// store 32 elements of f6_t in an array of 6 uint32_t
using
element_type
=
uint32_t
;
using
type
=
StaticallyIndexedArray_v2
<
element_type
,
6
>
;
type
data
;
typedef
int8_t
test_vec_t
__attribute__
((
ext_vector_type
(
32
)));
f6x32_pk_t
()
:
data
{
type
{}}
{}
f6x32_pk_t
(
type
init
)
:
data
{
init
}
{}
template
<
index_t
I
>
__host__
__device__
inline
f6_t
unpack
(
Number
<
I
>
)
{
static_assert
(
I
<
32
,
"Index out of range for 32 f6_t elements."
);
constexpr
int
num_bits_elem
=
6
;
constexpr
int
num_bits_vec_elem
=
32
;
constexpr
int
vector_size
=
6
;
constexpr
int
bit_pos
=
I
*
num_bits_elem
;
constexpr
int
arr_idx
=
bit_pos
/
num_bits_vec_elem
;
constexpr
int
bit_offset
=
bit_pos
%
num_bits_vec_elem
;
uint32_t
bits
=
data
.
At
(
Number
<
arr_idx
>
{})
>>
bit_offset
;
constexpr
int
overhang
=
bit_offset
+
num_bits_elem
-
num_bits_vec_elem
;
if
constexpr
(
overhang
>
0
&&
(
arr_idx
+
1
)
<
vector_size
)
{
bits
|=
(
data
.
At
(
Number
<
arr_idx
+
1
>
{})
&
((
1u
<<
overhang
)
-
1
))
<<
(
num_bits_elem
-
overhang
);
}
return
static_cast
<
f6_t
>
(
bits
&
0x3F
);
}
__host__
__device__
inline
type
pack
(
const
test_vec_t
&
x
)
{
type
packed
{};
// for each of the 32 f6_t values, place its 6 bits in the correct position
ck
::
static_for
<
0
,
32
,
1
>
{}([
&
](
auto
i
)
{
uint32_t
bits
=
static_cast
<
uint32_t
>
(
x
[
static_cast
<
int
>
(
i
)])
&
0x3F
;
constexpr
int
num_bits_elem
=
6
;
constexpr
int
num_bits_vec_elem
=
32
;
constexpr
int
vector_size
=
6
;
constexpr
int
bit_pos
=
i
*
num_bits_elem
;
constexpr
int
arr_index
=
bit_pos
/
num_bits_vec_elem
;
constexpr
int
bit_offset
=
bit_pos
%
num_bits_vec_elem
;
constexpr
int
overhang
=
bit_offset
+
num_bits_elem
-
num_bits_vec_elem
;
uint32_t
old_value
=
packed
.
At
(
Number
<
arr_index
>
{});
// insert bits into the current 32-bit block
old_value
|=
(
bits
<<
bit_offset
);
packed
.
At
(
Number
<
arr_index
>
{})
=
old_value
;
// if it crosses into the next block, shift the remainder
if
constexpr
(
overhang
>
0
&&
(
arr_index
+
1
)
<
vector_size
)
{
uint32_t
next_value
=
packed
.
At
(
Number
<
arr_index
+
1
>
{});
next_value
|=
(
bits
>>
(
num_bits_elem
-
overhang
));
packed
.
At
(
Number
<
arr_index
+
1
>
{})
=
next_value
;
}
});
return
packed
;
}
};
struct
bf6x16_pk_t
{
// store 16 elements of bf6_t in an array of 3 uint32_t
using
element_type
=
uint32_t
;
using
type
=
StaticallyIndexedArray_v2
<
element_type
,
3
>
;
type
data
;
typedef
int8_t
test_vec_t
__attribute__
((
ext_vector_type
(
16
)));
bf6x16_pk_t
()
:
data
{
type
{}}
{}
bf6x16_pk_t
(
type
init
)
:
data
{
init
}
{}
template
<
index_t
I
>
__host__
__device__
inline
bf6_t
unpack
(
Number
<
I
>
)
{
static_assert
(
I
<
16
,
"Index out of range for 16 f6_t elements."
);
constexpr
int
num_bits_elem
=
6
;
constexpr
int
num_bits_vec_elem
=
32
;
constexpr
int
vector_size
=
3
;
constexpr
int
bit_pos
=
I
*
num_bits_elem
;
constexpr
int
arr_idx
=
bit_pos
/
num_bits_vec_elem
;
constexpr
int
bit_offset
=
bit_pos
%
num_bits_vec_elem
;
uint32_t
bits
=
data
.
At
(
Number
<
arr_idx
>
{})
>>
bit_offset
;
constexpr
int
overhang
=
bit_offset
+
num_bits_elem
-
num_bits_vec_elem
;
if
constexpr
(
overhang
>
0
&&
(
arr_idx
+
1
)
<
vector_size
)
{
bits
|=
(
data
.
At
(
Number
<
arr_idx
+
1
>
{})
&
((
1u
<<
overhang
)
-
1
))
<<
(
num_bits_elem
-
overhang
);
}
return
static_cast
<
bf6_t
>
(
bits
&
0x3F
);
}
__host__
__device__
inline
type
pack
(
const
test_vec_t
&
x
)
{
type
packed
{};
// for each of the 16 bf6_t values, place its 6 bits in the correct position
ck
::
static_for
<
0
,
16
,
1
>
{}([
&
](
auto
i
)
{
uint32_t
bits
=
static_cast
<
uint32_t
>
(
x
[
static_cast
<
int
>
(
i
)])
&
0x3F
;
constexpr
int
num_bits_elem
=
6
;
constexpr
int
num_bits_vec_elem
=
32
;
constexpr
int
vector_size
=
3
;
constexpr
int
bit_pos
=
i
*
num_bits_elem
;
constexpr
int
arr_index
=
bit_pos
/
num_bits_vec_elem
;
constexpr
int
bit_offset
=
bit_pos
%
num_bits_vec_elem
;
constexpr
int
overhang
=
bit_offset
+
num_bits_elem
-
num_bits_vec_elem
;
uint32_t
old_value
=
packed
.
At
(
Number
<
arr_index
>
{});
// insert bits into the current 32-bit block
old_value
|=
(
bits
<<
bit_offset
);
packed
.
At
(
Number
<
arr_index
>
{})
=
old_value
;
// if it crosses into the next block, shift the remainder
if
constexpr
(
overhang
>
0
&&
(
arr_index
+
1
)
<
vector_size
)
{
uint32_t
next_value
=
packed
.
At
(
Number
<
arr_index
+
1
>
{});
next_value
|=
(
bits
>>
(
num_bits_elem
-
overhang
));
packed
.
At
(
Number
<
arr_index
+
1
>
{})
=
next_value
;
}
});
return
packed
;
}
};
struct
bf6x32_pk_t
{
// store 32 elements of bf6_t in an array of 6 uint32_t
using
element_type
=
uint32_t
;
using
type
=
StaticallyIndexedArray_v2
<
element_type
,
6
>
;
type
data
;
typedef
int8_t
test_vec_t
__attribute__
((
ext_vector_type
(
32
)));
bf6x32_pk_t
()
:
data
{
type
{}}
{}
bf6x32_pk_t
(
type
init
)
:
data
{
init
}
{}
template
<
index_t
I
>
__host__
__device__
inline
bf6_t
unpack
(
Number
<
I
>
)
{
static_assert
(
I
<
32
,
"Index out of range for 32 f6_t elements."
);
constexpr
int
num_bits_elem
=
6
;
constexpr
int
num_bits_vec_elem
=
32
;
constexpr
int
vector_size
=
6
;
constexpr
int
bit_pos
=
I
*
num_bits_elem
;
constexpr
int
arr_idx
=
bit_pos
/
num_bits_vec_elem
;
constexpr
int
bit_offset
=
bit_pos
%
num_bits_vec_elem
;
uint32_t
bits
=
data
.
At
(
Number
<
arr_idx
>
{})
>>
bit_offset
;
constexpr
int
overhang
=
bit_offset
+
num_bits_elem
-
num_bits_vec_elem
;
if
constexpr
(
overhang
>
0
&&
(
arr_idx
+
1
)
<
vector_size
)
{
bits
|=
(
data
.
At
(
Number
<
arr_idx
+
1
>
{})
&
((
1u
<<
overhang
)
-
1
))
<<
(
num_bits_elem
-
overhang
);
}
return
static_cast
<
bf6_t
>
(
bits
&
0x3F
);
}
__host__
__device__
inline
type
pack
(
const
test_vec_t
&
x
)
{
type
packed
{};
// for each of the 32 bf6_t values, place its 6 bits in the correct position
ck
::
static_for
<
0
,
32
,
1
>
{}([
&
](
auto
i
)
{
uint32_t
bits
=
static_cast
<
uint32_t
>
(
x
[
static_cast
<
int
>
(
i
)])
&
0x3F
;
constexpr
int
num_bits_elem
=
6
;
constexpr
int
num_bits_vec_elem
=
32
;
constexpr
int
vector_size
=
6
;
constexpr
int
bit_pos
=
i
*
num_bits_elem
;
constexpr
int
arr_index
=
bit_pos
/
num_bits_vec_elem
;
constexpr
int
bit_offset
=
bit_pos
%
num_bits_vec_elem
;
constexpr
int
overhang
=
bit_offset
+
num_bits_elem
-
num_bits_vec_elem
;
uint32_t
old_value
=
packed
.
At
(
Number
<
arr_index
>
{});
// insert bits into the current 32-bit block
old_value
|=
(
bits
<<
bit_offset
);
packed
.
At
(
Number
<
arr_index
>
{})
=
old_value
;
// if it crosses into the next block, shift the remainder
if
constexpr
(
overhang
>
0
&&
(
arr_index
+
1
)
<
vector_size
)
{
uint32_t
next_value
=
packed
.
At
(
Number
<
arr_index
+
1
>
{});
next_value
|=
(
bits
>>
(
num_bits_elem
-
overhang
));
packed
.
At
(
Number
<
arr_index
+
1
>
{})
=
next_value
;
}
});
return
packed
;
}
};
// custom data type - pack int4 data
struct
pk_i4_t
...
...
@@ -28,14 +331,15 @@ inline constexpr auto next_pow2(uint32_t x)
}
// native types: double, float, _Float16, ushort, int32_t, int8_t, uint8_t, f8_fnuz_t, bf8_fnuz_t,
// native types: bool
// native types: bool
, f4_t, f6_t, bf6_t
template
<
typename
T
>
inline
constexpr
bool
is_native_type
()
{
return
is_same
<
T
,
double
>::
value
||
is_same
<
T
,
float
>::
value
||
is_same
<
T
,
half_t
>::
value
||
is_same
<
T
,
bhalf_t
>::
value
||
is_same
<
T
,
int32_t
>::
value
||
is_same
<
T
,
int8_t
>::
value
||
is_same
<
T
,
uint8_t
>::
value
||
is_same
<
T
,
f8_fnuz_t
>::
value
||
is_same
<
T
,
bf8_fnuz_t
>::
value
||
is_same
<
T
,
bool
>::
value
;
is_same
<
T
,
bf8_fnuz_t
>::
value
||
is_same
<
T
,
bool
>::
value
||
is_same
<
T
,
f4_t
>::
value
||
is_same
<
T
,
f6_t
>::
value
||
is_same
<
T
,
bf6_t
>::
value
;
}
// vector_type
...
...
@@ -217,7 +521,7 @@ struct scalar_type<bool>
};
template
<
typename
T
>
struct
vector_type
<
T
,
1
,
typename
std
::
enable_if_t
<
is_native_type
<
T
>
()
>>
struct
vector_type
<
T
,
1
,
typename
ck
::
enable_if_t
<
is_native_type
<
T
>
()
>>
{
using
d1_t
=
T
;
using
type
=
d1_t
;
...
...
@@ -253,7 +557,7 @@ struct vector_type<T, 1, typename std::enable_if_t<is_native_type<T>()>>
__device__
int
static
err
=
0
;
template
<
typename
T
>
struct
vector_type
<
T
,
2
,
typename
std
::
enable_if_t
<
is_native_type
<
T
>
()
>>
struct
vector_type
<
T
,
2
,
typename
ck
::
enable_if_t
<
is_native_type
<
T
>
()
>>
{
using
d1_t
=
T
;
typedef
T
d2_t
__attribute__
((
ext_vector_type
(
2
)));
...
...
@@ -313,7 +617,7 @@ struct vector_type<T, 2, typename std::enable_if_t<is_native_type<T>()>>
};
template
<
typename
T
>
struct
vector_type
<
T
,
3
,
typename
std
::
enable_if_t
<
is_native_type
<
T
>
()
>>
struct
vector_type
<
T
,
3
,
typename
ck
::
enable_if_t
<
is_native_type
<
T
>
()
>>
{
using
d1_t
=
T
;
typedef
T
d2_t
__attribute__
((
ext_vector_type
(
2
)));
...
...
@@ -383,7 +687,7 @@ struct vector_type<T, 3, typename std::enable_if_t<is_native_type<T>()>>
};
template
<
typename
T
>
struct
vector_type
<
T
,
4
,
typename
std
::
enable_if_t
<
is_native_type
<
T
>
()
>>
struct
vector_type
<
T
,
4
,
typename
ck
::
enable_if_t
<
is_native_type
<
T
>
()
>>
{
using
d1_t
=
T
;
typedef
T
d2_t
__attribute__
((
ext_vector_type
(
2
)));
...
...
@@ -453,7 +757,7 @@ struct vector_type<T, 4, typename std::enable_if_t<is_native_type<T>()>>
};
template
<
typename
T
>
struct
vector_type
<
T
,
5
,
typename
std
::
enable_if_t
<
is_native_type
<
T
>
()
>>
struct
vector_type
<
T
,
5
,
typename
ck
::
enable_if_t
<
is_native_type
<
T
>
()
>>
{
using
d1_t
=
T
;
typedef
T
d4_t
__attribute__
((
ext_vector_type
(
4
)));
...
...
@@ -523,7 +827,7 @@ struct vector_type<T, 5, typename std::enable_if_t<is_native_type<T>()>>
};
template
<
typename
T
>
struct
vector_type
<
T
,
7
,
typename
std
::
enable_if_t
<
is_native_type
<
T
>
()
>>
struct
vector_type
<
T
,
7
,
typename
ck
::
enable_if_t
<
is_native_type
<
T
>
()
>>
{
using
d1_t
=
T
;
typedef
T
d2_t
__attribute__
((
ext_vector_type
(
2
)));
...
...
@@ -605,7 +909,7 @@ struct vector_type<T, 7, typename std::enable_if_t<is_native_type<T>()>>
};
template
<
typename
T
>
struct
vector_type
<
T
,
8
,
typename
std
::
enable_if_t
<
is_native_type
<
T
>
()
>>
struct
vector_type
<
T
,
8
,
typename
ck
::
enable_if_t
<
is_native_type
<
T
>
()
>>
{
using
d1_t
=
T
;
typedef
T
d2_t
__attribute__
((
ext_vector_type
(
2
)));
...
...
@@ -687,7 +991,7 @@ struct vector_type<T, 8, typename std::enable_if_t<is_native_type<T>()>>
};
template
<
typename
T
>
struct
vector_type
<
T
,
13
,
typename
std
::
enable_if_t
<
is_native_type
<
T
>
()
>>
struct
vector_type
<
T
,
13
,
typename
ck
::
enable_if_t
<
is_native_type
<
T
>
()
>>
{
using
d1_t
=
T
;
typedef
T
d4_t
__attribute__
((
ext_vector_type
(
4
)));
...
...
@@ -769,7 +1073,7 @@ struct vector_type<T, 13, typename std::enable_if_t<is_native_type<T>()>>
};
template
<
typename
T
>
struct
vector_type
<
T
,
16
,
typename
std
::
enable_if_t
<
is_native_type
<
T
>
()
>>
struct
vector_type
<
T
,
16
,
typename
ck
::
enable_if_t
<
is_native_type
<
T
>
()
>>
{
using
d1_t
=
T
;
typedef
T
d2_t
__attribute__
((
ext_vector_type
(
2
)));
...
...
@@ -863,7 +1167,7 @@ struct vector_type<T, 16, typename std::enable_if_t<is_native_type<T>()>>
};
template
<
typename
T
>
struct
vector_type
<
T
,
32
,
typename
std
::
enable_if_t
<
is_native_type
<
T
>
()
>>
struct
vector_type
<
T
,
32
,
typename
ck
::
enable_if_t
<
is_native_type
<
T
>
()
>>
{
using
d1_t
=
T
;
typedef
T
d2_t
__attribute__
((
ext_vector_type
(
2
)));
...
...
@@ -967,7 +1271,7 @@ struct vector_type<T, 32, typename std::enable_if_t<is_native_type<T>()>>
};
template
<
typename
T
>
struct
vector_type
<
T
,
64
,
typename
std
::
enable_if_t
<
is_native_type
<
T
>
()
>>
struct
vector_type
<
T
,
64
,
typename
ck
::
enable_if_t
<
is_native_type
<
T
>
()
>>
{
using
d1_t
=
T
;
typedef
T
d2_t
__attribute__
((
ext_vector_type
(
2
)));
...
...
@@ -1083,7 +1387,7 @@ struct vector_type<T, 64, typename std::enable_if_t<is_native_type<T>()>>
};
template
<
typename
T
>
struct
vector_type
<
T
,
128
,
typename
std
::
enable_if_t
<
is_native_type
<
T
>
()
>>
struct
vector_type
<
T
,
128
,
typename
ck
::
enable_if_t
<
is_native_type
<
T
>
()
>>
{
using
d1_t
=
T
;
typedef
T
d2_t
__attribute__
((
ext_vector_type
(
2
)));
...
...
@@ -1209,7 +1513,7 @@ struct vector_type<T, 128, typename std::enable_if_t<is_native_type<T>()>>
};
template
<
typename
T
>
struct
vector_type
<
T
,
256
,
typename
std
::
enable_if_t
<
is_native_type
<
T
>
()
>>
struct
vector_type
<
T
,
256
,
typename
ck
::
enable_if_t
<
is_native_type
<
T
>
()
>>
{
using
d1_t
=
T
;
typedef
T
d2_t
__attribute__
((
ext_vector_type
(
2
)));
...
...
@@ -1358,12 +1662,37 @@ struct nnvb_data_t_selector<f8_ocp_t>
{
using
type
=
f8_ocp_t
::
data_type
;
};
template
<
>
struct
nnvb_data_t_selector
<
bf8_ocp_t
>
{
using
type
=
bf8_ocp_t
::
data_type
;
};
template
<
>
struct
nnvb_data_t_selector
<
f6x16_pk_t
>
{
using
type
=
f6x16_pk_t
::
type
;
};
template
<
>
struct
nnvb_data_t_selector
<
f6x32_pk_t
>
{
using
type
=
f6x32_pk_t
::
type
;
};
template
<
>
struct
nnvb_data_t_selector
<
bf6x16_pk_t
>
{
using
type
=
bf6x16_pk_t
::
type
;
};
template
<
>
struct
nnvb_data_t_selector
<
bf6x32_pk_t
>
{
using
type
=
bf6x32_pk_t
::
type
;
};
template
<
>
struct
nnvb_data_t_selector
<
pk_i4_t
>
{
...
...
@@ -1374,7 +1703,7 @@ template <typename T, index_t N>
struct
non_native_vector_base
<
T
,
N
,
std
::
enable_if_t
<
sizeof
(
T
)
==
1
||
sizeof
(
T
)
==
2
||
sizeof
(
T
)
==
4
||
sizeof
(
T
)
==
8
>>
ck
::
enable_if_t
<
sizeof
(
T
)
==
1
||
sizeof
(
T
)
==
2
||
sizeof
(
T
)
==
4
||
sizeof
(
T
)
==
8
>>
{
using
data_t
=
typename
nnvb_data_t_selector
<
T
>::
type
;
// select data_t based on the size of T
static_assert
(
sizeof
(
T
)
==
sizeof
(
data_t
),
"non_native_vector_base storage size mismatch"
);
...
...
@@ -1470,6 +1799,63 @@ struct non_native_vector_base<
}
};
// implementation for f6x16 and f6x32
template
<
typename
T
,
index_t
N
>
struct
non_native_vector_base
<
T
,
N
,
std
::
enable_if_t
<
sizeof
(
T
)
==
12
||
sizeof
(
T
)
==
24
>>
{
using
data_t
=
typename
nnvb_data_t_selector
<
T
>::
type
;
// select data_t based on declared base type
using
element_t
=
typename
T
::
element_type
;
// select element_t based on declared element type
static_assert
(
sizeof
(
T
)
==
sizeof
(
data_t
),
"non_native_vector_base storage size mismatch"
);
static
constexpr
size_t
size_factor
=
sizeof
(
data_t
)
/
sizeof
(
element_t
);
// f6x16: 12/4 = 3, f6x32: 24/4 = 6
using
data_v
=
element_t
__attribute__
((
ext_vector_type
(
N
*
size_factor
)));
using
type
=
non_native_vector_base
<
T
,
N
>
;
union
alignas
(
next_pow2
(
N
*
sizeof
(
T
)))
{
data_v
dN
;
// storage vector;
StaticallyIndexedArray
<
data_t
,
N
>
dxN
;
StaticallyIndexedArray
<
T
,
N
>
dTxN
;
StaticallyIndexedArray
<
data_v
,
1
>
dNx1
;
}
data_
;
__host__
__device__
constexpr
non_native_vector_base
(
data_t
a
)
:
data_
{
data_v
(
a
.
At
(
Number
<
0
>
{}))}
{
}
__host__
__device__
constexpr
non_native_vector_base
(
T
f
)
:
non_native_vector_base
(
bit_cast
<
data_t
>
(
f
))
{
}
__host__
__device__
constexpr
non_native_vector_base
()
:
non_native_vector_base
(
T
{}){};
__host__
__device__
constexpr
non_native_vector_base
(
data_v
v
)
:
data_
{
v
}
{}
__host__
__device__
constexpr
operator
data_v
()
const
{
return
data_
.
dN
;
}
__host__
__device__
constexpr
operator
data_t
()
const
{
if
constexpr
(
N
==
1
)
{
return
data_
.
dxN
[
Number
<
0
>
{}];
}
else
{
return
data_
.
dxN
;
// XXX this should cause an error
}
}
__host__
__device__
constexpr
operator
T
()
const
{
if
constexpr
(
N
==
1
)
{
return
data_
.
dTxN
[
Number
<
0
>
{}];
}
else
{
return
data_
.
dTxN
;
// XXX this should cause an error
}
}
};
template
<
typename
T
,
index_t
N
>
struct
scalar_type
<
non_native_vector_base
<
T
,
N
>>
;
...
...
@@ -1499,7 +1885,7 @@ struct scalar_type<non_native_vector_base<pk_i4_t, N>>
// non-native vector_type implementation
template
<
typename
T
>
struct
vector_type
<
T
,
1
,
typename
std
::
enable_if_t
<!
is_native_type
<
T
>
()
>>
struct
vector_type
<
T
,
1
,
typename
ck
::
enable_if_t
<!
is_native_type
<
T
>
()
>>
{
using
d1_t
=
T
;
using
d1_nnv_t
=
non_native_vector_base
<
T
,
1
>
;
...
...
@@ -1550,7 +1936,7 @@ struct vector_type<T, 1, typename std::enable_if_t<!is_native_type<T>()>>
};
template
<
typename
T
>
struct
vector_type
<
T
,
2
,
typename
std
::
enable_if_t
<!
is_native_type
<
T
>
()
>>
struct
vector_type
<
T
,
2
,
typename
ck
::
enable_if_t
<!
is_native_type
<
T
>
()
>>
{
using
d1_t
=
T
;
using
d1_nnv_t
=
non_native_vector_base
<
T
,
1
>
;
...
...
@@ -1613,7 +1999,7 @@ struct vector_type<T, 2, typename std::enable_if_t<!is_native_type<T>()>>
};
template
<
typename
T
>
struct
vector_type
<
T
,
4
,
typename
std
::
enable_if_t
<!
is_native_type
<
T
>
()
>>
struct
vector_type
<
T
,
4
,
typename
ck
::
enable_if_t
<!
is_native_type
<
T
>
()
>>
{
using
d1_t
=
T
;
using
d1_nnv_t
=
non_native_vector_base
<
T
,
1
>
;
...
...
@@ -1686,7 +2072,7 @@ struct vector_type<T, 4, typename std::enable_if_t<!is_native_type<T>()>>
};
template
<
typename
T
>
struct
vector_type
<
T
,
8
,
typename
std
::
enable_if_t
<!
is_native_type
<
T
>
()
>>
struct
vector_type
<
T
,
8
,
typename
ck
::
enable_if_t
<!
is_native_type
<
T
>
()
>>
{
using
d1_t
=
T
;
using
d1_nnv_t
=
non_native_vector_base
<
T
,
1
>
;
...
...
@@ -1771,7 +2157,7 @@ struct vector_type<T, 8, typename std::enable_if_t<!is_native_type<T>()>>
};
template
<
typename
T
>
struct
vector_type
<
T
,
16
,
typename
std
::
enable_if_t
<!
is_native_type
<
T
>
()
>>
struct
vector_type
<
T
,
16
,
typename
ck
::
enable_if_t
<!
is_native_type
<
T
>
()
>>
{
using
d1_t
=
T
;
using
d1_nnv_t
=
non_native_vector_base
<
T
,
1
>
;
...
...
@@ -1866,7 +2252,7 @@ struct vector_type<T, 16, typename std::enable_if_t<!is_native_type<T>()>>
};
template
<
typename
T
>
struct
vector_type
<
T
,
32
,
typename
std
::
enable_if_t
<!
is_native_type
<
T
>
()
>>
struct
vector_type
<
T
,
32
,
typename
ck
::
enable_if_t
<!
is_native_type
<
T
>
()
>>
{
using
d1_t
=
T
;
using
d2_t
=
non_native_vector_base
<
T
,
2
>
;
...
...
@@ -1970,7 +2356,7 @@ struct vector_type<T, 32, typename std::enable_if_t<!is_native_type<T>()>>
};
template
<
typename
T
>
struct
vector_type
<
T
,
64
,
typename
std
::
enable_if_t
<!
is_native_type
<
T
>
()
>>
struct
vector_type
<
T
,
64
,
typename
ck
::
enable_if_t
<!
is_native_type
<
T
>
()
>>
{
using
d1_t
=
T
;
using
d2_t
=
non_native_vector_base
<
T
,
2
>
;
...
...
@@ -2205,25 +2591,251 @@ using uint8x16_t = typename vector_type<uint8_t, 16>::type;
using
uint8x32_t
=
typename
vector_type
<
uint8_t
,
32
>::
type
;
using
uint8x64_t
=
typename
vector_type
<
uint8_t
,
64
>::
type
;
// f4
using
f4x2_t
=
typename
vector_type
<
f4x2_pk_t
,
1
>::
type
;
using
f4x4_t
=
typename
vector_type
<
f4x2_pk_t
,
2
>::
type
;
using
f4x8_t
=
typename
vector_type
<
f4x2_pk_t
,
4
>::
type
;
using
f4x16_t
=
typename
vector_type
<
f4x2_pk_t
,
8
>::
type
;
using
f4x32_t
=
typename
vector_type
<
f4x2_pk_t
,
16
>::
type
;
using
f4x64_t
=
typename
vector_type
<
f4x2_pk_t
,
32
>::
type
;
// f6
using
f6x16_t
=
typename
vector_type
<
f6x16_pk_t
,
1
>::
type
;
using
f6x32_t
=
typename
vector_type
<
f6x32_pk_t
,
1
>::
type
;
// bf6
using
bf6x16_t
=
typename
vector_type
<
bf6x16_pk_t
,
1
>::
type
;
using
bf6x32_t
=
typename
vector_type
<
bf6x32_pk_t
,
1
>::
type
;
// pack int4
using
pk_i4x2_t
=
typename
vector_type
<
pk_i4_t
,
2
>::
type
;
using
pk_i4x4_t
=
typename
vector_type
<
pk_i4_t
,
4
>::
type
;
using
pk_i4x8_t
=
typename
vector_type
<
pk_i4_t
,
8
>::
type
;
#ifdef CK_CODE_GEN_RTC
template
<
typename
T
>
struct
NumericLimits
;
template
<
>
struct
NumericLimits
<
int32_t
>
{
__host__
__device__
static
constexpr
int32_t
Lowest
()
noexcept
{
return
-
2147483647
-
1
;
}
__host__
__device__
static
constexpr
int32_t
Min
()
noexcept
{
return
-
2147483647
-
1
;
}
__host__
__device__
static
constexpr
int32_t
Max
()
noexcept
{
return
2147483647
;
}
__host__
__device__
static
constexpr
int32_t
Infinity
()
noexcept
{
return
0
;
}
__host__
__device__
static
constexpr
int32_t
QuietNaN
()
{
return
0
;
}
};
template
<
>
struct
NumericLimits
<
int16_t
>
{
__host__
__device__
static
constexpr
int16_t
Lowest
()
noexcept
{
return
-
32768
;
}
__host__
__device__
static
constexpr
int16_t
Min
()
noexcept
{
return
-
32768
;
}
__host__
__device__
static
constexpr
int16_t
Max
()
noexcept
{
return
32767
;
}
__host__
__device__
static
constexpr
int16_t
Infinity
()
noexcept
{
return
0
;
}
__host__
__device__
static
constexpr
int16_t
QuietNaN
()
{
return
0
;
}
};
template
<
>
struct
NumericLimits
<
int8_t
>
{
__host__
__device__
static
constexpr
int8_t
Lowest
()
noexcept
{
return
-
128
;
}
__host__
__device__
static
constexpr
int8_t
Min
()
noexcept
{
return
-
128
;
}
__host__
__device__
static
constexpr
int8_t
Max
()
noexcept
{
return
127
;
}
__host__
__device__
static
constexpr
int8_t
Infinity
()
noexcept
{
return
0
;
}
__host__
__device__
static
constexpr
int8_t
QuietNaN
()
{
return
0
;
}
};
template
<
>
struct
NumericLimits
<
uint32_t
>
{
__host__
__device__
static
constexpr
uint32_t
Lowest
()
noexcept
{
return
0
;
}
__host__
__device__
static
constexpr
uint32_t
Min
()
noexcept
{
return
0
;
}
__host__
__device__
static
constexpr
uint32_t
Max
()
noexcept
{
return
4294967295U
;
}
__host__
__device__
static
constexpr
uint32_t
Infinity
()
noexcept
{
return
0
;
}
__host__
__device__
static
constexpr
uint32_t
QuietNaN
()
{
return
0
;
}
};
template
<
>
struct
NumericLimits
<
uint16_t
>
{
__host__
__device__
static
constexpr
uint16_t
Lowest
()
noexcept
{
return
0
;
}
__host__
__device__
static
constexpr
uint16_t
Min
()
noexcept
{
return
0
;
}
__host__
__device__
static
constexpr
uint16_t
Max
()
noexcept
{
return
65535U
;
}
__host__
__device__
static
constexpr
uint16_t
Infinity
()
noexcept
{
return
0
;
}
__host__
__device__
static
constexpr
uint16_t
QuietNaN
()
{
return
0
;
}
};
template
<
>
struct
NumericLimits
<
float
>
{
static
constexpr
unsigned
int
binary_min
=
0x00800000
;
static
constexpr
unsigned
int
binary_max
=
0x7F7FFFFF
;
static
constexpr
unsigned
int
binary_lowest
=
0xFF7FFFFF
;
static
constexpr
unsigned
int
binary_qnan
=
0xFFC00001
;
static
constexpr
unsigned
int
binary_inf
=
0x7F8000000
;
__host__
__device__
static
constexpr
float
Min
()
{
return
bit_cast
<
float
>
(
binary_min
);
}
__host__
__device__
static
constexpr
float
Max
()
{
return
bit_cast
<
float
>
(
binary_max
);
}
__host__
__device__
static
constexpr
float
Lowest
()
{
return
bit_cast
<
float
>
(
binary_lowest
);
}
__host__
__device__
static
constexpr
float
QuietNaN
()
{
return
bit_cast
<
float
>
(
binary_qnan
);
}
__host__
__device__
static
constexpr
float
Infinity
()
{
return
bit_cast
<
float
>
(
binary_inf
);
}
};
template
<
>
struct
NumericLimits
<
half_t
>
{
static
constexpr
unsigned
short
binary_min
=
0x0400
;
static
constexpr
unsigned
short
binary_max
=
0x7BFF
;
static
constexpr
unsigned
short
binary_lowest
=
0xFBFF
;
static
constexpr
unsigned
short
binary_qnan
=
0x7FFF
;
__host__
__device__
static
constexpr
half_t
Min
()
{
return
bit_cast
<
half_t
>
(
binary_min
);
}
__host__
__device__
static
constexpr
half_t
Max
()
{
return
bit_cast
<
half_t
>
(
binary_max
);
}
__host__
__device__
static
constexpr
half_t
Lowest
()
{
return
bit_cast
<
half_t
>
(
binary_lowest
);
}
__host__
__device__
static
constexpr
half_t
QuietNaN
()
{
return
bit_cast
<
half_t
>
(
binary_qnan
);
}
};
#ifdef CK_EXPERIMENTAL_BIT_INT_EXTENSION_INT4
template
<
>
struct
NumericLimits
<
int4_t
>
{
__host__
__device__
static
constexpr
int4_t
Min
()
{
return
int4_t
(
-
8
);
}
__host__
__device__
static
constexpr
int4_t
Max
()
{
return
int4_t
(
7
);
}
__host__
__device__
static
constexpr
int4_t
Lowest
()
{
return
int4_t
(
-
8
);
}
};
#endif // CK_EXPERIMENTAL_BIT_INT_EXTENSION_INT4
template
<
>
struct
NumericLimits
<
f8_fnuz_t
>
{
// negative zero nan mode with exp bias = 8
static
constexpr
uint8_t
binary_min
=
0x08
;
// 0b00001000
static
constexpr
uint8_t
binary_max
=
0x7F
;
// 0b01111111
static
constexpr
uint8_t
binary_lowest
=
0xFF
;
// 0b11111111
static
constexpr
uint8_t
binary_qnan
=
0x80
;
// 0b10000000
// ieee mode with exp bias = 7
// static constexpr uint8_t binary_min = 0x08; // 0b00001000
// static constexpr uint8_t binary_max = 0x77; // 0b01110111
// static constexpr uint8_t binary_lowest = 0xF7; // 0b11110111
// static constexpr uint8_t binary_qnan = 0x79; // any sign, exp=1111, mant!=0
__host__
__device__
static
constexpr
f8_fnuz_t
Min
()
{
return
f8_fnuz_t
(
binary_min
);
}
__host__
__device__
static
constexpr
f8_fnuz_t
Max
()
{
return
f8_fnuz_t
(
binary_max
);
}
__host__
__device__
static
constexpr
f8_fnuz_t
Lowest
()
{
return
f8_fnuz_t
(
binary_lowest
);
}
__host__
__device__
static
constexpr
f8_fnuz_t
QuietNaN
()
{
return
f8_fnuz_t
(
binary_qnan
);
}
};
template
<
>
struct
NumericLimits
<
bf8_fnuz_t
>
{
// negative zero nan mode with exp bias = 16
static
constexpr
uint8_t
binary_min
=
0x04
;
// 0b00000100
static
constexpr
uint8_t
binary_max
=
0x7F
;
// 0b01111111
static
constexpr
uint8_t
binary_lowest
=
0xFF
;
// 0b11111111
static
constexpr
uint8_t
binary_qnan
=
0x80
;
// 0b10000000
// ieee mode with exp bias = 15
// static constexpr uint8_t binary_min = 0x04; // 0b00000100
// static constexpr uint8_t binary_max = 0x7B; // 0b01111011
// static constexpr uint8_t binary_lowest = 0xFB; // 0b11111011
// static constexpr uint8_t binary_qnan = 0x79; // any sign, exp=1111, mant!=
__host__
__device__
static
constexpr
bf8_fnuz_t
Min
()
{
return
bf8_fnuz_t
(
binary_min
);
}
__host__
__device__
static
constexpr
bf8_fnuz_t
Max
()
{
return
bf8_fnuz_t
(
binary_max
);
}
__host__
__device__
static
constexpr
bf8_fnuz_t
Lowest
()
{
return
bf8_fnuz_t
(
binary_lowest
);
}
__host__
__device__
static
constexpr
bf8_fnuz_t
QuietNaN
()
{
return
bf8_fnuz_t
(
binary_qnan
);
}
};
template
<
>
struct
NumericLimits
<
f8_ocp_t
>
{
static
constexpr
uint8_t
binary_min
=
0x08
;
// 0b00001000 = 2^-6
static
constexpr
uint8_t
binary_max
=
0x7E
;
// 0b01111110 = 448
static
constexpr
uint8_t
binary_lowest
=
0xFE
;
// 0b11111110 = -448
static
constexpr
uint8_t
binary_qnan
=
0x7F
;
// 0b01111111
__host__
__device__
static
constexpr
f8_ocp_t
Min
()
{
return
bit_cast
<
f8_ocp_t
>
(
binary_min
);
}
__host__
__device__
static
constexpr
f8_ocp_t
Max
()
{
return
bit_cast
<
f8_ocp_t
>
(
binary_max
);
}
__host__
__device__
static
constexpr
f8_ocp_t
Lowest
()
{
return
bit_cast
<
f8_ocp_t
>
(
binary_lowest
);
}
__host__
__device__
static
constexpr
f8_ocp_t
QuietNaN
()
{
return
bit_cast
<
f8_ocp_t
>
(
binary_qnan
);
}
};
template
<
>
struct
NumericLimits
<
bf8_ocp_t
>
{
static
constexpr
uint8_t
binary_min
=
0x04
;
// 0b00000100 = 2^-14
static
constexpr
uint8_t
binary_max
=
0x7B
;
// 0b01111011 = 57344
static
constexpr
uint8_t
binary_lowest
=
0xFB
;
// 0b11111011 = -57344
static
constexpr
uint8_t
binary_qnan
=
0x7D
;
// 0b01111101
__host__
__device__
static
constexpr
bf8_ocp_t
Min
()
{
return
bit_cast
<
bf8_ocp_t
>
(
binary_min
);
}
__host__
__device__
static
constexpr
bf8_ocp_t
Max
()
{
return
bit_cast
<
bf8_ocp_t
>
(
binary_max
);
}
__host__
__device__
static
constexpr
bf8_ocp_t
Lowest
()
{
return
bit_cast
<
bf8_ocp_t
>
(
binary_lowest
);
}
__host__
__device__
static
constexpr
bf8_ocp_t
QuietNaN
()
{
return
bit_cast
<
bf8_ocp_t
>
(
binary_qnan
);
}
};
#else
template
<
typename
T
>
struct
NumericLimits
{
__host__
__device__
static
constexpr
T
Min
()
{
return
std
::
numeric_limits
<
T
>::
min
();
}
__host__
__device__
static
constexpr
T
Max
()
{
return
std
::
numeric_limits
<
T
>::
max
();
}
__host__
__device__
static
constexpr
T
Lowest
()
{
return
std
::
numeric_limits
<
T
>::
lowest
();
}
__host__
__device__
static
constexpr
T
QuietNaN
()
{
return
std
::
numeric_limits
<
T
>::
quiet_NaN
();
}
__host__
__device__
static
constexpr
T
Infinity
()
{
return
std
::
numeric_limits
<
T
>::
infinity
();
}
};
...
...
@@ -2347,6 +2959,119 @@ struct NumericLimits<bf8_ocp_t>
return
bit_cast
<
bf8_ocp_t
>
(
binary_qnan
);
}
};
#endif
template
<
>
struct
NumericLimits
<
f4_t
>
{
static
constexpr
uint8_t
binary_min_normal
=
0x2
;
// 0b0010
static
constexpr
uint8_t
binary_max_normal
=
0x7
;
// 0b0111
static
constexpr
uint8_t
binary_lowest_normal
=
0xF
;
// 0b1111
static
constexpr
uint8_t
binary_min_subnorm
=
0x1
;
// 0b0001
static
constexpr
uint8_t
binary_max_subnorm
=
0x1
;
// 0b0001
static
constexpr
float
data_max_normal_number
=
6
;
static
constexpr
float
data_min_subnormal_number
=
0.5
;
__host__
__device__
static
constexpr
f4_t
Min
()
{
return
f4_t
(
binary_min_normal
);
}
__host__
__device__
static
constexpr
f4_t
Max
()
{
return
f4_t
(
binary_max_normal
);
}
__host__
__device__
static
constexpr
f4_t
Lowest
()
{
return
f4_t
(
binary_lowest_normal
);
}
__host__
__device__
static
constexpr
f4_t
MinSubnorm
()
{
return
f4_t
(
binary_min_subnorm
);
}
__host__
__device__
static
constexpr
f4_t
MaxSubnorm
()
{
return
f4_t
(
binary_max_subnorm
);
}
__host__
__device__
static
constexpr
float
DataMaxNorm
()
{
return
data_max_normal_number
;
}
__host__
__device__
static
constexpr
float
DataMinSubnorm
()
{
return
data_min_subnormal_number
;
}
};
template
<
>
struct
NumericLimits
<
f6_t
>
{
static
constexpr
uint8_t
binary_min_normal
=
0x08
;
// 0b001000
static
constexpr
uint8_t
binary_max_normal
=
0x1F
;
// 0b011111
static
constexpr
uint8_t
binary_lowest_normal
=
0x3F
;
// 0b111111
static
constexpr
uint8_t
binary_min_subnorm
=
0x01
;
// 0b000001
static
constexpr
uint8_t
binary_max_subnorm
=
0x07
;
// 0b000111
static
constexpr
float
data_max_normal_number
=
7.5
;
static
constexpr
float
data_min_subnormal_number
=
0.125
;
__host__
__device__
static
constexpr
f6_t
Min
()
{
return
f6_t
(
binary_min_normal
&
0b111111
);
}
__host__
__device__
static
constexpr
f6_t
Max
()
{
return
f6_t
(
binary_max_normal
&
0b111111
);
}
__host__
__device__
static
constexpr
f6_t
Lowest
()
{
return
f6_t
(
binary_lowest_normal
&
0b111111
);
}
__host__
__device__
static
constexpr
f6_t
MinSubnorm
()
{
return
f6_t
(
binary_min_subnorm
&
0b111111
);
}
__host__
__device__
static
constexpr
f6_t
MaxSubnorm
()
{
return
f6_t
(
binary_max_subnorm
&
0b111111
);
}
__host__
__device__
static
constexpr
float
DataMaxNorm
()
{
return
data_max_normal_number
;
}
__host__
__device__
static
constexpr
float
DataMinSubnorm
()
{
return
data_min_subnormal_number
;
}
};
template
<
>
struct
NumericLimits
<
bf6_t
>
{
static
constexpr
uint8_t
binary_min_normal
=
0x08
;
// 0b001000
static
constexpr
uint8_t
binary_max_normal
=
0x1F
;
// 0b011111
static
constexpr
uint8_t
binary_lowest_normal
=
0x3F
;
// 0b111111
static
constexpr
uint8_t
binary_min_subnorm
=
0x01
;
// 0b000001
static
constexpr
uint8_t
binary_max_subnorm
=
0x03
;
// 0b000011
static
constexpr
float
data_max_normal_number
=
28
;
static
constexpr
float
data_min_subnormal_number
=
0.0625
;
__host__
__device__
static
constexpr
bf6_t
Min
()
{
return
bf6_t
(
binary_min_normal
);
}
__host__
__device__
static
constexpr
bf6_t
Max
()
{
return
bf6_t
(
binary_max_normal
);
}
__host__
__device__
static
constexpr
bf6_t
Lowest
()
{
return
bf6_t
(
binary_lowest_normal
);
}
__host__
__device__
static
constexpr
bf6_t
MinSubnorm
()
{
return
bf6_t
(
binary_min_subnorm
);
}
__host__
__device__
static
constexpr
bf6_t
MaxSubnorm
()
{
return
bf6_t
(
binary_max_subnorm
);
}
__host__
__device__
static
constexpr
float
DataMaxNorm
()
{
return
data_max_normal_number
;
}
__host__
__device__
static
constexpr
float
DataMinSubnorm
()
{
return
data_min_subnormal_number
;
}
};
template
<
>
struct
NumericLimits
<
e8m0_bexp_t
>
{
static
constexpr
e8m0_bexp_t
binary_min
=
0x00
;
// 0b00000000
static
constexpr
e8m0_bexp_t
binary_max
=
0xFE
;
// 0b11111110
static
constexpr
e8m0_bexp_t
binary_qnan
=
0xFF
;
// 0b11111111
static
constexpr
e8m0_bexp_t
binary_1
=
0x7F
;
// 0b01111111
static
constexpr
e8m0_bexp_t
binary_2
=
0x80
;
// 0b10000000
static
constexpr
e8m0_bexp_t
binary_3
=
0x82
;
// 0b10000010
static
constexpr
e8m0_bexp_t
binary_135
=
0x87
;
// 0b10000111
static
constexpr
e8m0_bexp_t
binary_142
=
0x8E
;
// 0b10001110
__host__
__device__
static
constexpr
e8m0_bexp_t
Min
()
{
return
e8m0_bexp_t
(
binary_min
);
}
__host__
__device__
static
constexpr
e8m0_bexp_t
Max
()
{
return
e8m0_bexp_t
(
binary_max
);
}
__host__
__device__
static
constexpr
e8m0_bexp_t
QuietNaN
()
{
return
e8m0_bexp_t
(
binary_qnan
);
}
__host__
__device__
static
constexpr
e8m0_bexp_t
Binary_1
()
{
return
e8m0_bexp_t
(
binary_1
);
}
__host__
__device__
static
constexpr
e8m0_bexp_t
Binary_2
()
{
return
e8m0_bexp_t
(
binary_2
);
}
__host__
__device__
static
constexpr
e8m0_bexp_t
Binary_3
()
{
return
e8m0_bexp_t
(
binary_3
);
}
__host__
__device__
static
constexpr
e8m0_bexp_t
Binary_135
()
{
return
e8m0_bexp_t
(
binary_135
);
}
__host__
__device__
static
constexpr
e8m0_bexp_t
Binary_142
()
{
return
e8m0_bexp_t
(
binary_142
);
}
};
template
<
typename
T
>
struct
NumericUtils
...
...
@@ -2367,6 +3092,7 @@ struct NumericUtils<float>
static
constexpr
uint32_t
NegInf
=
0xFF800000
;
static
constexpr
uint32_t
NaN
=
0x7F800001
;
static
constexpr
uint32_t
Neg0
=
0x80000000
;
static
constexpr
bool
has_inf
=
true
;
using
bitwise_type
=
uint32_t
;
};
...
...
@@ -2384,9 +3110,19 @@ struct NumericUtils<half_t>
static
constexpr
uint32_t
NegInf
=
0xFC00
;
static
constexpr
uint32_t
NaN
=
0x7C01
;
static
constexpr
uint32_t
Neg0
=
0x8000
;
static
constexpr
bool
has_inf
=
true
;
using
bitwise_type
=
uint16_t
;
};
template
<
>
struct
NumericUtils
<
bhalf_t
>
{
static
constexpr
int
exp
=
8
;
static
constexpr
int
mant
=
7
;
static
constexpr
int
bias
=
128
;
// negative zero nan mode
// static constexpr int bias = 127; // ieee mode
};
template
<
>
struct
NumericUtils
<
f8_fnuz_t
>
{
...
...
@@ -2394,6 +3130,7 @@ struct NumericUtils<f8_fnuz_t>
static
constexpr
int
mant
=
3
;
static
constexpr
int
bias
=
8
;
// negative zero nan mode
// static constexpr int bias = 7; // ieee mode
static
constexpr
bool
has_inf
=
false
;
};
template
<
>
...
...
@@ -2403,6 +3140,7 @@ struct NumericUtils<bf8_fnuz_t>
static
constexpr
int
mant
=
2
;
static
constexpr
int
bias
=
16
;
// negative zero nan mode
// static constexpr int bias = 15; // ieee mode
static
constexpr
bool
has_inf
=
false
;
};
template
<
>
struct
NumericUtils
<
f8_ocp_t
>
...
...
@@ -2421,11 +3159,109 @@ struct NumericUtils<bf8_ocp_t>
};
template
<
>
struct
NumericUtils
<
bhalf_t
>
struct
NumericUtils
<
f4_t
>
{
static
constexpr
int
exp
=
2
;
static
constexpr
int
mant
=
1
;
static
constexpr
int
bias
=
1
;
static
constexpr
uint32_t
sr_shift
=
10
;
static
constexpr
int
unbiased_exp_min
=
0
;
static
constexpr
int
unbiased_exp_max
=
2
;
static
constexpr
int
biased_exp_min
=
1
;
static
constexpr
int
biased_exp_max
=
3
;
static
constexpr
uint8_t
positive_zero_mask
=
0b0000
;
static
constexpr
uint8_t
negative_zero_mask
=
0b1000
;
static
constexpr
uint8_t
one_mask
=
0b0010
;
static
constexpr
uint8_t
set_sign_mask
=
0b0111
;
static
constexpr
uint8_t
data_max_positive_normal_mask
=
0b0111
;
static
constexpr
uint8_t
data_max_negative_normal_mask
=
0b1111
;
static
constexpr
uint8_t
data_max_positive_subnormal_mask
=
0b0001
;
static
constexpr
uint8_t
data_max_negative_subnormal_mask
=
0b1001
;
static
constexpr
bool
has_inf
=
false
;
using
bitwise_type
=
uint8_t
;
};
template
<
>
struct
NumericUtils
<
f6_t
>
{
static
constexpr
int
exp
=
2
;
static
constexpr
int
mant
=
3
;
static
constexpr
int
bias
=
1
;
static
constexpr
uint32_t
sr_shift
=
12
;
static
constexpr
int
unbiased_exp_min
=
0
;
static
constexpr
int
unbiased_exp_max
=
2
;
static
constexpr
int
biased_exp_min
=
1
;
static
constexpr
int
biased_exp_max
=
3
;
static
constexpr
uint8_t
positive_zero_mask
=
0b000000
;
static
constexpr
uint8_t
negative_zero_mask
=
0b100000
;
static
constexpr
uint8_t
set_sign_mask
=
0b011111
;
static
constexpr
uint8_t
data_max_positive_normal_mask
=
0b011111
;
static
constexpr
uint8_t
data_max_negative_normal_mask
=
0b111111
;
static
constexpr
uint8_t
data_max_positive_subnormal_mask
=
0b000111
;
static
constexpr
uint8_t
data_max_negative_subnormal_mask
=
0b100111
;
static
constexpr
bool
has_inf
=
false
;
static
constexpr
bool
has_nan
=
false
;
static
constexpr
bool
has_zero
=
true
;
using
bitwise_type
=
uint8_t
;
};
template
<
>
struct
NumericUtils
<
bf6_t
>
{
static
constexpr
int
exp
=
3
;
static
constexpr
int
mant
=
2
;
static
constexpr
int
bias
=
3
;
static
constexpr
uint32_t
sr_shift
=
11
;
static
constexpr
int
unbiased_exp_min
=
-
2
;
static
constexpr
int
unbiased_exp_max
=
4
;
static
constexpr
int
biased_exp_min
=
1
;
static
constexpr
int
biased_exp_max
=
7
;
static
constexpr
uint8_t
positive_zero_mask
=
0b000000
;
static
constexpr
uint8_t
negative_zero_mask
=
0b100000
;
static
constexpr
uint8_t
set_sign_mask
=
0b011111
;
static
constexpr
uint8_t
data_max_positive_normal_mask
=
0b011111
;
static
constexpr
uint8_t
data_max_negative_normal_mask
=
0b111111
;
static
constexpr
uint8_t
data_max_positive_subnormal_mask
=
0b000011
;
static
constexpr
uint8_t
data_max_negative_subnormal_mask
=
0b100011
;
static
constexpr
bool
has_inf
=
false
;
static
constexpr
bool
has_nan
=
false
;
static
constexpr
bool
has_zero
=
true
;
using
bitwise_type
=
uint8_t
;
};
template
<
>
struct
NumericUtils
<
e8m0_bexp_t
>
{
static
constexpr
int
exp
=
8
;
static
constexpr
int
mant
=
7
;
static
constexpr
int
bias
=
128
;
// negative zero nan mode
// static constexpr int bias = 127; // ieee mode
static
constexpr
int
mant
=
0
;
static
constexpr
int
bias
=
127
;
static
constexpr
int
unbiased_exp_min
=
-
127
;
static
constexpr
int
unbiased_exp_max
=
127
;
static
constexpr
int
biased_exp_min
=
0
;
static
constexpr
int
biased_exp_max
=
254
;
using
bitwise_type
=
uint8_t
;
};
}
// namespace ck
include/ck/utility/debug.hpp
View file @
dbb7002d
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-202
4
, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-202
5
, Advanced Micro Devices, Inc. All rights reserved.
#ifndef UTILITY_DEBUG_HPP
#define UTILITY_DEBUG_HPP
#include "type.hpp"
namespace
ck
{
namespace
debug
{
...
...
include/ck/utility/e8m0.hpp
0 → 100644
View file @
dbb7002d
// SPDX-License-Identifier: MIT
// Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck/utility/type.hpp"
namespace
ck
{
/**
* @brief Unsigned representation of a conventional biased Float32 exponent.
*
* bias = 127;
*
* E8M0_1 = 0b01111111; => 2^(127-127) = 1
* E8M0_2 = 0b10000000; => 2^(128-127) = 2^1 = 2
* E8M0_3 = 0b10000010; => 2^(130-127) = 2^3 = 8
* E8M0_135 = 0b10000111; => 2^(135-127) = 2^8 = 256
* E8M0_142 = 0b10001110; => 2^(142-127) = 2^15 = 32768
* E8M0_MIN = 0b00000000; => 2^-127
* E8M0_MAX = 0b11111110; => 2^127
* E8M0_NAN = 0b11111111; => NaN
*/
struct
e8m0_bexp_t
{
using
type
=
uint8_t
;
type
data
;
constexpr
static
type
bias
=
127
;
constexpr
static
type
nan_mask
=
0xFF
;
__host__
__device__
constexpr
e8m0_bexp_t
()
:
data
{
type
{}}
{}
__host__
__device__
constexpr
e8m0_bexp_t
(
type
init
)
:
data
{
init
}
{}
__host__
__device__
constexpr
e8m0_bexp_t
(
int
init
)
:
data
{
static_cast
<
type
>
(
init
&
nan_mask
)}
{
}
__host__
__device__
explicit
constexpr
e8m0_bexp_t
(
float
scale
)
:
data
{
static_cast
<
type
>
((
bit_cast
<
uint32_t
>
(
scale
)
&
(
nan_mask
<<
23
))
>>
23
)}
{
}
__host__
__device__
explicit
constexpr
operator
float
()
const
{
if
(
data
==
nan_mask
||
data
==
0
)
{
uint32_t
bits
=
data
<<
1
;
bits
|=
1
;
bits
<<=
22
;
return
bit_cast
<
float
>
(
bits
);
}
else
{
uint32_t
bits
=
data
<<
23
;
return
bit_cast
<
float
>
(
bits
);
}
}
__host__
__device__
constexpr
bool
operator
==
(
const
e8m0_bexp_t
&
other
)
const
{
// strict IEEE compliance for NaN
return
data
==
other
.
data
&&
data
!=
nan_mask
;
}
__host__
__device__
constexpr
bool
is_nan
()
const
{
return
data
==
nan_mask
;
}
};
namespace
utils
{
template
<
typename
T
>
__host__
__device__
inline
int
get_exponent_value
(
T
x
);
template
<
>
__host__
__device__
inline
int
get_exponent_value
<
e8m0_bexp_t
>
(
e8m0_bexp_t
x
)
{
return
x
.
data
;
}
}
// namespace utils
}
// namespace ck
include/ck/utility/enable_if.hpp
View file @
dbb7002d
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-202
3
, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-202
5
, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
namespace
ck
{
#ifndef CK_CODE_GEN_RTC
template
<
bool
B
,
typename
T
=
void
>
using
enable_if
=
std
::
enable_if
<
B
,
T
>
;
template
<
bool
B
,
typename
T
=
void
>
using
enable_if_t
=
typename
std
::
enable_if
<
B
,
T
>::
type
;
#else
template
<
bool
B
,
class
T
=
void
>
struct
enable_if
{
};
template
<
class
T
>
struct
enable_if
<
true
,
T
>
{
using
type
=
T
;
};
template
<
bool
B
,
class
T
=
void
>
using
enable_if_t
=
typename
enable_if
<
B
,
T
>::
type
;
#endif
}
// namespace ck
include/ck/utility/env.hpp
View file @
dbb7002d
// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2024
-2025
, Advanced Micro Devices, Inc. All rights reserved.
#ifndef CK_CODE_GEN_RTC
#pragma once
#include <cstdlib>
...
...
@@ -183,3 +184,4 @@ void UpdateEnvVar(EnvVar, const std::string_view& val)
}
}
// namespace ck
#endif
include/ck/utility/functional.hpp
View file @
dbb7002d
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-202
3
, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-202
5
, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
...
...
@@ -120,11 +120,11 @@ constexpr auto conditional_expr(X&& x, Y&& y)
{
if
constexpr
(
predicate
)
{
return
std
::
forward
<
X
>
(
x
);
return
ck
::
forward
<
X
>
(
x
);
}
else
{
return
std
::
forward
<
Y
>
(
y
);
return
ck
::
forward
<
Y
>
(
y
);
}
}
...
...
Prev
1
2
3
4
5
6
7
8
9
10
…
12
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment