Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel_ROCM
Commits
6dcc40d4
Unverified
Commit
6dcc40d4
authored
Feb 04, 2025
by
Max Podkorytov
Committed by
GitHub
Feb 04, 2025
Browse files
Merge branch 'develop' into ck-flex
parents
2c8e04aa
800cf897
Changes
171
Expand all
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
2726 additions
and
131 deletions
+2726
-131
include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_softmax_gemm_permute_xdl_cshuffle.hpp
...device_grouped_gemm_softmax_gemm_permute_xdl_cshuffle.hpp
+2
-3
include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_xdl.hpp
...sor_operation/gpu/device/impl/device_grouped_gemm_xdl.hpp
+1
-2
include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_xdl_fixed_nk.hpp
...tion/gpu/device/impl/device_grouped_gemm_xdl_fixed_nk.hpp
+1
-2
include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_xdl_splitk_cshuffle.hpp
...u/device/impl/device_grouped_gemm_xdl_splitk_cshuffle.hpp
+2
-3
include/ck/tensor_operation/gpu/device/impl/device_splitk_contraction_multiple_d_xdl_cshuffle.hpp
...mpl/device_splitk_contraction_multiple_d_xdl_cshuffle.hpp
+1
-2
include/ck/tensor_operation/gpu/element/unary_element_wise_operation.hpp
...or_operation/gpu/element/unary_element_wise_operation.hpp
+26
-67
include/ck/tensor_operation/gpu/grid/gridwise_batched_gemm_gemm_xdl_cshuffle_v1.hpp
...n/gpu/grid/gridwise_batched_gemm_gemm_xdl_cshuffle_v1.hpp
+1
-0
include/ck/tensor_operation/gpu/grid/gridwise_batched_gemm_multiple_d_gemm_multiple_d_xdl_cshuffle_v1.hpp
...tched_gemm_multiple_d_gemm_multiple_d_xdl_cshuffle_v1.hpp
+12
-5
include/ck/tensor_operation/gpu/grid/gridwise_batched_gemm_multiple_d_softmax_gemm_xdl_cshuffle_v1.hpp
..._batched_gemm_multiple_d_softmax_gemm_xdl_cshuffle_v1.hpp
+1
-0
include/ck/tensor_operation/gpu/grid/gridwise_batched_gemm_softmax_gemm_xdl_cshuffle_v1.hpp
...id/gridwise_batched_gemm_softmax_gemm_xdl_cshuffle_v1.hpp
+1
-0
include/ck/tensor_operation/gpu/warp/xdlops_gemm.hpp
include/ck/tensor_operation/gpu/warp/xdlops_gemm.hpp
+312
-11
include/ck/utility/amd_buffer_addressing.hpp
include/ck/utility/amd_buffer_addressing.hpp
+1
-1
include/ck/utility/amd_ck_fp8.hpp
include/ck/utility/amd_ck_fp8.hpp
+14
-28
include/ck/utility/amd_xdlops.hpp
include/ck/utility/amd_xdlops.hpp
+264
-1
include/ck/utility/data_type.hpp
include/ck/utility/data_type.hpp
+619
-6
include/ck/utility/e8m0.hpp
include/ck/utility/e8m0.hpp
+80
-0
include/ck/utility/mxf4_utils.hpp
include/ck/utility/mxf4_utils.hpp
+109
-0
include/ck/utility/mxf6_utils.hpp
include/ck/utility/mxf6_utils.hpp
+325
-0
include/ck/utility/mxf8_utils.hpp
include/ck/utility/mxf8_utils.hpp
+570
-0
include/ck/utility/mxfp_utils.hpp
include/ck/utility/mxfp_utils.hpp
+384
-0
No files found.
include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_softmax_gemm_permute_xdl_cshuffle.hpp
View file @
6dcc40d4
...
@@ -43,8 +43,7 @@ __global__ void
...
@@ -43,8 +43,7 @@ __global__ void
const
B1ElementwiseOperation
b1_element_op
,
const
B1ElementwiseOperation
b1_element_op
,
const
CElementwiseOperation
c_element_op
)
const
CElementwiseOperation
c_element_op
)
{
{
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx9__))
defined(__gfx94__))
__shared__
char
p_shared
[
GridwiseGemm
::
GetSharedMemoryNumberOfByte
()];
__shared__
char
p_shared
[
GridwiseGemm
::
GetSharedMemoryNumberOfByte
()];
const
index_t
block_id
=
get_block_1d_id
();
const
index_t
block_id
=
get_block_1d_id
();
...
@@ -109,7 +108,7 @@ __global__ void
...
@@ -109,7 +108,7 @@ __global__ void
ignore
=
acc_element_op
;
ignore
=
acc_element_op
;
ignore
=
b1_element_op
;
ignore
=
b1_element_op
;
ignore
=
c_element_op
;
ignore
=
c_element_op
;
#endif // end of if (defined(__gfx9
08__) || defined(__gfx90a
__))
#endif // end of if (defined(__gfx9__))
}
}
// Computes C = A * B0 * B1
// Computes C = A * B0 * B1
...
...
include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_xdl.hpp
View file @
6dcc40d4
...
@@ -38,8 +38,7 @@ __global__ void
...
@@ -38,8 +38,7 @@ __global__ void
const
BElementwiseOperation
b_element_op
,
const
BElementwiseOperation
b_element_op
,
const
CDEElementwiseOperation
c_element_op
)
const
CDEElementwiseOperation
c_element_op
)
{
{
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx9__))
defined(__gfx94__))
__shared__
char
p_shared
[
GridwiseGemm
::
GetSharedMemoryNumberOfByte
()];
__shared__
char
p_shared
[
GridwiseGemm
::
GetSharedMemoryNumberOfByte
()];
const
index_t
block_id
=
get_block_1d_id
();
const
index_t
block_id
=
get_block_1d_id
();
...
...
include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_xdl_fixed_nk.hpp
View file @
6dcc40d4
...
@@ -50,8 +50,7 @@ __global__ void
...
@@ -50,8 +50,7 @@ __global__ void
const
BElementwiseOperation
b_element_op
,
const
BElementwiseOperation
b_element_op
,
const
CDEElementwiseOperation
c_element_op
)
const
CDEElementwiseOperation
c_element_op
)
{
{
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx9__))
defined(__gfx94__))
__shared__
char
p_shared
[
GridwiseGemm
::
GetSharedMemoryNumberOfByte
()];
__shared__
char
p_shared
[
GridwiseGemm
::
GetSharedMemoryNumberOfByte
()];
const
index_t
block_id
=
get_block_1d_id
();
const
index_t
block_id
=
get_block_1d_id
();
...
...
include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_xdl_splitk_cshuffle.hpp
View file @
6dcc40d4
...
@@ -40,8 +40,7 @@ __global__ void
...
@@ -40,8 +40,7 @@ __global__ void
const
BElementwiseOperation
b_element_op
,
const
BElementwiseOperation
b_element_op
,
const
CElementwiseOperation
c_element_op
)
const
CElementwiseOperation
c_element_op
)
{
{
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx9__))
defined(__gfx94__))
constexpr
index_t
shared_size
=
GridwiseGemm
::
GetSharedMemoryNumberOfByte
();
constexpr
index_t
shared_size
=
GridwiseGemm
::
GetSharedMemoryNumberOfByte
();
__shared__
uint8_t
p_shared
[
shared_size
];
__shared__
uint8_t
p_shared
[
shared_size
];
...
@@ -80,7 +79,7 @@ __global__ void
...
@@ -80,7 +79,7 @@ __global__ void
ignore
=
a_element_op
;
ignore
=
a_element_op
;
ignore
=
b_element_op
;
ignore
=
b_element_op
;
ignore
=
c_element_op
;
ignore
=
c_element_op
;
#endif // end of if (defined(__gfx9
08__) || defined(__gfx90a
__))
#endif // end of if (defined(__gfx9__))
}
}
template
<
typename
ALayout
,
template
<
typename
ALayout
,
...
...
include/ck/tensor_operation/gpu/device/impl/device_splitk_contraction_multiple_d_xdl_cshuffle.hpp
View file @
6dcc40d4
...
@@ -56,8 +56,7 @@ __global__ void
...
@@ -56,8 +56,7 @@ __global__ void
const
ComputePtrOffsetOfBatch
compute_ptr_offset_of_batch
,
const
ComputePtrOffsetOfBatch
compute_ptr_offset_of_batch
,
const
Block2ETileMap
block_2_etile_map
)
const
Block2ETileMap
block_2_etile_map
)
{
{
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx9__))
defined(__gfx94__))
__shared__
char
p_shared
[
GridwiseGemm
::
GetSharedMemoryNumberOfByte
()];
__shared__
char
p_shared
[
GridwiseGemm
::
GetSharedMemoryNumberOfByte
()];
const
index_t
num_blocks_per_batch
=
const
index_t
num_blocks_per_batch
=
...
...
include/ck/tensor_operation/gpu/element/unary_element_wise_operation.hpp
View file @
6dcc40d4
...
@@ -16,7 +16,8 @@ namespace ck {
...
@@ -16,7 +16,8 @@ namespace ck {
// [Who Says Elephants Can't Run: Bringing Large Scale MoE Models into Cloud Scale Production]
// [Who Says Elephants Can't Run: Bringing Large Scale MoE Models into Cloud Scale Production]
// (https://arxiv.org/abs/2211.10017) and implementation:
// (https://arxiv.org/abs/2211.10017) and implementation:
// https://github.com/NVIDIA/FasterTransformer/blob/main/src/fastertransformer/cutlass_extensions/include/cutlass_extensions/interleaved_numeric_conversion.h
// https://github.com/NVIDIA/FasterTransformer/blob/main/src/fastertransformer/cutlass_extensions/include/cutlass_extensions/interleaved_numeric_conversion.h
__host__
__device__
inline
half4_t
pki4_to_half4
(
int
q
)
// Convert lower part of packed int4 -> int4 to half
__device__
inline
half4_t
i4_to_half4
(
int
q
)
{
{
const
int
LO
=
0x000f000f
;
const
int
LO
=
0x000f000f
;
const
int
HI
=
0x00f000f0
;
const
int
HI
=
0x00f000f0
;
...
@@ -44,7 +45,7 @@ __host__ __device__ inline half4_t pki4_to_half4(int q)
...
@@ -44,7 +45,7 @@ __host__ __device__ inline half4_t pki4_to_half4(int q)
return
res
.
template
AsType
<
half4_t
>()[
Number
<
0
>
{}];
return
res
.
template
AsType
<
half4_t
>()[
Number
<
0
>
{}];
}
}
__host__
__device__
inline
half4_t
pk
i4_to_half4_scale
(
int
q
,
const
ck
::
half2_t
&
scale
)
__device__
inline
half4_t
i4_to_half4_scale
(
int
q
,
const
ck
::
half2_t
&
scale
)
{
{
const
int
LO
=
0x000f000f
;
const
int
LO
=
0x000f000f
;
const
int
HI
=
0x00f000f0
;
const
int
HI
=
0x00f000f0
;
...
@@ -78,34 +79,7 @@ __host__ __device__ inline half4_t pki4_to_half4_scale(int q, const ck::half2_t&
...
@@ -78,34 +79,7 @@ __host__ __device__ inline half4_t pki4_to_half4_scale(int q, const ck::half2_t&
return
res
.
template
AsType
<
half4_t
>()[
Number
<
0
>
{}];
return
res
.
template
AsType
<
half4_t
>()[
Number
<
0
>
{}];
}
}
__host__
__device__
inline
half2_t
pki4_to_half2
(
pk_i4_t
q
)
__device__
inline
bhalf4_t
i4_to_bhalf4
(
int
q
)
{
#if 1
uint8_t
x_u8
=
ck
::
bit_cast
<
uint8_t
>
(
q
);
uint32_t
i4s
=
((
x_u8
&
0x0f
)
<<
16
)
|
((
x_u8
&
0xf0
)
>>
4
);
const
int
EX
=
0x64006400
;
const
int
SUB
=
0xE408E408
;
//-8
int
lo
=
i4s
|
EX
;
return
amd_assembly_pk_add_f16
(
bit_cast
<
half2_t
>
(
lo
),
bit_cast
<
half2_t
>
(
SUB
));
#else
uint8_t
x_u8
=
ck
::
bit_cast
<
uint8_t
>
(
q
);
vector_type
<
half_t
,
2
>
res
;
half_t
x_h
=
(
x_u8
&
0x0f
)
-
8
;
half_t
x_l
=
((
x_u8
&
0xf0
)
>>
4
)
-
8
;
res
.
template
AsType
<
half_t
>()(
Number
<
0
>
{})
=
x_l
;
res
.
template
AsType
<
half_t
>()(
Number
<
1
>
{})
=
x_h
;
return
res
.
template
AsType
<
half2_t
>()[
Number
<
0
>
{}];
#endif
}
__host__
__device__
inline
bhalf4_t
pki4_to_bhalf4
(
int
q
)
{
{
uint32_t
i8s
=
(
q
&
0xf
)
|
((
q
&
0xf0
)
<<
4
)
|
((
q
&
0xf00
)
<<
8
)
|
((
q
&
0xf000
)
<<
12
);
uint32_t
i8s
=
(
q
&
0xf
)
|
((
q
&
0xf0
)
<<
4
)
|
((
q
&
0xf00
)
<<
8
)
|
((
q
&
0xf000
)
<<
12
);
...
@@ -134,21 +108,6 @@ __host__ __device__ inline bhalf4_t pki4_to_bhalf4(int q)
...
@@ -134,21 +108,6 @@ __host__ __device__ inline bhalf4_t pki4_to_bhalf4(int q)
return
res
.
template
AsType
<
bhalf4_t
>()[
Number
<
0
>
{}];
return
res
.
template
AsType
<
bhalf4_t
>()[
Number
<
0
>
{}];
}
}
__host__
__device__
inline
bhalf2_t
pki4_to_bhalf2
(
pk_i4_t
q
)
{
uint8_t
x_u8
=
ck
::
bit_cast
<
uint8_t
>
(
q
);
float
x_h
=
((
x_u8
&
0x0f
)
>>
0
)
-
8.
f
;
float
x_l
=
((
x_u8
&
0xf0
)
>>
4
)
-
8.
f
;
vector_type
<
bhalf_t
,
2
>
res
;
res
.
template
AsType
<
bhalf_t
>()(
Number
<
0
>
{})
=
type_convert
<
bhalf_t
>
(
x_l
);
res
.
template
AsType
<
bhalf_t
>()(
Number
<
1
>
{})
=
type_convert
<
bhalf_t
>
(
x_h
);
return
res
.
template
AsType
<
bhalf2_t
>()[
Number
<
0
>
{}];
}
namespace
tensor_operation
{
namespace
tensor_operation
{
namespace
element_wise
{
namespace
element_wise
{
...
@@ -159,11 +118,11 @@ struct PassThroughPack8
...
@@ -159,11 +118,11 @@ struct PassThroughPack8
__host__
__device__
constexpr
void
operator
()(
ck
::
half8_t
&
y
,
const
ck
::
pk_i4x4_t
&
x
)
const
__host__
__device__
constexpr
void
operator
()(
ck
::
half8_t
&
y
,
const
ck
::
pk_i4x4_t
&
x
)
const
{
{
#if
1
#if
CK_USE_PK4_LAYOUT_SHUFFLE
vector_type
<
half_t
,
8
>
result
;
vector_type
<
half_t
,
8
>
result
;
result
.
template
AsType
<
half4_t
>()(
Number
<
0
>
{})
=
pk
i4_to_half4
(
bit_cast
<
int
>
(
x
));
result
.
template
AsType
<
half4_t
>()(
Number
<
0
>
{})
=
i4_to_half4
(
bit_cast
<
int
>
(
x
));
result
.
template
AsType
<
half4_t
>()(
Number
<
1
>
{})
=
pk
i4_to_half4
(
bit_cast
<
int
>
(
x
)
>>
8
);
result
.
template
AsType
<
half4_t
>()(
Number
<
1
>
{})
=
i4_to_half4
(
bit_cast
<
int
>
(
x
)
>>
8
);
y
=
result
.
template
AsType
<
half8_t
>()[
Number
<
0
>
{}];
y
=
result
.
template
AsType
<
half8_t
>()[
Number
<
0
>
{}];
#else
#else
...
@@ -171,13 +130,13 @@ struct PassThroughPack8
...
@@ -171,13 +130,13 @@ struct PassThroughPack8
vector_type
<
pk_i4_t
,
4
>
src
{
x
};
vector_type
<
pk_i4_t
,
4
>
src
{
x
};
dst
.
template
AsType
<
half2_t
>()(
Number
<
0
>
{})
=
dst
.
template
AsType
<
half2_t
>()(
Number
<
0
>
{})
=
pki4_to_
half2
(
src
.
template
AsType
<
pk_i4_t
>()[
Number
<
0
>
{}]);
type_convert
<
half2
_t
>
(
src
.
template
AsType
<
pk_i4_t
>()[
Number
<
0
>
{}]);
dst
.
template
AsType
<
half2_t
>()(
Number
<
1
>
{})
=
dst
.
template
AsType
<
half2_t
>()(
Number
<
1
>
{})
=
pki4_to_
half2
(
src
.
template
AsType
<
pk_i4_t
>()[
Number
<
1
>
{}]);
type_convert
<
half2
_t
>
(
src
.
template
AsType
<
pk_i4_t
>()[
Number
<
1
>
{}]);
dst
.
template
AsType
<
half2_t
>()(
Number
<
2
>
{})
=
dst
.
template
AsType
<
half2_t
>()(
Number
<
2
>
{})
=
pki4_to_
half2
(
src
.
template
AsType
<
pk_i4_t
>()[
Number
<
2
>
{}]);
type_convert
<
half2
_t
>
(
src
.
template
AsType
<
pk_i4_t
>()[
Number
<
2
>
{}]);
dst
.
template
AsType
<
half2_t
>()(
Number
<
3
>
{})
=
dst
.
template
AsType
<
half2_t
>()(
Number
<
3
>
{})
=
pki4_to_
half2
(
src
.
template
AsType
<
pk_i4_t
>()[
Number
<
3
>
{}]);
type_convert
<
half2
_t
>
(
src
.
template
AsType
<
pk_i4_t
>()[
Number
<
3
>
{}]);
y
=
dst
.
template
AsType
<
half8_t
>()[
Number
<
0
>
{}];
y
=
dst
.
template
AsType
<
half8_t
>()[
Number
<
0
>
{}];
#endif
#endif
...
@@ -185,11 +144,11 @@ struct PassThroughPack8
...
@@ -185,11 +144,11 @@ struct PassThroughPack8
__host__
__device__
constexpr
void
operator
()(
ck
::
bhalf8_t
&
y
,
const
ck
::
pk_i4x4_t
&
x
)
const
__host__
__device__
constexpr
void
operator
()(
ck
::
bhalf8_t
&
y
,
const
ck
::
pk_i4x4_t
&
x
)
const
{
{
#if
1
#if
CK_USE_PK4_LAYOUT_SHUFFLE
vector_type
<
bhalf_t
,
8
>
result
;
vector_type
<
bhalf_t
,
8
>
result
;
result
.
template
AsType
<
bhalf4_t
>()(
Number
<
0
>
{})
=
pk
i4_to_bhalf4
(
bit_cast
<
int
>
(
x
));
result
.
template
AsType
<
bhalf4_t
>()(
Number
<
0
>
{})
=
i4_to_bhalf4
(
bit_cast
<
int
>
(
x
));
result
.
template
AsType
<
bhalf4_t
>()(
Number
<
1
>
{})
=
pk
i4_to_bhalf4
(
bit_cast
<
int
>
(
x
)
>>
16
);
result
.
template
AsType
<
bhalf4_t
>()(
Number
<
1
>
{})
=
i4_to_bhalf4
(
bit_cast
<
int
>
(
x
)
>>
16
);
y
=
result
.
template
AsType
<
bhalf8_t
>()[
Number
<
0
>
{}];
y
=
result
.
template
AsType
<
bhalf8_t
>()[
Number
<
0
>
{}];
#else
#else
...
@@ -197,13 +156,13 @@ struct PassThroughPack8
...
@@ -197,13 +156,13 @@ struct PassThroughPack8
vector_type
<
pk_i4_t
,
4
>
src
{
x
};
vector_type
<
pk_i4_t
,
4
>
src
{
x
};
dst
.
template
AsType
<
bhalf2_t
>()(
Number
<
0
>
{})
=
dst
.
template
AsType
<
bhalf2_t
>()(
Number
<
0
>
{})
=
pki4_to_
bhalf2
(
src
.
template
AsType
<
pk_i4_t
>()[
Number
<
0
>
{}]);
type_convert
<
bhalf2
_t
>
(
src
.
template
AsType
<
pk_i4_t
>()[
Number
<
0
>
{}]);
dst
.
template
AsType
<
bhalf2_t
>()(
Number
<
1
>
{})
=
dst
.
template
AsType
<
bhalf2_t
>()(
Number
<
1
>
{})
=
pki4_to_
bhalf2
(
src
.
template
AsType
<
pk_i4_t
>()[
Number
<
1
>
{}]);
type_convert
<
bhalf2
_t
>
(
src
.
template
AsType
<
pk_i4_t
>()[
Number
<
1
>
{}]);
dst
.
template
AsType
<
bhalf2_t
>()(
Number
<
2
>
{})
=
dst
.
template
AsType
<
bhalf2_t
>()(
Number
<
2
>
{})
=
pki4_to_
bhalf2
(
src
.
template
AsType
<
pk_i4_t
>()[
Number
<
2
>
{}]);
type_convert
<
bhalf2
_t
>
(
src
.
template
AsType
<
pk_i4_t
>()[
Number
<
2
>
{}]);
dst
.
template
AsType
<
bhalf2_t
>()(
Number
<
3
>
{})
=
dst
.
template
AsType
<
bhalf2_t
>()(
Number
<
3
>
{})
=
pki4_to_
bhalf2
(
src
.
template
AsType
<
pk_i4_t
>()[
Number
<
3
>
{}]);
type_convert
<
bhalf2
_t
>
(
src
.
template
AsType
<
pk_i4_t
>()[
Number
<
3
>
{}]);
y
=
dst
.
template
AsType
<
bhalf8_t
>()[
Number
<
0
>
{}];
y
=
dst
.
template
AsType
<
bhalf8_t
>()[
Number
<
0
>
{}];
#endif
#endif
...
@@ -219,12 +178,12 @@ struct DequantPack8
...
@@ -219,12 +178,12 @@ struct DequantPack8
__host__
__device__
constexpr
void
__host__
__device__
constexpr
void
operator
()(
ck
::
half8_t
&
y
,
const
ck
::
pk_i4x4_t
&
x
,
const
ck
::
half2_t
&
z
)
const
operator
()(
ck
::
half8_t
&
y
,
const
ck
::
pk_i4x4_t
&
x
,
const
ck
::
half2_t
&
z
)
const
{
{
#if
1
#if
CK_USE_PK4_LAYOUT_SHUFFLE
vector_type
<
half_t
,
8
>
result
;
vector_type
<
half_t
,
8
>
result
;
result
.
template
AsType
<
half4_t
>()(
Number
<
0
>
{})
=
pk
i4_to_half4_scale
(
bit_cast
<
int
>
(
x
),
z
);
result
.
template
AsType
<
half4_t
>()(
Number
<
0
>
{})
=
i4_to_half4_scale
(
bit_cast
<
int
>
(
x
),
z
);
result
.
template
AsType
<
half4_t
>()(
Number
<
1
>
{})
=
result
.
template
AsType
<
half4_t
>()(
Number
<
1
>
{})
=
pk
i4_to_half4_scale
(
bit_cast
<
int
>
(
x
)
>>
8
,
z
);
i4_to_half4_scale
(
bit_cast
<
int
>
(
x
)
>>
8
,
z
);
y
=
result
.
template
AsType
<
half8_t
>()[
Number
<
0
>
{}];
y
=
result
.
template
AsType
<
half8_t
>()[
Number
<
0
>
{}];
#else
#else
...
@@ -232,13 +191,13 @@ struct DequantPack8
...
@@ -232,13 +191,13 @@ struct DequantPack8
vector_type
<
pk_i4_t
,
4
>
src
{
x
};
vector_type
<
pk_i4_t
,
4
>
src
{
x
};
dst
.
template
AsType
<
half2_t
>()(
Number
<
0
>
{})
=
dst
.
template
AsType
<
half2_t
>()(
Number
<
0
>
{})
=
pki4_to_
half2
(
src
.
template
AsType
<
pk_i4_t
>()[
Number
<
0
>
{}]);
type_convert
<
half2
_t
>
(
src
.
template
AsType
<
pk_i4_t
>()[
Number
<
0
>
{}]);
dst
.
template
AsType
<
half2_t
>()(
Number
<
1
>
{})
=
dst
.
template
AsType
<
half2_t
>()(
Number
<
1
>
{})
=
pki4_to_
half2
(
src
.
template
AsType
<
pk_i4_t
>()[
Number
<
1
>
{}]);
type_convert
<
half2
_t
>
(
src
.
template
AsType
<
pk_i4_t
>()[
Number
<
1
>
{}]);
dst
.
template
AsType
<
half2_t
>()(
Number
<
2
>
{})
=
dst
.
template
AsType
<
half2_t
>()(
Number
<
2
>
{})
=
pki4_to_
half2
(
src
.
template
AsType
<
pk_i4_t
>()[
Number
<
2
>
{}]);
type_convert
<
half2
_t
>
(
src
.
template
AsType
<
pk_i4_t
>()[
Number
<
2
>
{}]);
dst
.
template
AsType
<
half2_t
>()(
Number
<
3
>
{})
=
dst
.
template
AsType
<
half2_t
>()(
Number
<
3
>
{})
=
pki4_to_
half2
(
src
.
template
AsType
<
pk_i4_t
>()[
Number
<
3
>
{}]);
type_convert
<
half2
_t
>
(
src
.
template
AsType
<
pk_i4_t
>()[
Number
<
3
>
{}]);
y
=
dst
.
template
AsType
<
half8_t
>()[
Number
<
0
>
{}];
y
=
dst
.
template
AsType
<
half8_t
>()[
Number
<
0
>
{}];
#endif
#endif
...
@@ -260,7 +219,7 @@ struct PassThroughPack2
...
@@ -260,7 +219,7 @@ struct PassThroughPack2
__host__
__device__
constexpr
void
operator
()(
ck
::
half2_t
&
y
,
const
ck
::
pk_i4_t
&
x
)
const
__host__
__device__
constexpr
void
operator
()(
ck
::
half2_t
&
y
,
const
ck
::
pk_i4_t
&
x
)
const
{
{
#if
1
#if
CK_USE_PK4_LAYOUT_SHUFFLE
uint8_t
x_u8
=
ck
::
bit_cast
<
uint8_t
>
(
x
);
uint8_t
x_u8
=
ck
::
bit_cast
<
uint8_t
>
(
x
);
uint8_t
x_l
=
(
x_u8
&
0x0f
)
>>
0
;
uint8_t
x_l
=
(
x_u8
&
0x0f
)
>>
0
;
uint8_t
x_h
=
(
x_u8
&
0xf0
)
>>
4
;
uint8_t
x_h
=
(
x_u8
&
0xf0
)
>>
4
;
...
...
include/ck/tensor_operation/gpu/grid/gridwise_batched_gemm_gemm_xdl_cshuffle_v1.hpp
View file @
6dcc40d4
...
@@ -607,6 +607,7 @@ struct GridwiseBatchedGemmGemm_Xdl_CShuffle
...
@@ -607,6 +607,7 @@ struct GridwiseBatchedGemmGemm_Xdl_CShuffle
// with 'group_size' amount of contiguous elements. Having Gemm1KPack greater than A1K1 will
// with 'group_size' amount of contiguous elements. Having Gemm1KPack greater than A1K1 will
// cause mismatch in summation index for example c[0:7] = a1[[0:3, 8:11]] * b1[0:7].
// cause mismatch in summation index for example c[0:7] = a1[[0:3, 8:11]] * b1[0:7].
// therefore we may just as well assign Gemm1KPack = group_size
// therefore we may just as well assign Gemm1KPack = group_size
constexpr
index_t
Gemm1KPack
=
constexpr
index_t
Gemm1KPack
=
MfmaSelector
<
FloatAB
,
MPerXdl
,
NPerXdl
>::
selected_mfma
.
group_size
;
MfmaSelector
<
FloatAB
,
MPerXdl
,
NPerXdl
>::
selected_mfma
.
group_size
;
...
...
include/ck/tensor_operation/gpu/grid/gridwise_batched_gemm_multiple_d_gemm_multiple_d_xdl_cshuffle_v1.hpp
View file @
6dcc40d4
...
@@ -856,11 +856,18 @@ struct GridwiseBatchedGemmMultipleDGemmMultipleD_Xdl_CShuffle
...
@@ -856,11 +856,18 @@ struct GridwiseBatchedGemmMultipleDGemmMultipleD_Xdl_CShuffle
static_cast
<
A0B0B1DataType
*>
(
p_shared
)
+
SharedMemTrait
::
b1_block_space_offset
,
static_cast
<
A0B0B1DataType
*>
(
p_shared
)
+
SharedMemTrait
::
b1_block_space_offset
,
b1_block_desc_bk0_n_bk1
.
GetElementSpaceSize
());
b1_block_desc_bk0_n_bk1
.
GetElementSpaceSize
());
constexpr
index_t
Gemm1KPack
=
math
::
max
(
// selected_mfma.group_size or B1K1 <= Gemm1KPack <= selected_mfma.group_size
math
::
lcm
(
// selected_mfma.k_per_blk <= Gemm1KPack
MfmaSelector
<
A0B0B1DataType
,
Gemm0MPerXdl
,
Gemm0NPerXdl
>::
selected_mfma
.
group_size
,
//
B1K1
),
// Following similar rationale behind Gemm0KPack, let Gemm1KPack be the lowest common
MfmaSelector
<
A0B0B1DataType
,
Gemm0MPerXdl
,
Gemm0NPerXdl
>::
selected_mfma
.
k_per_blk
);
// multiples of A1K1 (predetermined by selected_mfma.group_size) and B1K1. But in this case
// Gemm1KPack can't be higher than A1K1 itself because A1 matrix is distributed in VGPRs
// with 'group_size' amount of contiguous elements. Having Gemm1KPack greater than A1K1 will
// cause mismatch in summation index for example c[0:7] = a1[[0:3, 8:11]] * b1[0:7].
// therefore we may just as well assign Gemm1KPack = group_size
constexpr
index_t
Gemm1KPack
=
MfmaSelector
<
A0B0B1DataType
,
Gemm0MPerXdl
,
Gemm0NPerXdl
>::
selected_mfma
.
group_size
;
auto
blockwise_gemm1
=
BlockwiseGemmXdlops_v2
<
auto
blockwise_gemm1
=
BlockwiseGemmXdlops_v2
<
BlockSize
,
BlockSize
,
...
...
include/ck/tensor_operation/gpu/grid/gridwise_batched_gemm_multiple_d_softmax_gemm_xdl_cshuffle_v1.hpp
View file @
6dcc40d4
...
@@ -773,6 +773,7 @@ struct GridwiseBatchedGemmMultipleDSoftmaxGemm_Xdl_CShuffle
...
@@ -773,6 +773,7 @@ struct GridwiseBatchedGemmMultipleDSoftmaxGemm_Xdl_CShuffle
// with 'group_size' amount of contiguous elements. Having Gemm1KPack greater than A1K1 will
// with 'group_size' amount of contiguous elements. Having Gemm1KPack greater than A1K1 will
// cause mismatch in summation index for example c[0:7] = a1[[0:3, 8:11]] * b1[0:7].
// cause mismatch in summation index for example c[0:7] = a1[[0:3, 8:11]] * b1[0:7].
// therefore we may just as well assign Gemm1KPack = group_size
// therefore we may just as well assign Gemm1KPack = group_size
constexpr
index_t
Gemm1KPack
=
constexpr
index_t
Gemm1KPack
=
MfmaSelector
<
FloatAB
,
MPerXdl
,
NPerXdl
>::
selected_mfma
.
group_size
;
MfmaSelector
<
FloatAB
,
MPerXdl
,
NPerXdl
>::
selected_mfma
.
group_size
;
...
...
include/ck/tensor_operation/gpu/grid/gridwise_batched_gemm_softmax_gemm_xdl_cshuffle_v1.hpp
View file @
6dcc40d4
...
@@ -628,6 +628,7 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle
...
@@ -628,6 +628,7 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle
// with 'group_size' amount of contiguous elements. Having Gemm1KPack greater than A1K1 will
// with 'group_size' amount of contiguous elements. Having Gemm1KPack greater than A1K1 will
// cause mismatch in summation index for example c[0:7] = a1[[0:3, 8:11]] * b1[0:7].
// cause mismatch in summation index for example c[0:7] = a1[[0:3, 8:11]] * b1[0:7].
// therefore we may just as well assign Gemm1KPack = group_size
// therefore we may just as well assign Gemm1KPack = group_size
constexpr
index_t
Gemm1KPack
=
constexpr
index_t
Gemm1KPack
=
MfmaSelector
<
FloatAB
,
MPerXdl
,
NPerXdl
>::
selected_mfma
.
group_size
;
MfmaSelector
<
FloatAB
,
MPerXdl
,
NPerXdl
>::
selected_mfma
.
group_size
;
...
...
include/ck/tensor_operation/gpu/warp/xdlops_gemm.hpp
View file @
6dcc40d4
...
@@ -37,7 +37,17 @@ enum struct MfmaInstr
...
@@ -37,7 +37,17 @@ enum struct MfmaInstr
mfma_f32_32x32x16f8bf8
,
mfma_f32_32x32x16f8bf8
,
mfma_f32_16x16x32f8bf8
,
mfma_f32_16x16x32f8bf8
,
mfma_f32_32x32x16bf8f8
,
mfma_f32_32x32x16bf8f8
,
mfma_f32_16x16x32bf8f8
mfma_f32_16x16x32bf8f8
,
mfma_f32_32x32x16f16
,
mfma_f32_16x16x32f16
,
mfma_f32_32x32x16bf16
,
mfma_f32_16x16x32bf16
,
mfma_i32_32x32x32i8
,
mfma_i32_16x16x64i8
,
mfma_f32_32x32x64f8f6f4
,
mfma_f32_16x16x128f8f6f4
,
mfma_scale_f32_32x32x64f8f6f4
,
mfma_scale_f32_16x16x128f8f6f4
};
};
template
<
MfmaInstr
instr
>
template
<
MfmaInstr
instr
>
...
@@ -198,6 +208,50 @@ struct mfma_type<MfmaInstr::mfma_f32_32x32x8f16>
...
@@ -198,6 +208,50 @@ struct mfma_type<MfmaInstr::mfma_f32_32x32x8f16>
}
}
};
};
template
<
>
struct
mfma_type
<
MfmaInstr
::
mfma_f32_32x32x16f16
>
{
static
constexpr
index_t
group_size
=
4
;
static
constexpr
index_t
num_groups_per_blk
=
4
;
static
constexpr
index_t
num_regs_per_blk
=
16
;
static
constexpr
index_t
num_threads_per_blk
=
32
;
static
constexpr
index_t
wave_size
=
64
;
static
constexpr
index_t
num_input_blks
=
2
;
static
constexpr
index_t
num_output_blks
=
1
;
static
constexpr
index_t
m_per_blk
=
32
;
static
constexpr
index_t
n_per_blk
=
32
;
static
constexpr
index_t
k_per_blk
=
8
;
static
constexpr
bool
is_k_reduction
=
true
;
template
<
index_t
MPerXdlops
,
index_t
NPerXdlops
,
class
FloatA
,
class
FloatB
,
class
FloatC
>
__device__
void
run
(
const
FloatA
&
a
,
const
FloatB
&
b
,
FloatC
&
reg_c
)
const
{
intrin_mfma_f32_32x32x16f16
<
MPerXdlops
,
NPerXdlops
>::
Run
(
a
,
b
,
reg_c
);
}
};
template
<
>
struct
mfma_type
<
MfmaInstr
::
mfma_f32_16x16x32f16
>
{
static
constexpr
index_t
group_size
=
4
;
static
constexpr
index_t
num_groups_per_blk
=
1
;
static
constexpr
index_t
num_regs_per_blk
=
4
;
static
constexpr
index_t
num_threads_per_blk
=
16
;
static
constexpr
index_t
wave_size
=
64
;
static
constexpr
index_t
num_input_blks
=
4
;
static
constexpr
index_t
num_output_blks
=
1
;
static
constexpr
index_t
m_per_blk
=
16
;
static
constexpr
index_t
n_per_blk
=
16
;
static
constexpr
index_t
k_per_blk
=
8
;
static
constexpr
bool
is_k_reduction
=
true
;
template
<
index_t
MPerXdlops
,
index_t
NPerXdlops
,
class
FloatA
,
class
FloatB
,
class
FloatC
>
__device__
void
run
(
const
FloatA
&
a
,
const
FloatB
&
b
,
FloatC
&
reg_c
)
const
{
intrin_mfma_f32_16x16x32f16
<
MPerXdlops
,
NPerXdlops
>::
Run
(
a
,
b
,
reg_c
);
}
};
template
<
>
template
<
>
struct
mfma_type
<
MfmaInstr
::
mfma_f32_16x16x16f16
>
struct
mfma_type
<
MfmaInstr
::
mfma_f32_16x16x16f16
>
{
{
...
@@ -264,6 +318,28 @@ struct mfma_type<MfmaInstr::mfma_f32_4x4x4f16>
...
@@ -264,6 +318,28 @@ struct mfma_type<MfmaInstr::mfma_f32_4x4x4f16>
}
}
};
};
template
<
>
struct
mfma_type
<
MfmaInstr
::
mfma_f32_32x32x16bf16
>
{
static
constexpr
index_t
group_size
=
4
;
static
constexpr
index_t
num_groups_per_blk
=
4
;
static
constexpr
index_t
num_regs_per_blk
=
16
;
static
constexpr
index_t
num_threads_per_blk
=
32
;
static
constexpr
index_t
wave_size
=
64
;
static
constexpr
index_t
num_input_blks
=
2
;
static
constexpr
index_t
num_output_blks
=
1
;
static
constexpr
index_t
m_per_blk
=
32
;
static
constexpr
index_t
n_per_blk
=
32
;
static
constexpr
index_t
k_per_blk
=
8
;
static
constexpr
bool
is_k_reduction
=
true
;
template
<
index_t
MPerXdlops
,
index_t
NPerXdlops
,
class
FloatA
,
class
FloatB
,
class
FloatC
>
__device__
void
run
(
const
FloatA
&
a
,
const
FloatB
&
b
,
FloatC
&
reg_c
)
const
{
intrin_mfma_f32_32x32x16bf16
<
MPerXdlops
,
NPerXdlops
>::
Run
(
a
,
b
,
reg_c
);
}
};
template
<
>
template
<
>
struct
mfma_type
<
MfmaInstr
::
mfma_f32_32x32x8bf16_1k
>
struct
mfma_type
<
MfmaInstr
::
mfma_f32_32x32x8bf16_1k
>
{
{
...
@@ -286,6 +362,28 @@ struct mfma_type<MfmaInstr::mfma_f32_32x32x8bf16_1k>
...
@@ -286,6 +362,28 @@ struct mfma_type<MfmaInstr::mfma_f32_32x32x8bf16_1k>
}
}
};
};
template
<
>
struct
mfma_type
<
MfmaInstr
::
mfma_f32_16x16x32bf16
>
{
static
constexpr
index_t
group_size
=
4
;
static
constexpr
index_t
num_groups_per_blk
=
1
;
static
constexpr
index_t
num_regs_per_blk
=
4
;
static
constexpr
index_t
num_threads_per_blk
=
16
;
static
constexpr
index_t
wave_size
=
64
;
static
constexpr
index_t
num_input_blks
=
4
;
static
constexpr
index_t
num_output_blks
=
1
;
static
constexpr
index_t
m_per_blk
=
16
;
static
constexpr
index_t
n_per_blk
=
16
;
static
constexpr
index_t
k_per_blk
=
8
;
static
constexpr
bool
is_k_reduction
=
true
;
template
<
index_t
MPerXdlops
,
index_t
NPerXdlops
,
class
FloatA
,
class
FloatB
,
class
FloatC
>
__device__
void
run
(
const
FloatA
&
a
,
const
FloatB
&
b
,
FloatC
&
reg_c
)
const
{
intrin_mfma_f32_16x16x32bf16
<
MPerXdlops
,
NPerXdlops
>::
Run
(
a
,
b
,
reg_c
);
}
};
template
<
>
template
<
>
struct
mfma_type
<
MfmaInstr
::
mfma_f32_16x16x16bf16_1k
>
struct
mfma_type
<
MfmaInstr
::
mfma_f32_16x16x16bf16_1k
>
{
{
...
@@ -440,6 +538,50 @@ struct mfma_type<MfmaInstr::mfma_i32_16x16x32i8>
...
@@ -440,6 +538,50 @@ struct mfma_type<MfmaInstr::mfma_i32_16x16x32i8>
}
}
};
};
template
<
>
struct
mfma_type
<
MfmaInstr
::
mfma_i32_32x32x32i8
>
{
static
constexpr
index_t
group_size
=
4
;
static
constexpr
index_t
num_groups_per_blk
=
4
;
static
constexpr
index_t
num_regs_per_blk
=
16
;
static
constexpr
index_t
num_threads_per_blk
=
32
;
static
constexpr
index_t
wave_size
=
64
;
static
constexpr
index_t
num_input_blks
=
2
;
static
constexpr
index_t
num_output_blks
=
1
;
static
constexpr
index_t
m_per_blk
=
32
;
static
constexpr
index_t
n_per_blk
=
32
;
static
constexpr
index_t
k_per_blk
=
16
;
static
constexpr
bool
is_k_reduction
=
true
;
template
<
index_t
MPerXdlops
,
index_t
NPerXdlops
,
class
FloatA
,
class
FloatB
,
class
FloatC
>
__device__
void
run
(
const
FloatA
&
a
,
const
FloatB
&
b
,
FloatC
&
reg_c
)
const
{
intrin_mfma_i32_32x32x32i8
<
MPerXdlops
,
NPerXdlops
>::
Run
(
a
,
b
,
reg_c
);
}
};
template
<
>
struct
mfma_type
<
MfmaInstr
::
mfma_i32_16x16x64i8
>
{
static
constexpr
index_t
group_size
=
4
;
static
constexpr
index_t
num_groups_per_blk
=
1
;
static
constexpr
index_t
num_regs_per_blk
=
4
;
static
constexpr
index_t
num_threads_per_blk
=
16
;
static
constexpr
index_t
wave_size
=
64
;
static
constexpr
index_t
num_input_blks
=
4
;
static
constexpr
index_t
num_output_blks
=
1
;
static
constexpr
index_t
m_per_blk
=
16
;
static
constexpr
index_t
n_per_blk
=
16
;
static
constexpr
index_t
k_per_blk
=
16
;
static
constexpr
bool
is_k_reduction
=
true
;
template
<
index_t
MPerXdlops
,
index_t
NPerXdlops
,
class
FloatA
,
class
FloatB
,
class
FloatC
>
__device__
void
run
(
const
FloatA
&
a
,
const
FloatB
&
b
,
FloatC
&
reg_c
)
const
{
intrin_mfma_i32_16x16x64i8
<
MPerXdlops
,
NPerXdlops
>::
Run
(
a
,
b
,
reg_c
);
}
};
template
<
>
template
<
>
struct
mfma_type
<
MfmaInstr
::
mfma_f64_16x16x4f64
>
struct
mfma_type
<
MfmaInstr
::
mfma_f64_16x16x4f64
>
{
{
...
@@ -638,16 +780,115 @@ struct mfma_type<MfmaInstr::mfma_f32_16x16x32bf8f8>
...
@@ -638,16 +780,115 @@ struct mfma_type<MfmaInstr::mfma_f32_16x16x32bf8f8>
}
}
};
};
// TODO: fix mfma...f8f6f4 instructions
template
<
>
struct
mfma_type
<
MfmaInstr
::
mfma_f32_32x32x64f8f6f4
>
{
// clang-format off
static
constexpr
index_t
group_size
=
4
;
// ??? group_size * num_groups_per_blk == num_regs_per_blk
static
constexpr
index_t
num_groups_per_blk
=
4
;
// ??? group_size * num_groups_per_blk == num_regs_per_blk
static
constexpr
index_t
num_regs_per_blk
=
16
;
// m_per_blk * n_per_blk / wave_size
static
constexpr
index_t
num_threads_per_blk
=
32
;
// n_per_blk
static
constexpr
index_t
wave_size
=
64
;
// fixed
static
constexpr
index_t
num_input_blks
=
2
;
// m_per_blk / num_regs_per_blk
static
constexpr
index_t
num_output_blks
=
1
;
// (is_k_reduction == true) ???
static
constexpr
index_t
m_per_blk
=
32
;
// from the instruction
static
constexpr
index_t
n_per_blk
=
32
;
// from the instruction
static
constexpr
index_t
k_per_blk
=
32
;
// (is_k_reduction == true) ? 64 / num_input_blks
static
constexpr
bool
is_k_reduction
=
true
;
// ???
// clang-format on
template
<
index_t
MPerXdlops
,
index_t
NPerXdlops
,
class
FloatA
,
class
FloatB
,
class
FloatC
>
__device__
void
run
(
const
FloatA
&
a
,
const
FloatB
&
b
,
FloatC
&
reg_c
)
const
{
intrin_mfma_f32_32x32x64f8f6f4
<
MPerXdlops
,
NPerXdlops
>::
Run
(
a
,
b
,
reg_c
);
}
};
template
<
>
struct
mfma_type
<
MfmaInstr
::
mfma_f32_16x16x128f8f6f4
>
{
// clang-format off
static
constexpr
index_t
group_size
=
4
;
// ??? group_size * num_groups_per_blk == num_regs_per_blk
static
constexpr
index_t
num_groups_per_blk
=
1
;
// ??? group_size * num_groups_per_blk == num_regs_per_blk
static
constexpr
index_t
num_regs_per_blk
=
4
;
// m_per_blk * n_per_blk / wave_size
static
constexpr
index_t
num_threads_per_blk
=
16
;
// == n_per_blk
static
constexpr
index_t
wave_size
=
64
;
// fixed
static
constexpr
index_t
num_input_blks
=
4
;
// m_per_blk / num_regs_per_blk
static
constexpr
index_t
num_output_blks
=
1
;
// (is_k_reduction == true) ???
static
constexpr
index_t
m_per_blk
=
16
;
// from the instruction
static
constexpr
index_t
n_per_blk
=
16
;
// from the instruction
static
constexpr
index_t
k_per_blk
=
32
;
// (is_k_reduction == true) ? 128 / num_input_blks
static
constexpr
bool
is_k_reduction
=
true
;
// ???
// clang-format on
template
<
index_t
MPerXdlops
,
index_t
NPerXdlops
,
class
FloatA
,
class
FloatB
,
class
FloatC
>
__device__
void
run
(
const
FloatA
&
a
,
const
FloatB
&
b
,
FloatC
&
reg_c
)
const
{
intrin_mfma_f32_16x16x128f8f6f4
<
MPerXdlops
,
NPerXdlops
>::
Run
(
a
,
b
,
reg_c
);
}
};
template
<
>
struct
mfma_type
<
MfmaInstr
::
mfma_scale_f32_32x32x64f8f6f4
>
{
// clang-format off
static
constexpr
index_t
group_size
=
4
;
// ??? group_size * num_groups_per_blk == num_regs_per_blk
static
constexpr
index_t
num_groups_per_blk
=
4
;
// ??? group_size * num_groups_per_blk == num_regs_per_blk
static
constexpr
index_t
num_regs_per_blk
=
16
;
// m_per_blk * n_per_blk / wave_size
static
constexpr
index_t
num_threads_per_blk
=
32
;
// n_per_blk
static
constexpr
index_t
wave_size
=
64
;
// fixed
static
constexpr
index_t
num_input_blks
=
2
;
// m_per_blk / num_regs_per_blk
static
constexpr
index_t
num_output_blks
=
1
;
// (is_k_reduction == true) ???
static
constexpr
index_t
m_per_blk
=
32
;
// from the instruction
static
constexpr
index_t
n_per_blk
=
32
;
// from the instruction
static
constexpr
index_t
k_per_blk
=
32
;
// (is_k_reduction == true) ? 64 / num_input_blks
static
constexpr
bool
is_k_reduction
=
true
;
// ???
// clang-format on
template
<
index_t
MPerXdlops
,
index_t
NPerXdlops
,
class
FloatA
,
class
FloatB
,
class
FloatC
>
__device__
void
run
(
const
FloatA
&
a
,
const
FloatB
&
b
,
FloatC
&
reg_c
)
const
{
intrin_mfma_scale_f32_32x32x64f8f6f4
<
MPerXdlops
,
NPerXdlops
>::
Run
(
a
,
b
,
reg_c
);
}
};
template
<
>
struct
mfma_type
<
MfmaInstr
::
mfma_scale_f32_16x16x128f8f6f4
>
{
// clang-format off
static
constexpr
index_t
group_size
=
4
;
// ??? group_size * num_groups_per_blk == num_regs_per_blk
static
constexpr
index_t
num_groups_per_blk
=
1
;
// ??? group_size * num_groups_per_blk == num_regs_per_blk
static
constexpr
index_t
num_regs_per_blk
=
4
;
// m_per_blk * n_per_blk / wave_size
static
constexpr
index_t
num_threads_per_blk
=
16
;
// == n_per_blk
static
constexpr
index_t
wave_size
=
64
;
// fixed
static
constexpr
index_t
num_input_blks
=
4
;
// m_per_blk / num_regs_per_blk
static
constexpr
index_t
num_output_blks
=
1
;
// (is_k_reduction == true) ???
static
constexpr
index_t
m_per_blk
=
16
;
// from the instruction
static
constexpr
index_t
n_per_blk
=
16
;
// from the instruction
static
constexpr
index_t
k_per_blk
=
32
;
// (is_k_reduction == true) ? 128 / num_input_blks
static
constexpr
bool
is_k_reduction
=
true
;
// ???
// clang-format on
template
<
index_t
MPerXdlops
,
index_t
NPerXdlops
,
class
FloatA
,
class
FloatB
,
class
FloatC
>
__device__
void
run
(
const
FloatA
&
a
,
const
FloatB
&
b
,
FloatC
&
reg_c
)
const
{
intrin_mfma_scale_f32_16x16x128f8f6f4
<
MPerXdlops
,
NPerXdlops
>::
Run
(
a
,
b
,
reg_c
);
}
};
template
<
typename
base_type
,
template
<
typename
base_type
,
index_t
MPerXdlops
,
index_t
MPerXdlops
,
index_t
NPerXdlops
,
index_t
NPerXdlops
,
typename
additional_type
=
base_type
>
typename
additional_type
=
base_type
,
bool
is_single_rate_mfma
=
false
>
struct
MfmaSelector
struct
MfmaSelector
{
{
template
<
typename
base_type_
,
template
<
typename
base_type_
,
index_t
MPerXdlops_
,
index_t
MPerXdlops_
,
index_t
NPerXdlops_
,
index_t
NPerXdlops_
,
typename
additional_type_
=
base_type_
>
typename
additional_type_
=
base_type_
,
bool
is_single_rate_mfma_
=
false
>
static
constexpr
auto
GetMfma
();
static
constexpr
auto
GetMfma
();
template
<
>
template
<
>
...
@@ -711,13 +952,32 @@ struct MfmaSelector
...
@@ -711,13 +952,32 @@ struct MfmaSelector
}
}
template
<
>
template
<
>
constexpr
auto
GetMfma
<
half_t
,
32
,
32
>
()
constexpr
auto
GetMfma
<
half_t
,
32
,
32
,
half_t
,
false
>
()
{
#if defined(__gfx950__)
return
MfmaInstr
::
mfma_f32_32x32x16f16
;
#else
return
MfmaInstr
::
mfma_f32_32x32x8f16
;
#endif
}
template
<
>
constexpr
auto
GetMfma
<
half_t
,
32
,
32
,
half_t
,
true
>
()
{
{
return
MfmaInstr
::
mfma_f32_32x32x8f16
;
return
MfmaInstr
::
mfma_f32_32x32x8f16
;
}
}
template
<
>
template
<
>
constexpr
auto
GetMfma
<
half_t
,
16
,
16
>
()
constexpr
auto
GetMfma
<
half_t
,
16
,
16
,
half_t
,
false
>
()
{
#if defined(__gfx950__)
return
MfmaInstr
::
mfma_f32_16x16x32f16
;
#else
return
MfmaInstr
::
mfma_f32_16x16x16f16
;
#endif
}
template
<
>
constexpr
auto
GetMfma
<
half_t
,
16
,
16
,
half_t
,
true
>
()
{
{
return
MfmaInstr
::
mfma_f32_16x16x16f16
;
return
MfmaInstr
::
mfma_f32_16x16x16f16
;
}
}
...
@@ -741,7 +1001,19 @@ struct MfmaSelector
...
@@ -741,7 +1001,19 @@ struct MfmaSelector
}
}
template
<
>
template
<
>
constexpr
auto
GetMfma
<
bhalf_t
,
32
,
32
>
()
constexpr
auto
GetMfma
<
bhalf_t
,
32
,
32
,
bhalf_t
,
false
>
()
{
#if defined(__gfx950__)
return
MfmaInstr
::
mfma_f32_32x32x16bf16
;
#elif defined(CK_USE_AMD_MFMA_BF16_1K_OP)
return
MfmaInstr
::
mfma_f32_32x32x8bf16_1k
;
#else
return
MfmaInstr
::
mfma_f32_32x32x4bf16
;
#endif
}
template
<
>
constexpr
auto
GetMfma
<
bhalf_t
,
32
,
32
,
bhalf_t
,
true
>
()
{
{
#if defined(CK_USE_AMD_MFMA_BF16_1K_OP)
#if defined(CK_USE_AMD_MFMA_BF16_1K_OP)
return
MfmaInstr
::
mfma_f32_32x32x8bf16_1k
;
return
MfmaInstr
::
mfma_f32_32x32x8bf16_1k
;
...
@@ -751,7 +1023,19 @@ struct MfmaSelector
...
@@ -751,7 +1023,19 @@ struct MfmaSelector
}
}
template
<
>
template
<
>
constexpr
auto
GetMfma
<
bhalf_t
,
16
,
16
>
()
constexpr
auto
GetMfma
<
bhalf_t
,
16
,
16
,
bhalf_t
,
false
>
()
{
#if defined(__gfx950__)
return
MfmaInstr
::
mfma_f32_16x16x32bf16
;
#elif defined(CK_USE_AMD_MFMA_BF16_1K_OP)
return
MfmaInstr
::
mfma_f32_16x16x16bf16_1k
;
#else
return
MfmaInstr
::
mfma_f32_16x16x8bf16
;
#endif
}
template
<
>
constexpr
auto
GetMfma
<
bhalf_t
,
16
,
16
,
bhalf_t
,
true
>
()
{
{
#if defined(CK_USE_AMD_MFMA_BF16_1K_OP)
#if defined(CK_USE_AMD_MFMA_BF16_1K_OP)
return
MfmaInstr
::
mfma_f32_16x16x16bf16_1k
;
return
MfmaInstr
::
mfma_f32_16x16x16bf16_1k
;
...
@@ -760,7 +1044,18 @@ struct MfmaSelector
...
@@ -760,7 +1044,18 @@ struct MfmaSelector
#endif
#endif
}
}
#if defined(CK_USE_AMD_MFMA_GFX940)
#if defined(__gfx950__)
template
<
>
constexpr
auto
GetMfma
<
int8_t
,
32
,
32
>
()
{
return
MfmaInstr
::
mfma_i32_32x32x32i8
;
}
template
<
>
constexpr
auto
GetMfma
<
int8_t
,
16
,
16
>
()
{
return
MfmaInstr
::
mfma_i32_16x16x64i8
;
}
#elif defined(__gfx942__)
template
<
>
template
<
>
constexpr
auto
GetMfma
<
int8_t
,
32
,
32
>
()
constexpr
auto
GetMfma
<
int8_t
,
32
,
32
>
()
{
{
...
@@ -832,8 +1127,8 @@ struct MfmaSelector
...
@@ -832,8 +1127,8 @@ struct MfmaSelector
return
MfmaInstr
::
mfma_f32_16x16x32bf8f8
;
return
MfmaInstr
::
mfma_f32_16x16x32bf8f8
;
}
}
static
constexpr
auto
selected_mfma
=
static
constexpr
auto
selected_mfma
=
mfma_type
<
mfma_type
<
GetMfma
<
base_type
,
MPerXdlops
,
NPerXdlops
,
additional_type
>
()
>
{};
GetMfma
<
base_type
,
MPerXdlops
,
NPerXdlops
,
additional_type
,
is_single_rate_mfma
>
()
>
{};
__host__
__device__
constexpr
MfmaSelector
()
__host__
__device__
constexpr
MfmaSelector
()
{
{
...
@@ -1135,7 +1430,13 @@ struct XdlopsGemm
...
@@ -1135,7 +1430,13 @@ struct XdlopsGemm
return
TransposeC
?
CIndex4D
{
blk_td
,
I0
,
blk_id
,
I0
}
:
CIndex4D
{
I0
,
blk_id
,
I0
,
blk_td
};
return
TransposeC
?
CIndex4D
{
blk_td
,
I0
,
blk_id
,
I0
}
:
CIndex4D
{
I0
,
blk_id
,
I0
,
blk_td
};
}
}
static
constexpr
auto
mfma
=
MfmaSelector
<
base_type
,
MPerXdlops
,
NPerXdlops
,
additional_type
>
{};
// Falls back to single rate instruction on gfx950 if KPack <= 4; no change on gfx942-
static
constexpr
auto
mfma
=
MfmaSelector
<
base_type
,
MPerXdlops
,
NPerXdlops
,
additional_type
,
((
is_same
<
base_type
,
half_t
>::
value
||
is_same
<
base_type
,
bhalf_t
>::
value
)
&&
KPack
<=
4
)
?
true
:
false
>
{};
static
constexpr
auto
mfma_instr
=
mfma
.
selected_mfma
;
static
constexpr
auto
mfma_instr
=
mfma
.
selected_mfma
;
...
...
include/ck/utility/amd_buffer_addressing.hpp
View file @
6dcc40d4
...
@@ -581,7 +581,7 @@ __device__ void amd_global_atomic_add_impl(const typename vector_type<T, N>::typ
...
@@ -581,7 +581,7 @@ __device__ void amd_global_atomic_add_impl(const typename vector_type<T, N>::typ
tmp
.
template
AsType
<
half2_t
>()[
i
]);
tmp
.
template
AsType
<
half2_t
>()[
i
]);
});
});
}
}
#if defined(__gfx942__)
#if defined(__gfx942__)
|| defined(__gfx950__)
else
if
constexpr
(
is_same
<
T
,
bhalf_t
>::
value
)
else
if
constexpr
(
is_same
<
T
,
bhalf_t
>::
value
)
{
{
vector_type
<
bhalf_t
,
N
>
tmp
{
src_thread_data
};
vector_type
<
bhalf_t
,
N
>
tmp
{
src_thread_data
};
...
...
include/ck/utility/amd_ck_fp8.hpp
View file @
6dcc40d4
...
@@ -20,39 +20,25 @@
...
@@ -20,39 +20,25 @@
#define CK_USE_OCP_FP8 0
#define CK_USE_OCP_FP8 0
#endif
#endif
namespace
{
// https://en.cppreference.com/w/cpp/types/conditional
template
<
bool
B
,
class
T
,
class
F
>
struct
conditional
{
using
type
=
T
;
};
template
<
class
T
,
class
F
>
struct
conditional
<
false
,
T
,
F
>
{
using
type
=
F
;
};
}
// namespace
namespace
ck
{
using
f8_fnuz_t
=
_BitInt
(
8
);
using
bf8_fnuz_t
=
unsigned
_BitInt
(
8
);
#if(defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__) || defined(__gfx1200__) || \
#if(defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__) || defined(__gfx1200__) || \
defined(__gfx1201__)
) &&
\
defined(__gfx1201__)
|| defined(__gfx950__)) &&
\
__HIP_DEVICE_COMPILE__
__HIP_DEVICE_COMPILE__
#define CK_FP8_CVT_FAST_PATH 1
#define CK_FP8_CVT_FAST_PATH 1
#else
#else
#define CK_FP8_CVT_FAST_PATH 0
#define CK_FP8_CVT_FAST_PATH 0
#endif
#endif
#if(defined(__gfx1200__) || defined(__gfx1201__)) && __HIP_DEVICE_COMPILE__
#if(defined(__gfx1200__) || defined(__gfx1201__)
|| defined(__gfx950__)
) && __HIP_DEVICE_COMPILE__
#define CK_OCP_FP8_CVT_FAST_PATH 1
#define CK_OCP_FP8_CVT_FAST_PATH 1
#else
#else
#define CK_OCP_FP8_CVT_FAST_PATH 0
#define CK_OCP_FP8_CVT_FAST_PATH 0
#endif
#endif
namespace
ck
{
using
f8_fnuz_t
=
_BitInt
(
8
);
using
bf8_fnuz_t
=
unsigned
_BitInt
(
8
);
typedef
unsigned
char
fp8_storage_t
;
typedef
unsigned
char
fp8_storage_t
;
/**
/**
...
@@ -207,10 +193,11 @@ __host__ __device__ static inline T cast_from_f8(fp8_storage_t x)
...
@@ -207,10 +193,11 @@ __host__ __device__ static inline T cast_from_f8(fp8_storage_t x)
}
}
}
}
typename
conditional
<
typename
std
::
conditional
<
sizeof
(
T
)
==
2
,
sizeof
(
T
)
==
2
,
unsigned
short
int
,
unsigned
short
int
,
typename
conditional
<
sizeof
(
T
)
==
4
,
unsigned
int
,
unsigned
long
long
>::
type
>::
type
retval
;
typename
std
::
conditional
<
sizeof
(
T
)
==
4
,
unsigned
int
,
unsigned
long
long
>::
type
>::
type
retval
;
if
constexpr
(
we
==
5
&&
is_half
&&
!
is_fnuz
)
if
constexpr
(
we
==
5
&&
is_half
&&
!
is_fnuz
)
{
{
...
@@ -303,7 +290,6 @@ static __device__ float2_t cast_to_f32x2_from_f8x2(fp8x2_storage_t v)
...
@@ -303,7 +290,6 @@ static __device__ float2_t cast_to_f32x2_from_f8x2(fp8x2_storage_t v)
return
__builtin_amdgcn_cvt_pk_f32_bf8
(
i16val
,
false
);
return
__builtin_amdgcn_cvt_pk_f32_bf8
(
i16val
,
false
);
}
}
}
}
#endif
#endif
}
// namespace fp8_impl
}
// namespace fp8_impl
...
@@ -378,7 +364,7 @@ struct bf8_ocp_t
...
@@ -378,7 +364,7 @@ struct bf8_ocp_t
__host__
explicit
operator
float
()
const
__host__
explicit
operator
float
()
const
#endif
#endif
{
{
#if defined(__gfx1200__) || defined(__gfx1201__)
#if
defined(__gfx950__) ||
defined(__gfx1200__) || defined(__gfx1201__)
return
fp8_impl
::
cast_to_f32_from_f8
<
default_interpret
>
(
this
->
data
);
return
fp8_impl
::
cast_to_f32_from_f8
<
default_interpret
>
(
this
->
data
);
#else
#else
return
fp8_impl
::
cast_from_f8
<
float
,
wm
,
we
,
false
>
(
return
fp8_impl
::
cast_from_f8
<
float
,
wm
,
we
,
false
>
(
...
@@ -392,7 +378,7 @@ struct bf8_ocp_t
...
@@ -392,7 +378,7 @@ struct bf8_ocp_t
__host__
explicit
operator
_Float16
()
const
__host__
explicit
operator
_Float16
()
const
#endif
#endif
{
{
#if defined(__gfx1200__) || defined(__gfx1201__)
#if
defined(__gfx950__) ||
defined(__gfx1200__) || defined(__gfx1201__)
return
static_cast
<
_Float16
>
(
fp8_impl
::
cast_to_f32_from_f8
<
default_interpret
>
(
this
->
data
));
return
static_cast
<
_Float16
>
(
fp8_impl
::
cast_to_f32_from_f8
<
default_interpret
>
(
this
->
data
));
#else
#else
return
fp8_impl
::
cast_from_f8
<
_Float16
,
wm
,
we
,
false
>
(
return
fp8_impl
::
cast_from_f8
<
_Float16
,
wm
,
we
,
false
>
(
...
@@ -553,10 +539,10 @@ __host__ __device__ static inline fp8_storage_t cast_to_f8(T _x, unsigned int rn
...
@@ -553,10 +539,10 @@ __host__ __device__ static inline fp8_storage_t cast_to_f8(T _x, unsigned int rn
constexpr
int
mfmt
=
(
sizeof
(
T
)
==
8
)
?
52
:
((
sizeof
(
T
)
==
4
)
?
23
:
10
);
constexpr
int
mfmt
=
(
sizeof
(
T
)
==
8
)
?
52
:
((
sizeof
(
T
)
==
4
)
?
23
:
10
);
using
T_bitwise
=
typename
conditional
<
using
T_bitwise
=
typename
std
::
conditional
<
sizeof
(
T
)
==
2
,
sizeof
(
T
)
==
2
,
unsigned
short
int
,
unsigned
short
int
,
typename
conditional
<
sizeof
(
T
)
==
4
,
unsigned
int
,
unsigned
long
long
>::
type
>::
type
;
typename
std
::
conditional
<
sizeof
(
T
)
==
4
,
unsigned
int
,
unsigned
long
long
>::
type
>::
type
;
T_bitwise
x_bitwise
=
bit_cast
<
T_bitwise
>
(
_x
);
T_bitwise
x_bitwise
=
bit_cast
<
T_bitwise
>
(
_x
);
unsigned
long
long
x
{
x_bitwise
};
unsigned
long
long
x
{
x_bitwise
};
...
...
include/ck/utility/amd_xdlops.hpp
View file @
6dcc40d4
...
@@ -5,7 +5,7 @@
...
@@ -5,7 +5,7 @@
namespace
ck
{
namespace
ck
{
// Define the common macro for MI300 models
// Define the common macro for MI300 models
#if defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__)
#if defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__)
|| defined(__gfx950__)
#define __gfx94__
#define __gfx94__
#endif
#endif
...
@@ -134,6 +134,46 @@ struct intrin_mfma_f32_32x32x4f16<32, 64>
...
@@ -134,6 +134,46 @@ struct intrin_mfma_f32_32x32x4f16<32, 64>
}
}
};
};
template
<
index_t
MPerWave
,
index_t
NPerWave
>
struct
intrin_mfma_f32_32x32x16f16
;
template
<
>
struct
intrin_mfma_f32_32x32x16f16
<
32
,
32
>
{
template
<
class
FloatC
>
__device__
static
void
Run
(
const
half8_t
&
reg_a
,
const
half8_t
&
reg_b
,
FloatC
&
reg_c
)
{
#if defined(__gfx950__)
reg_c
.
template
AsType
<
float16_t
>()(
Number
<
0
>
{})
=
__builtin_amdgcn_mfma_f32_32x32x16_f16
(
reg_a
,
reg_b
,
reg_c
.
template
AsType
<
float16_t
>()[
Number
<
0
>
{}],
0
,
0
,
0
);
#else
ignore
=
reg_a
;
ignore
=
reg_b
;
ignore
=
reg_c
;
#endif // defined(__gfx950__)
}
};
template
<
index_t
MPerWave
,
index_t
NPerWave
>
struct
intrin_mfma_f32_16x16x32f16
;
template
<
>
struct
intrin_mfma_f32_16x16x32f16
<
16
,
16
>
{
template
<
class
FloatC
>
__device__
static
void
Run
(
const
half8_t
&
reg_a
,
const
half8_t
&
reg_b
,
FloatC
&
reg_c
)
{
#if defined(__gfx950__)
reg_c
.
template
AsType
<
float4_t
>()(
Number
<
0
>
{})
=
__builtin_amdgcn_mfma_f32_16x16x32_f16
(
reg_a
,
reg_b
,
reg_c
.
template
AsType
<
float4_t
>()[
Number
<
0
>
{}],
0
,
0
,
0
);
#else
ignore
=
reg_a
;
ignore
=
reg_b
;
ignore
=
reg_c
;
#endif // defined(__gfx950__)
}
};
template
<
index_t
MPerWave
,
index_t
NPerWave
>
template
<
index_t
MPerWave
,
index_t
NPerWave
>
struct
intrin_mfma_f32_32x32x8f16
;
struct
intrin_mfma_f32_32x32x8f16
;
...
@@ -204,6 +244,46 @@ struct intrin_mfma_f32_4x4x4f16<8, 64>
...
@@ -204,6 +244,46 @@ struct intrin_mfma_f32_4x4x4f16<8, 64>
};
};
// bfp16
// bfp16
template
<
index_t
MPerWave
,
index_t
NPerWave
>
struct
intrin_mfma_f32_32x32x16bf16
;
template
<
>
struct
intrin_mfma_f32_32x32x16bf16
<
32
,
32
>
{
template
<
class
FloatC
>
__device__
static
void
Run
(
const
bhalf8_t
&
reg_a
,
const
bhalf8_t
&
reg_b
,
FloatC
&
reg_c
)
{
#if defined(__gfx950__)
reg_c
.
template
AsType
<
float16_t
>()(
Number
<
0
>
{})
=
__builtin_amdgcn_mfma_f32_32x32x16_bf16
(
reg_a
,
reg_b
,
reg_c
.
template
AsType
<
float16_t
>()[
Number
<
0
>
{}],
0
,
0
,
0
);
#else
ignore
=
reg_a
;
ignore
=
reg_b
;
ignore
=
reg_c
;
#endif // defined(__gfx950__)
}
};
template
<
index_t
MPerWave
,
index_t
NPerWave
>
struct
intrin_mfma_f32_16x16x32bf16
;
template
<
>
struct
intrin_mfma_f32_16x16x32bf16
<
16
,
16
>
{
template
<
class
FloatC
>
__device__
static
void
Run
(
const
bhalf8_t
&
reg_a
,
const
bhalf8_t
&
reg_b
,
FloatC
&
reg_c
)
{
#if defined(__gfx950__)
reg_c
.
template
AsType
<
float4_t
>()(
Number
<
0
>
{})
=
__builtin_amdgcn_mfma_f32_16x16x32_bf16
(
reg_a
,
reg_b
,
reg_c
.
template
AsType
<
float4_t
>()[
Number
<
0
>
{}],
0
,
0
,
0
);
#else
ignore
=
reg_a
;
ignore
=
reg_b
;
ignore
=
reg_c
;
#endif // defined(__gfx950__)
}
};
template
<
index_t
MPerWave
,
index_t
NPerWave
>
template
<
index_t
MPerWave
,
index_t
NPerWave
>
struct
intrin_mfma_f32_32x32x8bf16_1k
;
struct
intrin_mfma_f32_32x32x8bf16_1k
;
...
@@ -298,6 +378,46 @@ struct intrin_mfma_i32_16x16x16i8<16, 16>
...
@@ -298,6 +378,46 @@ struct intrin_mfma_i32_16x16x16i8<16, 16>
}
}
};
};
template
<
index_t
MPerWave
,
index_t
NPerWave
>
struct
intrin_mfma_i32_32x32x32i8
;
template
<
>
struct
intrin_mfma_i32_32x32x32i8
<
32
,
32
>
{
template
<
class
FloatC
>
__device__
static
void
Run
(
const
int8x16_t
&
reg_a
,
const
int8x16_t
&
reg_b
,
FloatC
&
reg_c
)
{
#if defined(__gfx950__)
reg_c
.
template
AsType
<
int32x16_t
>()(
Number
<
0
>
{})
=
__builtin_amdgcn_mfma_i32_32x32x32_i8
(
reg_a
,
reg_b
,
reg_c
.
template
AsType
<
int32x16_t
>()[
Number
<
0
>
{}],
0
,
0
,
0
);
#else
ignore
=
reg_a
;
ignore
=
reg_b
;
ignore
=
reg_c
;
#endif // defined(__gfx950__)
}
};
template
<
index_t
MPerWave
,
index_t
NPerWave
>
struct
intrin_mfma_i32_16x16x64i8
;
template
<
>
struct
intrin_mfma_i32_16x16x64i8
<
16
,
16
>
{
template
<
class
FloatC
>
__device__
static
void
Run
(
const
int8x16_t
&
reg_a
,
const
int8x16_t
&
reg_b
,
FloatC
&
reg_c
)
{
#if defined(__gfx950__)
reg_c
.
template
AsType
<
int32x4_t
>()(
Number
<
0
>
{})
=
__builtin_amdgcn_mfma_i32_16x16x64_i8
(
reg_a
,
reg_b
,
reg_c
.
template
AsType
<
int32x4_t
>()[
Number
<
0
>
{}],
0
,
0
,
0
);
#else
ignore
=
reg_a
;
ignore
=
reg_b
;
ignore
=
reg_c
;
#endif // defined(__gfx950__)
}
};
template
<
index_t
MPerWave
,
index_t
NPerWave
>
template
<
index_t
MPerWave
,
index_t
NPerWave
>
struct
intrin_mfma_i32_32x32x16i8
;
struct
intrin_mfma_i32_32x32x16i8
;
...
@@ -356,6 +476,149 @@ struct intrin_mfma_f64_16x16x4f64<16, 16>
...
@@ -356,6 +476,149 @@ struct intrin_mfma_f64_16x16x4f64<16, 16>
}
}
};
};
template
<
index_t
MPerWave
,
index_t
NPerWave
>
struct
intrin_mfma_f32_32x32x64f8f6f4
;
/// @brief Performs a matrix fused multiply-accumulate operation on 32x32x64 submatrices for f8, f6,
/// and f4 data types.
///
/// @note Calls scaled version of the instruction as the original instruction is not supported in
/// the backend. That is the intended use. There is a backend optimization to select the unscaled
/// operation if the scale is 0.
template
<
>
struct
intrin_mfma_f32_32x32x64f8f6f4
<
32
,
32
>
{
template
<
class
FloatC
>
__device__
static
void
Run
(
const
f8x32_t
&
reg_a
,
const
f8x32_t
&
reg_b
,
FloatC
&
reg_c
)
{
#if defined(__gfx950__)
reg_c
.
template
AsType
<
float16_t
>()(
Number
<
0
>
{})
=
__builtin_amdgcn_mfma_scale_f32_32x32x64_f8f6f4
(
reg_a
,
reg_b
,
reg_c
.
template
AsType
<
float16_t
>()[
Number
<
0
>
{}],
0
,
// cbsz
0
,
// blgp
0
,
0
,
0
,
0
);
#else
ignore
=
reg_a
;
ignore
=
reg_b
;
ignore
=
reg_c
;
#endif
}
};
template
<
index_t
MPerWave
,
index_t
NPerWave
>
struct
intrin_mfma_scale_f32_32x32x64f8f6f4
;
template
<
>
struct
intrin_mfma_scale_f32_32x32x64f8f6f4
<
32
,
32
>
{
template
<
class
FloatC
>
__device__
static
void
Run
(
const
f8x32_t
&
reg_a
,
const
int32_t
scale_a
,
const
f8x32_t
&
reg_b
,
const
int32_t
scale_b
,
FloatC
&
reg_c
)
{
#if defined(__gfx950__)
// https://github.com/ROCm/llvm-project/blob/656552edc693e2bb4abc9258399c39d190fce2b3/llvm/test/Verifier/AMDGPU/mfma-scale.ll#L10
reg_c
.
template
AsType
<
float16_t
>()(
Number
<
0
>
{})
=
__builtin_amdgcn_mfma_scale_f32_32x32x64_f8f6f4
(
reg_a
,
reg_b
,
reg_c
.
template
AsType
<
float16_t
>()[
Number
<
0
>
{}],
0
,
// cbsz
0
,
// blgp
0
,
// { OPSEL_HI[0], OPSEL[0] }?
scale_a
,
0
,
// { OPSEL_HI[1], OPSEL[1] }?
scale_b
);
#else
ignore
=
reg_a
;
ignore
=
scale_a
;
ignore
=
reg_b
;
ignore
=
scale_b
;
ignore
=
reg_c
;
#endif
}
};
template
<
index_t
MPerWave
,
index_t
NPerWave
>
struct
intrin_mfma_scale_f32_16x16x128f8f6f4
;
template
<
>
struct
intrin_mfma_scale_f32_16x16x128f8f6f4
<
16
,
16
>
{
template
<
class
FloatC
>
__device__
static
void
Run
(
const
f8x32_t
&
reg_a
,
const
int32_t
scale_a
,
const
f8x32_t
&
reg_b
,
const
int32_t
scale_b
,
FloatC
&
reg_c
)
{
#if defined(__gfx950__)
// https://github.com/ROCm/llvm-project/blob/656552edc693e2bb4abc9258399c39d190fce2b3/llvm/test/Verifier/AMDGPU/mfma-scale.ll#L10
reg_c
.
template
AsType
<
float4_t
>()(
Number
<
0
>
{})
=
__builtin_amdgcn_mfma_scale_f32_16x16x128_f8f6f4
(
reg_a
,
reg_b
,
reg_c
.
template
AsType
<
float4_t
>()[
Number
<
0
>
{}],
0
,
// cbsz
0
,
// blgp
0
,
// { OPSEL_HI[0], OPSEL[0] }?
scale_a
,
0
,
// { OPSEL_HI[1], OPSEL[1] }?
scale_b
);
#else
ignore
=
reg_a
;
ignore
=
scale_a
;
ignore
=
reg_b
;
ignore
=
scale_b
;
ignore
=
reg_c
;
#endif
}
};
template
<
index_t
MPerWave
,
index_t
NPerWave
>
struct
intrin_mfma_f32_16x16x128f8f6f4
;
/// @brief Performs a matrix fused multiply-accumulate operation on 16x16x128 submatrices for f8f6f4
/// data types.
///
/// @note Calls scaled version of the instruction as the original instruction is not supported in
/// the backend. That is the intended use. There is a backend optimization to select the unscaled
/// operation if the scale is 0.
template
<
>
struct
intrin_mfma_f32_16x16x128f8f6f4
<
16
,
16
>
{
template
<
class
FloatC
>
__device__
static
void
Run
(
const
f8x32_t
&
reg_a
,
const
f8x32_t
&
reg_b
,
FloatC
&
reg_c
)
{
#if defined(__gfx950__)
// https://github.com/ROCm/llvm-project/blob/656552edc693e2bb4abc9258399c39d190fce2b3/llvm/test/Verifier/AMDGPU/mfma-scale.ll#L10
reg_c
.
template
AsType
<
float4_t
>()(
Number
<
0
>
{})
=
__builtin_amdgcn_mfma_scale_f32_16x16x128_f8f6f4
(
reg_a
,
reg_b
,
reg_c
.
template
AsType
<
float4_t
>()[
Number
<
0
>
{}],
0
,
// cbsz
0
,
// blgp
0
,
0
,
0
,
0
);
#else
ignore
=
reg_a
;
ignore
=
reg_b
;
ignore
=
reg_c
;
#endif
}
};
template
<
index_t
MPerWave
,
index_t
NPerWave
>
template
<
index_t
MPerWave
,
index_t
NPerWave
>
struct
intrin_mfma_f32_32x32x16f8f8
;
struct
intrin_mfma_f32_32x32x16f8f8
;
...
...
include/ck/utility/data_type.hpp
View file @
6dcc40d4
This diff is collapsed.
Click to expand it.
include/ck/utility/e8m0.hpp
0 → 100644
View file @
6dcc40d4
// SPDX-License-Identifier: MIT
// Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck/utility/type.hpp"
namespace
ck
{
/**
* @brief Unsigned representation of a conventional biased Float32 exponent.
*
* bias = 127;
*
* E8M0_1 = 0b01111111; => 2^(127-127) = 1
* E8M0_2 = 0b10000000; => 2^(128-127) = 2^1 = 2
* E8M0_3 = 0b10000010; => 2^(130-127) = 2^3 = 8
* E8M0_135 = 0b10000111; => 2^(135-127) = 2^8 = 256
* E8M0_142 = 0b10001110; => 2^(142-127) = 2^15 = 32768
* E8M0_MIN = 0b00000000; => 2^-127
* E8M0_MAX = 0b11111110; => 2^127
* E8M0_NAN = 0b11111111; => NaN
*/
struct
e8m0_bexp_t
{
using
type
=
uint8_t
;
type
data
;
constexpr
static
type
bias
=
127
;
constexpr
static
type
nan_mask
=
0xFF
;
__host__
__device__
constexpr
e8m0_bexp_t
()
:
data
{
type
{}}
{}
__host__
__device__
constexpr
e8m0_bexp_t
(
type
init
)
:
data
{
init
}
{}
__host__
__device__
constexpr
e8m0_bexp_t
(
int
init
)
:
data
{
static_cast
<
type
>
(
init
&
nan_mask
)}
{
}
__host__
__device__
explicit
constexpr
e8m0_bexp_t
(
float
scale
)
:
data
{
static_cast
<
type
>
((
bit_cast
<
uint32_t
>
(
scale
)
&
(
nan_mask
<<
23
))
>>
23
)}
{
}
__host__
__device__
explicit
constexpr
operator
float
()
const
{
if
(
data
==
nan_mask
||
data
==
0
)
{
uint32_t
bits
=
data
<<
1
;
bits
|=
1
;
bits
<<=
22
;
return
bit_cast
<
float
>
(
bits
);
}
else
{
uint32_t
bits
=
data
<<
23
;
return
bit_cast
<
float
>
(
bits
);
}
}
__host__
__device__
constexpr
bool
operator
==
(
const
e8m0_bexp_t
&
other
)
const
{
// strict IEEE compliance for NaN
return
data
==
other
.
data
&&
data
!=
nan_mask
;
}
__host__
__device__
constexpr
bool
is_nan
()
const
{
return
data
==
nan_mask
;
}
};
namespace
utils
{
template
<
typename
T
>
__host__
__device__
inline
int
get_exponent_value
(
T
x
);
template
<
>
__host__
__device__
inline
int
get_exponent_value
<
e8m0_bexp_t
>
(
e8m0_bexp_t
x
)
{
return
x
.
data
;
}
}
// namespace utils
}
// namespace ck
include/ck/utility/mxf4_utils.hpp
0 → 100644
View file @
6dcc40d4
// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck/utility/data_type.hpp"
#include "ck/utility/mxfp_utils.hpp"
namespace
ck
::
utils
{
template
<
>
__host__
__device__
inline
bool
is_nan
<
f4_t
>
(
e8m0_bexp_t
const
scale
,
f4_t
const
dataBytes
[[
maybe_unused
]])
{
// no need to check for data as it does not have NaN representation
return
scale
==
NumericLimits
<
e8m0_bexp_t
>::
QuietNaN
();
}
// no infinity representation in ocp_e2m1_mxfp4 will always return false
template
<
>
__host__
__device__
inline
bool
is_inf
<
f4_t
>
(
e8m0_bexp_t
const
scale
[[
maybe_unused
]],
f4_t
const
data
[[
maybe_unused
]])
{
// no inf representation for ocp_e2m1_mxfp4
return
false
;
}
template
<
>
__host__
__device__
inline
bool
is_zero
<
f4_t
>
(
e8m0_bexp_t
const
scale
,
f4_t
const
data
)
{
if
(
is_nan
<
f4_t
>
(
scale
,
data
))
return
false
;
// no need to check for scale as it does not have a 0 representation
f4_t
result
=
(
data
&
0b00001111
)
&
NumericUtils
<
f4_t
>::
set_sign_mask
;
return
result
==
0b0
;
}
template
<
>
__host__
__device__
inline
float
to_float
<
f4_t
>
(
e8m0_bexp_t
const
scale
,
f4_t
const
data
)
{
if
(
is_nan
<
f4_t
>
(
scale
,
data
))
return
std
::
numeric_limits
<
float
>::
quiet_NaN
();
if
(
is_zero
<
f4_t
>
(
scale
,
data
))
return
0.0
f
;
f4_t
prepared_data
=
data
&
0b00001111
;
int
scale_exp
=
get_exponent_value
<
e8m0_bexp_t
>
(
scale
);
return
convert_to_float
<
f4_t
>
(
prepared_data
,
scale_exp
);
}
template
<
>
__host__
__device__
inline
f4_t
sat_convert_to_type
<
f4_t
>
(
float
value
)
{
cvt
t
;
t
.
value_float
=
value
;
uint32_t
sign
=
t
.
value_bitwise
>>
31
;
if
(
std
::
isnan
(
value
))
{
return
sign
?
NumericUtils
<
f4_t
>::
data_max_negative_normal_mask
:
NumericUtils
<
f4_t
>::
data_max_positive_normal_mask
;
}
if
(
std
::
abs
(
value
)
>
NumericLimits
<
f4_t
>::
Max
())
// covers inf case as well
return
sign
?
NumericUtils
<
f4_t
>::
data_max_negative_normal_mask
:
NumericUtils
<
f4_t
>::
data_max_positive_normal_mask
;
f4_t
res
=
convert_to_type
<
f4_t
>
(
value
);
if
(
std
::
abs
(
to_float
<
f4_t
>
(
NumericLimits
<
e8m0_bexp_t
>::
Binary_1
(),
res
))
<
NumericLimits
<
f4_t
>::
DataMinSubnorm
())
return
value
<
0
?
NumericUtils
<
f4_t
>::
negative_zero_mask
:
NumericUtils
<
f4_t
>::
positive_zero_mask
;
return
res
;
}
template
<
>
__host__
__device__
inline
f4_t
sat_convert_to_type_sr
<
f4_t
>
(
float
value
,
uint32_t
seed
)
{
cvt
t
;
t
.
value_float
=
value
;
uint32_t
sign
=
t
.
value_bitwise
>>
31
;
if
(
std
::
isnan
(
value
))
return
sign
?
NumericUtils
<
f4_t
>::
data_max_negative_normal_mask
:
NumericUtils
<
f4_t
>::
data_max_positive_normal_mask
;
if
(
std
::
abs
(
value
)
>
NumericLimits
<
f4_t
>::
Max
())
// covers inf case as well
return
sign
?
NumericUtils
<
f4_t
>::
data_max_negative_normal_mask
:
NumericUtils
<
f4_t
>::
data_max_positive_normal_mask
;
f4_t
res
=
convert_to_type_sr
<
f4_t
>
(
value
,
seed
);
if
(
std
::
abs
(
to_float
<
f4_t
>
(
NumericLimits
<
e8m0_bexp_t
>::
Binary_1
(),
res
))
<
NumericLimits
<
f4_t
>::
DataMinSubnorm
())
return
value
<
0
?
NumericUtils
<
f4_t
>::
negative_zero_mask
:
NumericUtils
<
f4_t
>::
positive_zero_mask
;
return
res
;
}
}
// namespace ck::utils
include/ck/utility/mxf6_utils.hpp
0 → 100644
View file @
6dcc40d4
// SPDX-License-Identifier: MIT
// Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck/utility/data_type.hpp"
#include "ck/utility/mxfp_utils.hpp"
namespace
ck
::
utils
{
/**
* @brief Checks if an f6_t value is NaN based on the provided scale.
*
* For f6_t data, NaN cannot be represented directly. Instead, this function
* determines NaN by checking if the scale is set to a quiet NaN.
*
* @param scale The exponent scale factor (e8m0_bexp_t) used for f6_t.
* @param dataBytes The f6_t value to check (unused in this implementation).
* @return true if the scale indicates a NaN value, false otherwise.
*/
template
<
>
__host__
__device__
inline
bool
is_nan
<
f6_t
>
(
e8m0_bexp_t
const
scale
,
f6_t
const
dataBytes
[[
maybe_unused
]])
{
// no need to check for data as it does not have NaN representation
return
scale
.
is_nan
();
}
/**
* @brief Checks if an bf6_t value is NaN based on the provided scale.
*
* For bf6_t data, NaN cannot be represented directly. Instead, this function
* determines NaN by checking if the scale is set to a quiet NaN.
*
* @param scale The exponent scale factor (e8m0_bexp_t) used for bf6_t.
* @param dataBytes The bf6_t value to check (unused in this implementation).
* @return true if the scale indicates a NaN value, false otherwise.
*/
template
<
>
__host__
__device__
inline
bool
is_nan
<
bf6_t
>
(
e8m0_bexp_t
const
scale
,
bf6_t
const
dataBytes
[[
maybe_unused
]])
{
// no need to check for data as it does not have NaN representation
return
scale
.
is_nan
();
}
/**
* @brief Checks if an f6_t value is infinite.
*
* Because f6_t does not support infinite values, this function always returns false.
*
* @param scale The exponent scale factor (e8m0_bexp_t) used for f6_t.
* @param data The f6_t value to check.
* @return Always false, as infinity is not represented in f6_t.
*/
template
<
>
__host__
__device__
inline
bool
is_inf
<
f6_t
>
(
e8m0_bexp_t
const
scale
[[
maybe_unused
]],
f6_t
const
data
[[
maybe_unused
]])
{
// no inf representation for fp6
return
false
;
}
/**
* @brief Checks if an bf6_t value is infinite.
*
* Because bf6_t does not support infinite values, this function always returns false.
*
* @param scale The exponent scale factor (e8m0_bexp_t) used for bf6_t.
* @param data The bf6_t value to check.
* @return Always false, as infinity is not represented in bf6_t.
*/
template
<
>
__host__
__device__
inline
bool
is_inf
<
bf6_t
>
(
e8m0_bexp_t
const
scale
[[
maybe_unused
]],
bf6_t
const
data
[[
maybe_unused
]])
{
// no inf representation for bf6
return
false
;
}
/**
* @brief Checks whether an f6_t value is zero.
*
* If the specified f6_t is NaN, this function returns false.
* Otherwise, it masks out the sign bits and checks if the remaining bits
* are zero.
*
* @param scale The exponent scale factor (e8m0_bexp_t) used for f6_t.
* @param data The f6_t value to check.
* @return true if the value is zero; otherwise false.
*/
template
<
>
__host__
__device__
inline
bool
is_zero
<
f6_t
>
(
e8m0_bexp_t
const
scale
,
f6_t
const
data
)
{
if
(
is_nan
<
f6_t
>
(
scale
,
data
))
return
false
;
// no need to check for scale as it does not have a 0 representation
f6_t
result
=
(
data
&
0b00111111
)
&
NumericUtils
<
f6_t
>::
set_sign_mask
;
return
result
==
0b0
;
}
/**
* @brief Checks whether an bf6_t value is zero.
*
* If the specified bf6_t is NaN, this function returns false.
* Otherwise, it masks out the sign bits and checks if the remaining bits
* are zero.
*
* @param scale The exponent scale factor (e8m0_bexp_t) used for bf6_t.
* @param data The bf6_t value to check.
* @return true if the value is zero; otherwise false.
*/
template
<
>
__host__
__device__
inline
bool
is_zero
<
bf6_t
>
(
e8m0_bexp_t
const
scale
,
bf6_t
const
data
)
{
if
(
is_nan
<
bf6_t
>
(
scale
,
data
))
return
false
;
// no need to check for scale as it does not have a 0 representation
bf6_t
result
=
(
data
&
0b00111111
)
&
NumericUtils
<
bf6_t
>::
set_sign_mask
;
return
result
==
0b0
;
}
/**
* @brief Converts an f6_t value to a float based on an e8m0_bexp_t scale factor.
*
* Checks if the f6_t value is NaN or zero before performing the conversion.
* Applies the exponent from the scale to compute the final float result.
*
* @param scale The exponent scale factor (e8m0_bexp_t) used for f6_t.
* @param data The f6_t value to convert.
* @return The converted float value.
*/
template
<
>
__host__
__device__
inline
float
to_float
<
f6_t
>
(
e8m0_bexp_t
const
scale
,
f6_t
const
data
)
{
if
(
is_nan
<
f6_t
>
(
scale
,
data
))
return
std
::
numeric_limits
<
float
>::
quiet_NaN
();
if
(
is_zero
<
f6_t
>
(
scale
,
data
))
return
0.0
f
;
f6_t
prepared_data
=
data
&
0b00111111
;
int
scale_exp
=
get_exponent_value
<
e8m0_bexp_t
>
(
scale
);
return
convert_to_float
<
f6_t
>
(
prepared_data
,
scale_exp
);
}
/**
* @brief Converts an bf6_t value to a float based on an e8m0_bexp_t scale factor.
*
* Checks if the bf6_t value is NaN or zero before performing the conversion.
* Applies the exponent from the scale to compute the final float result.
*
* @param scale The exponent scale factor (e8m0_bexp_t) used for bf6_t.
* @param data The bf6_t value to convert.
* @return The converted float value.
*/
template
<
>
__host__
__device__
inline
float
to_float
<
bf6_t
>
(
e8m0_bexp_t
const
scale
,
bf6_t
const
data
)
{
if
(
is_nan
<
bf6_t
>
(
scale
,
data
))
return
std
::
numeric_limits
<
float
>::
quiet_NaN
();
if
(
is_zero
<
bf6_t
>
(
scale
,
data
))
return
0.0
f
;
bf6_t
prepared_data
=
data
&
0b00111111
;
int
scale_exp
=
get_exponent_value
<
e8m0_bexp_t
>
(
scale
);
return
convert_to_float
<
bf6_t
>
(
prepared_data
,
scale_exp
);
}
/**
* @brief Converts a float to f6_t with saturation.
*
* If the input is NaN or exceeds the representable range for f6_t, returns
* the corresponding max normal mask. Handles subnormal cases by returning
* zero with the appropriate sign.
*
* @param value The float value to be converted.
* @return The saturated f6_t value.
*/
template
<
>
__host__
__device__
inline
f6_t
sat_convert_to_type
<
f6_t
>
(
float
value
)
{
cvt
t
;
t
.
value_float
=
value
;
uint32_t
sign
=
t
.
value_bitwise
>>
31
;
if
(
std
::
isnan
(
value
))
{
return
sign
?
NumericUtils
<
f6_t
>::
data_max_negative_normal_mask
:
NumericUtils
<
f6_t
>::
data_max_positive_normal_mask
;
}
if
(
std
::
abs
(
value
)
>
NumericLimits
<
f6_t
>::
Max
())
// covers inf case as well
return
sign
?
NumericUtils
<
f6_t
>::
data_max_negative_normal_mask
:
NumericUtils
<
f6_t
>::
data_max_positive_normal_mask
;
f6_t
res
=
convert_to_type
<
f6_t
>
(
value
);
if
(
std
::
abs
(
to_float
<
f6_t
>
(
NumericLimits
<
e8m0_bexp_t
>::
Binary_1
(),
res
))
<
NumericLimits
<
f6_t
>::
DataMinSubnorm
())
return
sign
?
NumericUtils
<
f6_t
>::
negative_zero_mask
:
NumericUtils
<
f6_t
>::
positive_zero_mask
;
return
res
;
}
/**
* @brief Converts a float to bf6_t with saturation.
*
* If the input is NaN or exceeds the representable range for bf6_t, returns
* the corresponding max normal mask. Handles subnormal cases by returning
* zero with the appropriate sign.
*
* @param value The float value to be converted.
* @return The saturated bf6_t value.
*/
template
<
>
__host__
__device__
inline
bf6_t
sat_convert_to_type
<
bf6_t
>
(
float
value
)
{
cvt
t
;
t
.
value_float
=
value
;
uint32_t
sign
=
t
.
value_bitwise
>>
31
;
if
(
std
::
isnan
(
value
))
{
return
sign
?
NumericUtils
<
bf6_t
>::
data_max_negative_normal_mask
:
NumericUtils
<
bf6_t
>::
data_max_positive_normal_mask
;
}
if
(
std
::
abs
(
value
)
>
NumericLimits
<
bf6_t
>::
Max
())
// covers inf case as well
return
sign
?
NumericUtils
<
bf6_t
>::
data_max_negative_normal_mask
:
NumericUtils
<
bf6_t
>::
data_max_positive_normal_mask
;
bf6_t
res
=
convert_to_type
<
bf6_t
>
(
value
);
if
(
std
::
abs
(
to_float
<
bf6_t
>
(
NumericLimits
<
e8m0_bexp_t
>::
Binary_1
(),
res
))
<
NumericLimits
<
bf6_t
>::
DataMinSubnorm
())
return
sign
?
NumericUtils
<
bf6_t
>::
negative_zero_mask
:
NumericUtils
<
bf6_t
>::
positive_zero_mask
;
return
res
;
}
/**
* @brief Converts a float to f6_t with saturation and stochastic rounding.
*
* If the input is NaN or exceeds the representable range for f6_t, returns
* the corresponding max normal mask. Handles subnormal cases by returning
* zero with the appropriate sign.
*
* @param value The float value to be converted.
* @return The saturated f6_t value.
*/
template
<
>
__host__
__device__
inline
f6_t
sat_convert_to_type_sr
<
f6_t
>
(
float
value
,
uint32_t
seed
)
{
cvt
t
;
t
.
value_float
=
value
;
uint32_t
sign
=
t
.
value_bitwise
>>
31
;
if
(
std
::
isnan
(
value
))
return
sign
?
NumericUtils
<
f6_t
>::
data_max_negative_normal_mask
:
NumericUtils
<
f6_t
>::
data_max_positive_normal_mask
;
if
(
std
::
abs
(
value
)
>
NumericLimits
<
f6_t
>::
Max
())
// covers inf case as well
return
sign
?
NumericUtils
<
f6_t
>::
data_max_negative_normal_mask
:
NumericUtils
<
f6_t
>::
data_max_positive_normal_mask
;
f6_t
res
=
convert_to_type_sr
<
f6_t
>
(
value
,
seed
);
if
(
std
::
abs
(
to_float
<
f6_t
>
(
NumericLimits
<
e8m0_bexp_t
>::
Binary_1
(),
res
))
<
NumericLimits
<
f6_t
>::
DataMinSubnorm
())
return
sign
?
NumericUtils
<
f6_t
>::
negative_zero_mask
:
NumericUtils
<
f6_t
>::
positive_zero_mask
;
return
res
;
}
/**
* @brief Converts a float to f6_t with saturation and stochastic rounding.
*
* If the input is NaN or exceeds the representable range for f6_t, returns
* the corresponding max normal mask. Handles subnormal cases by returning
* zero with the appropriate sign.
*
* @param value The float value to be converted.
* @return The saturated f6_t value.
*/
template
<
>
__host__
__device__
inline
bf6_t
sat_convert_to_type_sr
<
bf6_t
>
(
float
value
,
uint32_t
seed
)
{
cvt
t
;
t
.
value_float
=
value
;
uint32_t
sign
=
t
.
value_bitwise
>>
31
;
if
(
std
::
isnan
(
value
))
return
sign
?
NumericUtils
<
bf6_t
>::
data_max_negative_normal_mask
:
NumericUtils
<
bf6_t
>::
data_max_positive_normal_mask
;
if
(
std
::
abs
(
value
)
>
NumericLimits
<
bf6_t
>::
Max
())
// covers inf case as well
return
sign
?
NumericUtils
<
bf6_t
>::
data_max_negative_normal_mask
:
NumericUtils
<
bf6_t
>::
data_max_positive_normal_mask
;
bf6_t
res
=
convert_to_type_sr
<
bf6_t
>
(
value
,
seed
);
if
(
std
::
abs
(
to_float
<
bf6_t
>
(
NumericLimits
<
e8m0_bexp_t
>::
Binary_1
(),
res
))
<
NumericLimits
<
bf6_t
>::
DataMinSubnorm
())
return
sign
?
NumericUtils
<
bf6_t
>::
negative_zero_mask
:
NumericUtils
<
bf6_t
>::
positive_zero_mask
;
return
res
;
}
}
// namespace ck::utils
include/ck/utility/mxf8_utils.hpp
0 → 100644
View file @
6dcc40d4
#include "ck/utility/data_type.hpp"
#include "ck/utility/mxfp_utils.hpp"
#if defined(__gfx950__) && __HIP_DEVICE_COMPILE__
#define CK_MX_FP8_CVT_FAST_PATH 1
#else
#define CK_MX_FP8_CVT_FAST_PATH 0
#endif
namespace
ck
{
namespace
fp8_impl
{
#if CK_MX_FP8_CVT_FAST_PATH
template
<
ck_fp8_interpretation_t
interpret
>
static
__device__
float
cast_to_f32_from_f8_scaled
(
float
scale
,
fp8_storage_t
v
)
{
union
{
unsigned
int
i32val
;
unsigned
char
i8val
[
4
];
}
val
;
val
.
i8val
[
0
]
=
v
;
static_assert
(
interpret
==
ck_fp8_interpretation_t
::
CK_E4M3_OCP
||
interpret
==
ck_fp8_interpretation_t
::
CK_E5M2_OCP
,
"Only OCP interpretations are supported"
);
if
constexpr
(
interpret
==
ck_fp8_interpretation_t
::
CK_E4M3_OCP
)
{
return
__builtin_amdgcn_cvt_scalef32_f32_fp8
(
val
.
i32val
,
scale
,
0
);
}
else
{
return
__builtin_amdgcn_cvt_scalef32_f32_bf8
(
val
.
i32val
,
scale
,
0
);
}
}
template
<
ck_fp8_interpretation_t
interpret
>
static
__device__
float2_t
cast_to_f32x2_from_f8x2_scaled
(
float
scale
,
fp8x2_storage_t
v
)
{
const
auto
i16val
=
bit_cast
<
uint16_t
>
(
v
);
static_assert
(
interpret
==
ck_fp8_interpretation_t
::
CK_E4M3_OCP
||
interpret
==
ck_fp8_interpretation_t
::
CK_E5M2_OCP
,
"Only OCP interpretations are supported"
);
if
constexpr
(
interpret
==
ck_fp8_interpretation_t
::
CK_E4M3_OCP
)
{
return
__builtin_amdgcn_cvt_scalef32_pk_f32_fp8
(
i16val
,
scale
,
0
);
}
else
{
return
__builtin_amdgcn_cvt_scalef32_pk_f32_bf8
(
i16val
,
scale
,
0
);
}
}
template
<
ck_fp8_interpretation_t
interpret
,
bool
stochastic_rounding
=
false
>
static
__device__
fp8_storage_t
cast_to_f8_from_f32_scaled
(
float
v
,
unsigned
int
rng
=
0
,
float
scale
=
1.0
f
)
{
fp8_storage_t
i8data
;
union
{
float
fval
;
unsigned
int
i32val
;
}
val
;
union
{
uint32_t
ival
;
vector_type
<
int16_t
,
2
>::
type
v2i16
;
fp8_storage_t
v4i8
[
4
];
}
ret
{};
// unsigned int ival = 0;
val
.
fval
=
v
;
if
constexpr
(
stochastic_rounding
)
{
ret
.
ival
=
(
interpret
==
ck_fp8_interpretation_t
::
CK_E4M3_OCP
)
?
__builtin_amdgcn_cvt_scalef32_sr_fp8_f32
(
ret
.
ival
,
val
.
fval
,
rng
,
scale
,
0
)
:
__builtin_amdgcn_cvt_scalef32_sr_bf8_f32
(
ret
.
ival
,
val
.
fval
,
rng
,
scale
,
0
);
i8data
=
ret
.
v4i8
[
0
];
}
else
{
// RNE CVT
// llvm.amdgcn.cvt.scalef32.pk.fp8.f32
// v2i16 old_vdst, float srcA, float srcB, float scale, bool dst_lo_hi_sel
if
constexpr
(
interpret
==
ck_fp8_interpretation_t
::
CK_E4M3_OCP
)
{
// If fval / scale > max fp8, returns Nan
ret
.
v2i16
=
__builtin_amdgcn_cvt_scalef32_pk_fp8_f32
(
/*old_vdst*/
ret
.
v2i16
,
val
.
fval
,
val
.
fval
,
scale
,
/*dst_lo_hi_sel*/
false
);
}
else
{
// If fval / scale > max bf8, returns Inf
ret
.
v2i16
=
__builtin_amdgcn_cvt_scalef32_pk_bf8_f32
(
/*old_vdst*/
ret
.
v2i16
,
val
.
fval
,
val
.
fval
,
scale
,
/*dst_lo_hi_sel*/
false
);
}
i8data
=
ret
.
v4i8
[
0
];
}
return
i8data
;
}
template
<
ck_fp8_interpretation_t
interpret
,
bool
stochastic_rounding
=
false
>
static
__device__
fp8x2_storage_t
cast_to_f8_from_f32_scaled
(
float2_t
v
,
unsigned
int
rng
=
0
,
float
scale
=
1.0
f
)
{
union
{
uint32_t
ival
;
vector_type
<
int16_t
,
2
>::
type
v2i16
;
StaticallyIndexedArray
<
fp8x2_storage_t
,
2
>
v2f8x2
;
}
ret
{};
if
constexpr
(
stochastic_rounding
)
{
fp8x2_storage_t
f8x2
;
if
constexpr
(
interpret
==
ck_fp8_interpretation_t
::
CK_E4M3_OCP
)
{
ret
.
ival
=
__builtin_amdgcn_cvt_scalef32_sr_fp8_f32
(
ret
.
ival
,
v
[
0
],
rng
,
scale
,
0
);
f8x2
[
0
]
=
ret
.
v2f8x2
(
Number
<
0
>
{})[
0
];
ret
.
ival
=
__builtin_amdgcn_cvt_scalef32_sr_fp8_f32
(
ret
.
ival
,
v
[
1
],
rng
,
scale
,
0
);
f8x2
[
1
]
=
ret
.
v2f8x2
(
Number
<
0
>
{})[
0
];
}
else
{
ret
.
ival
=
__builtin_amdgcn_cvt_scalef32_sr_bf8_f32
(
ret
.
ival
,
v
[
0
],
rng
,
scale
,
0
);
f8x2
[
0
]
=
ret
.
v2f8x2
(
Number
<
0
>
{})[
0
];
ret
.
ival
=
__builtin_amdgcn_cvt_scalef32_sr_bf8_f32
(
ret
.
ival
,
v
[
1
],
rng
,
scale
,
0
);
f8x2
[
1
]
=
ret
.
v2f8x2
(
Number
<
0
>
{})[
0
];
}
return
f8x2
;
}
else
{
// RNE CVT
// llvm.amdgcn.cvt.scalef32.pk.fp8.f32
// v2i16 old_vdst, float srcA, float srcB, float scale, bool dst_lo_hi_sel
if
constexpr
(
interpret
==
ck_fp8_interpretation_t
::
CK_E4M3_OCP
)
{
// If fval / scale > max fp8, returns Nan
ret
.
v2i16
=
__builtin_amdgcn_cvt_scalef32_pk_fp8_f32
(
/*old_vdst*/
ret
.
v2i16
,
v
[
0
],
v
[
1
],
scale
,
/*dst_lo_hi_sel*/
false
);
}
else
{
// If fval / scale > max bf8, returns Inf
ret
.
v2i16
=
__builtin_amdgcn_cvt_scalef32_pk_bf8_f32
(
/*old_vdst*/
ret
.
v2i16
,
v
[
0
],
v
[
1
],
scale
,
/*dst_lo_hi_sel*/
false
);
}
return
ret
.
v2f8x2
(
Number
<
0
>
{});
}
}
#endif // CK_MX_FP8_CVT_FAST_PATH
#if CK_MX_FP8_CVT_FAST_PATH
/**
* \brief convert float to @p fp8_storage_t with scaling
*
* This version is used when the fast path (MX FP8 hardware) is available
*
* \tparam interp interpretation of fp8
* \param f float number
* \param scale scaling factor
* \return fp8_storage_t
*/
template
<
ck_fp8_interpretation_t
interp
,
bool
stochastic_rounding
=
false
>
__host__
__device__
static
inline
fp8_storage_t
cvt_float_to_fp8_scaled
(
const
float
f
,
float
scale
)
{
__is_interpret_supported
(
interp
);
uint32_t
rng
=
0
;
if
constexpr
(
stochastic_rounding
)
{
constexpr
int
seed
=
1254739
;
rng
=
prand_generator
<
float
,
seed
>
(
reinterpret_cast
<
uintptr_t
>
(
&
f
),
f
);
}
return
cast_to_f8_from_f32_scaled
<
interp
,
stochastic_rounding
>
(
f
,
rng
,
scale
);
}
/**
* \brief convert 2xfloat to @p 2xfp8_storage_t with scaling
*
* This version is used when the fast path (MX FP8 hardware) is available
*
* \tparam interp interpretation of fp8
* \param f 2xfloat
* \param scale scaling factor
* \return 2xfp8_storage_t
*/
template
<
ck_fp8_interpretation_t
interp
,
bool
stochastic_rounding
=
false
>
__host__
__device__
static
inline
fp8x2_storage_t
cvt_float_to_fp8_scaled
(
const
float2_t
f
,
float
scale
)
{
__is_interpret_supported
(
interp
);
uint32_t
rng
=
0
;
if
constexpr
(
stochastic_rounding
)
{
constexpr
int
seed
=
1254739
;
rng
=
prand_generator
<
float
,
seed
>
(
reinterpret_cast
<
uintptr_t
>
(
&
f
),
f
[
0
]);
}
return
cast_to_f8_from_f32_scaled
<
interp
,
stochastic_rounding
>
(
f
,
rng
,
scale
);
}
#else
/**
* \brief convert float to @p fp8_storage_t with scaling
*
* This version is used when the fast path (MX FP8 hardware) is not available
*
* \tparam interp interpretation of fp8
* \param f float number
* \param scale scaling factor
* \return fp8_storage_t
*/
template
<
ck_fp8_interpretation_t
interp
,
bool
stochastic_rounding
=
false
>
__host__
__device__
static
inline
fp8_storage_t
cvt_float_to_fp8_scaled
(
const
float
f
,
float
scale
)
{
static_assert
(
interp
==
ck_fp8_interpretation_t
::
CK_E4M3_OCP
||
interp
==
ck_fp8_interpretation_t
::
CK_E5M2_OCP
,
"Only OCP interpretations are supported"
);
uint32_t
rng
=
0
;
if
constexpr
(
stochastic_rounding
)
{
constexpr
int
seed
=
1254739
;
rng
=
prand_generator
<
float
,
seed
>
(
reinterpret_cast
<
uintptr_t
>
(
&
f
),
f
);
}
if
constexpr
(
interp
==
ck_fp8_interpretation_t
::
CK_E4M3_OCP
)
{
return
cast_to_f8
<
float
,
3
,
4
,
false
,
true
,
stochastic_rounding
>
(
f
/
scale
,
rng
);
}
else
if
constexpr
(
interp
==
ck_fp8_interpretation_t
::
CK_E5M2_OCP
)
{
return
cast_to_f8
<
float
,
2
,
5
,
false
,
true
,
stochastic_rounding
>
(
f
/
scale
,
rng
);
}
else
{
__hip_assert
(
false
&&
"FP8 type is not supported by current target device"
);
return
0
;
}
}
/**
* \brief convert two float to @p 2xfp8_storage_t with scaling
*
* This version is used when the fast path (MX FP8 hardware) is not available
*
* \tparam interp interpretation of fp8
* \param f 2xfloat
* \param scale scaling factor
* \return 2xfp8_storage_t
*/
template
<
ck_fp8_interpretation_t
interp
,
bool
stochastic_rounding
=
false
>
__host__
__device__
static
inline
fp8x2_storage_t
cvt_float_to_fp8_scaled
(
const
float2_t
f
,
float
scale
)
{
static_assert
(
interp
==
ck_fp8_interpretation_t
::
CK_E4M3_OCP
||
interp
==
ck_fp8_interpretation_t
::
CK_E5M2_OCP
,
"Only OCP interpretations are supported"
);
uint32_t
rng
=
0
;
if
constexpr
(
stochastic_rounding
)
{
constexpr
int
seed
=
1254739
;
rng
=
prand_generator
<
float
,
seed
>
(
reinterpret_cast
<
uintptr_t
>
(
&
f
),
f
[
0
]);
}
if
constexpr
(
interp
==
ck_fp8_interpretation_t
::
CK_E4M3_OCP
)
{
return
{
cast_to_f8
<
float
,
3
,
4
,
false
,
true
,
stochastic_rounding
>
(
f
[
0
]
/
scale
,
rng
),
cast_to_f8
<
float
,
3
,
4
,
false
,
true
,
stochastic_rounding
>
(
f
[
1
]
/
scale
,
rng
)};
}
else
if
constexpr
(
interp
==
ck_fp8_interpretation_t
::
CK_E5M2_OCP
)
{
return
{
cast_to_f8
<
float
,
2
,
5
,
false
,
true
,
stochastic_rounding
>
(
f
[
0
]
/
scale
,
rng
),
cast_to_f8
<
float
,
2
,
5
,
false
,
true
,
stochastic_rounding
>
(
f
[
1
]
/
scale
,
rng
)};
}
else
{
__hip_assert
(
false
&&
"FP8 type is not supported by current target device"
);
return
0
;
}
}
#endif // CK_MX_FP8_CVT_FAST_PATH
}
// namespace fp8_impl
// Declare a template function for fp8 conversion using SR
template
<
typename
Y
,
typename
X
>
__host__
__device__
constexpr
Y
mxf8_convert_sr
(
X
x
,
float
scale
);
// Declare a template function for fp8 conversion using RNE
template
<
typename
Y
,
typename
X
>
__host__
__device__
constexpr
Y
mxf8_convert_rne
(
X
x
,
float
scale
);
// convert fp32 to fp8 with rounding to nearest even
template
<
>
inline
__host__
__device__
f8_ocp_t
mxf8_convert_rne
<
f8_ocp_t
,
float
>
(
float
x
,
float
scale
)
{
return
f8_ocp_t
{
fp8_impl
::
cvt_float_to_fp8_scaled
<
f8_ocp_t
::
default_interpret
>
(
x
,
scale
)};
}
// convert fp32 to bf8 with rounding to nearest even
template
<
>
inline
__host__
__device__
bf8_ocp_t
mxf8_convert_rne
<
bf8_ocp_t
,
float
>
(
float
x
,
float
scale
)
{
return
bf8_ocp_t
{
fp8_impl
::
cvt_float_to_fp8_scaled
<
bf8_ocp_t
::
default_interpret
>
(
x
,
scale
)};
}
// convert fp32x2 to fp8x2 with rounding to nearest even
template
<
>
inline
__host__
__device__
f8x2_ocp_t
mxf8_convert_rne
<
f8x2_ocp_t
,
float2_t
>
(
float2_t
x
,
float
scale
)
{
return
f8x2_ocp_t
{
fp8_impl
::
cvt_float_to_fp8_scaled
<
f8_ocp_t
::
default_interpret
>
(
x
,
scale
)};
}
// convert fp32x2 to bf8x2 with rounding to nearest even
template
<
>
inline
__host__
__device__
bf8x2_ocp_t
mxf8_convert_rne
<
bf8x2_ocp_t
,
float2_t
>
(
float2_t
x
,
float
scale
)
{
return
bf8x2_ocp_t
{
fp8_impl
::
cvt_float_to_fp8_scaled
<
bf8_ocp_t
::
default_interpret
>
(
x
,
scale
)};
}
// convert fp32x16 to fp8x16 with rounding to nearest even
template
<
>
inline
__host__
__device__
f8x16_ocp_t
mxf8_convert_rne
<
f8x16_ocp_t
,
float16_t
>
(
float16_t
x
,
float
scale
)
{
union
{
float16_t
float_1x16
;
float2_t
float_2x8
[
8
];
}
in
{
x
};
union
{
f8x16_ocp_t
fp8_1x16
;
f8x2_ocp_t
fp8_2x8
[
8
];
}
out
{};
ck
::
static_for
<
0
,
8
,
1
>
{}(
[
&
](
auto
i
)
{
out
.
fp8_2x8
[
i
]
=
mxf8_convert_rne
<
f8x2_ocp_t
>
(
in
.
float_2x8
[
i
],
scale
);
});
return
out
.
fp8_1x16
;
}
// convert fp32x16 to bf8x16 with rounding to nearest even
template
<
>
inline
__host__
__device__
bf8x16_ocp_t
mxf8_convert_rne
<
bf8x16_ocp_t
,
float16_t
>
(
float16_t
x
,
float
scale
)
{
union
{
float16_t
float_1x16
;
float2_t
float_2x8
[
8
];
}
in
{
x
};
union
{
bf8x16_ocp_t
bf8_1x16
;
bf8x2_ocp_t
bf8_2x8
[
8
];
}
out
{};
ck
::
static_for
<
0
,
8
,
1
>
{}(
[
&
](
auto
i
)
{
out
.
bf8_2x8
[
i
]
=
mxf8_convert_rne
<
bf8x2_ocp_t
>
(
in
.
float_2x8
[
i
],
scale
);
});
return
out
.
bf8_1x16
;
}
// convert fp32x32 to fp8x32 with rounding to nearest even
template
<
>
inline
__host__
__device__
f8x32_ocp_t
mxf8_convert_rne
<
f8x32_ocp_t
,
float32_t
>
(
float32_t
x
,
float
scale
)
{
union
{
float32_t
float_1x32
;
float16_t
float_16x2
[
2
];
}
in
{
x
};
union
{
f8x32_ocp_t
fp8_1x32
;
f8x16_ocp_t
fp8_16x2
[
2
];
}
out
{};
ck
::
static_for
<
0
,
2
,
1
>
{}(
[
&
](
auto
i
)
{
out
.
fp8_16x2
[
i
]
=
mxf8_convert_rne
<
f8x16_ocp_t
>
(
in
.
float_16x2
[
i
],
scale
);
});
return
out
.
fp8_1x32
;
}
// convert fp32x32 to bf8x32 with rounding to nearest even
template
<
>
inline
__host__
__device__
bf8x32_ocp_t
mxf8_convert_rne
<
bf8x32_ocp_t
,
float32_t
>
(
float32_t
x
,
float
scale
)
{
union
{
float32_t
float_1x32
;
float16_t
float_16x2
[
2
];
}
in
{
x
};
union
{
bf8x32_ocp_t
bf8_1x32
;
bf8x16_ocp_t
bf8_16x2
[
2
];
}
out
{};
ck
::
static_for
<
0
,
2
,
1
>
{}(
[
&
](
auto
i
)
{
out
.
bf8_16x2
[
i
]
=
mxf8_convert_rne
<
bf8x16_ocp_t
>
(
in
.
float_16x2
[
i
],
scale
);
});
return
out
.
bf8_1x32
;
}
// convert fp32 to fp8 with stochastic rounding
template
<
>
inline
__host__
__device__
f8_ocp_t
mxf8_convert_sr
<
f8_ocp_t
,
float
>
(
float
x
,
float
scale
)
{
return
f8_ocp_t
{
fp8_impl
::
cvt_float_to_fp8_scaled
<
f8_ocp_t
::
default_interpret
,
true
>
(
x
,
scale
)};
}
// convert fp32 to bf8 with stochastic rounding
template
<
>
inline
__host__
__device__
bf8_ocp_t
mxf8_convert_sr
<
bf8_ocp_t
,
float
>
(
float
x
,
float
scale
)
{
return
bf8_ocp_t
{
fp8_impl
::
cvt_float_to_fp8_scaled
<
bf8_ocp_t
::
default_interpret
,
true
>
(
x
,
scale
)};
}
// convert fp32x2 to fp8x2 with stochastic rounding
template
<
>
inline
__host__
__device__
f8x2_ocp_t
mxf8_convert_sr
<
f8x2_ocp_t
,
float2_t
>
(
float2_t
x
,
float
scale
)
{
return
f8x2_ocp_t
{
fp8_impl
::
cvt_float_to_fp8_scaled
<
f8_ocp_t
::
default_interpret
,
true
>
(
x
,
scale
)};
}
// convert fp32x2 to bf8x2 with stochastic rounding
template
<
>
inline
__host__
__device__
bf8x2_ocp_t
mxf8_convert_sr
<
bf8x2_ocp_t
,
float2_t
>
(
float2_t
x
,
float
scale
)
{
return
bf8x2_ocp_t
{
fp8_impl
::
cvt_float_to_fp8_scaled
<
bf8_ocp_t
::
default_interpret
,
true
>
(
x
,
scale
)};
}
// convert fp32x16 to fp8x16 with stochastic rounding
template
<
>
inline
__host__
__device__
f8x16_ocp_t
mxf8_convert_sr
<
f8x16_ocp_t
,
float16_t
>
(
float16_t
x
,
float
scale
)
{
union
{
float16_t
float_1x16
;
float2_t
float_2x8
[
8
];
}
in
{
x
};
union
{
f8x16_ocp_t
fp8_1x16
;
f8x2_ocp_t
fp8_2x8
[
8
];
}
out
{};
ck
::
static_for
<
0
,
8
,
1
>
{}(
[
&
](
auto
i
)
{
out
.
fp8_2x8
[
i
]
=
mxf8_convert_sr
<
f8x2_ocp_t
>
(
in
.
float_2x8
[
i
],
scale
);
});
return
out
.
fp8_1x16
;
}
// convert fp32x16 to bf8x16 with stochastic rounding
template
<
>
inline
__host__
__device__
bf8x16_ocp_t
mxf8_convert_sr
<
bf8x16_ocp_t
,
float16_t
>
(
float16_t
x
,
float
scale
)
{
union
{
float16_t
float_1x16
;
float2_t
float_2x8
[
8
];
}
in
{
x
};
union
{
bf8x16_ocp_t
bf8_1x16
;
bf8x2_ocp_t
bf8_2x8
[
8
];
}
out
{};
ck
::
static_for
<
0
,
8
,
1
>
{}(
[
&
](
auto
i
)
{
out
.
bf8_2x8
[
i
]
=
mxf8_convert_sr
<
bf8x2_ocp_t
>
(
in
.
float_2x8
[
i
],
scale
);
});
return
out
.
bf8_1x16
;
}
// convert fp32x32 to fp8x32 with stochastic rounding
template
<
>
inline
__host__
__device__
f8x32_ocp_t
mxf8_convert_sr
<
f8x32_ocp_t
,
float32_t
>
(
float32_t
x
,
float
scale
)
{
union
{
float32_t
float_1x32
;
float16_t
float_16x2
[
2
];
}
in
{
x
};
union
{
f8x32_ocp_t
fp8_1x32
;
f8x16_ocp_t
fp8_16x2
[
2
];
}
out
{};
ck
::
static_for
<
0
,
2
,
1
>
{}(
[
&
](
auto
i
)
{
out
.
fp8_16x2
[
i
]
=
mxf8_convert_sr
<
f8x16_ocp_t
>
(
in
.
float_16x2
[
i
],
scale
);
});
return
out
.
fp8_1x32
;
}
// convert fp32x32 to bf8x32 with stochastic rounding
template
<
>
inline
__host__
__device__
bf8x32_ocp_t
mxf8_convert_sr
<
bf8x32_ocp_t
,
float32_t
>
(
float32_t
x
,
float
scale
)
{
union
{
float32_t
float_1x32
;
float16_t
float_16x2
[
2
];
}
in
{
x
};
union
{
bf8x32_ocp_t
bf8_1x32
;
bf8x16_ocp_t
bf8_16x2
[
2
];
}
out
{};
ck
::
static_for
<
0
,
2
,
1
>
{}(
[
&
](
auto
i
)
{
out
.
bf8_16x2
[
i
]
=
mxf8_convert_sr
<
bf8x16_ocp_t
>
(
in
.
float_16x2
[
i
],
scale
);
});
return
out
.
bf8_1x32
;
}
}
// namespace ck
include/ck/utility/mxfp_utils.hpp
0 → 100644
View file @
6dcc40d4
// SPDX-License-Identifier: MIT
// Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
namespace
ck
::
utils
{
union
cvt
{
float
value_float
;
uint32_t
value_bitwise
;
};
template
<
typename
DTYPE
>
inline
bool
getDataHasInf
()
{
return
DTYPE
::
dataInfo
.
hasInf
;
}
template
<
typename
T
>
__host__
__device__
inline
bool
is_zero
(
e8m0_bexp_t
const
scale
,
T
const
data
);
template
<
typename
T
>
__host__
__device__
inline
bool
is_nan
(
e8m0_bexp_t
const
scale
,
T
const
data
);
template
<
typename
T
>
__host__
__device__
inline
bool
is_inf
(
e8m0_bexp_t
const
scale
,
T
const
data
);
template
<
typename
T
>
__host__
__device__
inline
int
get_exponent_value
(
T
x
)
{
x
>>=
NumericUtils
<
T
>::
mant
;
x
&=
((
1
<<
NumericUtils
<
T
>::
exp
)
-
1
);
return
static_cast
<
int
>
(
x
);
}
template
<
typename
T
>
__host__
__device__
inline
bool
is_subnormal
(
T
x
)
{
return
get_exponent_value
<
T
>
(
x
)
==
0
;
}
template
<
typename
T
>
__host__
__device__
inline
double
get_mantissa_value
(
T
x
)
{
double
mantissa
=
is_subnormal
<
T
>
(
x
)
?
0.0
f
:
1.0
f
;
for
(
uint
i
=
0
;
i
<
NumericUtils
<
T
>::
mant
;
i
++
)
{
mantissa
+=
std
::
pow
(
2
,
-
int32_t
((
NumericUtils
<
T
>::
mant
-
i
)))
*
(
x
&
0b1
);
x
>>=
1
;
}
return
mantissa
;
}
template
<
typename
T
>
__host__
__device__
inline
bool
get_data_has_inf
()
{
return
NumericUtils
<
T
>::
has_inf
;
}
template
<
typename
T
>
__host__
__device__
float
convert_to_float
(
T
data
,
int
scale_exp
)
{
float
d_sign
=
std
::
pow
(
-
1
,
static_cast
<
float
>
(
data
>>
(
NumericUtils
<
T
>::
exp
+
NumericUtils
<
T
>::
mant
)));
float
d_exp
;
if
(
is_subnormal
<
T
>
(
data
))
d_exp
=
std
::
pow
(
2
,
1
-
static_cast
<
int
>
(
NumericUtils
<
T
>::
bias
));
else
d_exp
=
std
::
pow
(
2
,
get_exponent_value
<
T
>
(
data
)
-
static_cast
<
int
>
(
NumericUtils
<
T
>::
bias
));
float
d_mant
=
get_mantissa_value
<
T
>
(
data
);
float
data_value
=
d_sign
*
d_exp
*
d_mant
;
float
scale_value
=
std
::
pow
(
2
,
static_cast
<
float
>
((
scale_exp
-
static_cast
<
int
>
(
NumericUtils
<
e8m0_bexp_t
>::
bias
))));
return
data_value
*
scale_value
;
}
template
<
typename
T
>
__host__
__device__
inline
float
to_float
(
e8m0_bexp_t
const
scale
,
T
const
data
);
template
<
typename
T
>
__host__
__device__
T
sat_convert_to_type
(
float
value
);
template
<
typename
T
>
__host__
__device__
T
sat_convert_to_type_sr
(
float
value
,
uint32_t
seed
);
template
<
typename
T
>
inline
T
convert_to_type
(
float
value
)
{
using
bitwise_type
=
typename
NumericUtils
<
T
>::
bitwise_type
;
if
(
std
::
abs
(
value
)
>
NumericLimits
<
T
>::
Max
())
{
float
max_value
=
NumericLimits
<
T
>::
Max
();
cvt
t
;
// cppcheck-suppress redundantAssignment
t
.
value_float
=
max_value
;
uint32_t
max_bitwise
=
t
.
value_bitwise
;
// cppcheck-suppress redundantAssignment
t
.
value_float
=
value
;
bitwise_type
sign
=
t
.
value_bitwise
>>
(
NumericUtils
<
float
>::
exp
+
NumericUtils
<
float
>::
mant
);
bitwise_type
exp
=
((
max_bitwise
>>
NumericUtils
<
float
>::
mant
)
&
NumericUtils
<
float
>::
exp_mask
)
-
(
NumericUtils
<
float
>::
bias
-
NumericUtils
<
T
>::
bias
);
bitwise_type
mantissa
=
max_bitwise
>>
(
NumericUtils
<
float
>::
mant
-
NumericUtils
<
T
>::
mant
);
uint32_t
mant_prev
=
max_bitwise
>>
(
NumericUtils
<
float
>::
mant
-
NumericUtils
<
T
>::
mant
);
mant_prev
&=
((
1
<<
NumericUtils
<
T
>::
mant
)
-
1
);
mant_prev
--
;
mant_prev
<<=
(
NumericUtils
<
float
>::
mant
-
NumericUtils
<
T
>::
mant
);
uint32_t
prev_bit
=
((
max_bitwise
>>
NumericUtils
<
float
>::
mant
)
<<
NumericUtils
<
float
>::
mant
)
|
mant_prev
;
t
.
value_bitwise
=
prev_bit
;
float
prev_val
=
t
.
value_float
;
float
diff
=
max_value
-
prev_val
;
float
actual_max
=
max_value
+
(
diff
/
2
);
if
(
std
::
abs
(
value
)
<
actual_max
)
{
return
sign
<<
((
NumericUtils
<
T
>::
exp
+
NumericUtils
<
T
>::
mant
))
|
(
exp
<<
NumericUtils
<
T
>::
mant
)
|
mantissa
;
}
else
{
if
(
!
get_data_has_inf
<
T
>
())
{
return
(
1
<<
(
NumericUtils
<
T
>::
mant
+
NumericUtils
<
T
>::
exp
))
-
1
;
}
else
{
exp
++
;
return
sign
<<
((
NumericUtils
<
T
>::
exp
+
NumericUtils
<
T
>::
mant
))
|
(
exp
<<
NumericUtils
<
T
>::
mant
);
}
}
}
const
int
mfmt
=
NumericUtils
<
float
>::
mant
;
uint32_t
x
;
x
=
bit_cast
<
uint32_t
>
(
value
);
uint32_t
head
,
mantissa
;
int32_t
exponent
,
bias
;
uint32_t
sign
;
head
=
x
&
NumericUtils
<
float
>::
head_mask
;
mantissa
=
x
&
NumericUtils
<
float
>::
mant_mask
;
exponent
=
(
head
>>
NumericUtils
<
float
>::
mant
)
&
NumericUtils
<
float
>::
exp_mask
;
sign
=
head
>>
(
NumericUtils
<
float
>::
mant
+
NumericUtils
<
float
>::
exp
);
bias
=
NumericUtils
<
float
>::
bias
;
if
(
x
==
0
)
{
return
0b0
;
}
const
int
mini_bias
=
NumericUtils
<
T
>::
bias
;
const
int
mini_denormal_act_exponent
=
1
-
mini_bias
;
int
act_exponent
,
out_exponent
,
exponent_diff
;
bool
is_subnorm
=
false
;
if
(
exponent
==
0
)
{
act_exponent
=
exponent
-
bias
+
1
;
exponent_diff
=
mini_denormal_act_exponent
-
act_exponent
;
is_subnorm
=
true
;
}
else
{
act_exponent
=
exponent
-
bias
;
if
(
act_exponent
<=
mini_denormal_act_exponent
)
{
exponent_diff
=
mini_denormal_act_exponent
-
act_exponent
;
is_subnorm
=
true
;
}
else
{
exponent_diff
=
0
;
}
mantissa
+=
(
1UL
<<
mfmt
);
}
auto
shift_amount
=
(
mfmt
-
NumericUtils
<
T
>::
mant
+
exponent_diff
);
shift_amount
=
(
shift_amount
>=
64
)
?
63
:
shift_amount
;
bool
midpoint
=
(
mantissa
&
((
1UL
<<
shift_amount
)
-
1
))
==
(
1UL
<<
(
shift_amount
-
1
));
float
min_subnorm
=
NumericLimits
<
T
>::
DataMinSubnorm
()
*
(
sign
?
-
1
:
1
);
if
(
is_subnorm
&&
std
::
abs
(
value
)
<
std
::
abs
(
min_subnorm
))
{
// closer to 0
if
(
std
::
abs
(
value
)
<=
std
::
abs
(
min_subnorm
-
value
))
return
0
;
else
return
1
|
(
sign
<<
(
NumericUtils
<
T
>::
exp
+
NumericUtils
<
T
>::
mant
));
}
if
(
exponent_diff
>
0
)
mantissa
>>=
exponent_diff
;
else
if
(
exponent_diff
==
-
1
)
mantissa
<<=
-
exponent_diff
;
bool
implicit_one
=
mantissa
&
(
1
<<
mfmt
);
out_exponent
=
(
act_exponent
+
exponent_diff
)
+
mini_bias
-
(
implicit_one
?
0
:
1
);
uint32_t
drop_mask
=
(
1UL
<<
(
mfmt
-
NumericUtils
<
T
>::
mant
))
-
1
;
bool
odd
=
mantissa
&
(
1UL
<<
(
mfmt
-
NumericUtils
<
T
>::
mant
));
mantissa
+=
(
midpoint
?
(
odd
?
mantissa
:
mantissa
-
1
)
:
mantissa
)
&
drop_mask
;
if
(
out_exponent
==
0
)
{
if
((
1UL
<<
mfmt
)
&
mantissa
)
{
out_exponent
=
1
;
}
}
else
{
if
((
1UL
<<
(
mfmt
+
1
))
&
mantissa
)
{
mantissa
>>=
1
;
out_exponent
++
;
}
}
mantissa
>>=
(
mfmt
-
NumericUtils
<
T
>::
mant
);
if
(
out_exponent
==
0
&&
mantissa
==
0
)
{
return
0
;
}
mantissa
&=
(
1UL
<<
NumericUtils
<
T
>::
mant
)
-
1
;
return
(
sign
<<
(
NumericUtils
<
T
>::
exp
+
NumericUtils
<
T
>::
mant
))
|
(
out_exponent
<<
NumericUtils
<
T
>::
mant
)
|
mantissa
;
}
template
<
typename
T
>
inline
T
convert_to_type_sr
(
float
value
,
uint32_t
seed
)
{
if
(
std
::
abs
(
value
)
>
NumericLimits
<
T
>::
Max
())
{
float
max_value
=
NumericLimits
<
T
>::
Max
();
cvt
t
;
// cppcheck-suppress redundantAssignment
t
.
value_float
=
max_value
;
uint
max_bitwise
=
t
.
value_bitwise
;
// cppcheck-suppress redundantAssignment
t
.
value_float
=
value
;
T
sign
=
t
.
value_bitwise
>>
(
NumericUtils
<
float
>::
exp
+
NumericUtils
<
float
>::
mant
);
T
exp
=
((
max_bitwise
>>
NumericUtils
<
float
>::
mant
)
&
NumericUtils
<
float
>::
exp_mask
)
-
(
NumericUtils
<
float
>::
bias
-
NumericUtils
<
T
>::
bias
);
uint32_t
mant_prev
=
max_bitwise
>>
(
NumericUtils
<
float
>::
mant
-
NumericUtils
<
T
>::
mant
);
mant_prev
&=
((
1UL
<<
NumericUtils
<
T
>::
mant
)
-
1
);
mant_prev
--
;
mant_prev
<<=
(
NumericUtils
<
float
>::
mant
-
NumericUtils
<
T
>::
mant
);
uint32_t
prev_bit
=
((
max_bitwise
>>
NumericUtils
<
float
>::
mant
)
<<
NumericUtils
<
float
>::
mant
)
|
mant_prev
;
t
.
value_bitwise
=
prev_bit
;
float
prev_val
=
t
.
value_float
;
float
diff
=
max_value
-
prev_val
;
float
actual_max
=
max_value
+
(
diff
/
2
);
if
(
std
::
abs
(
value
)
<
actual_max
)
{
double
d_max_value
=
static_cast
<
double
>
(
max_value
);
double
d_actual_max
=
static_cast
<
double
>
(
actual_max
);
double
d_value
=
static_cast
<
double
>
(
value
);
double
d_is
=
std
::
abs
(
d_max_value
-
d_actual_max
);
double
d_seed
=
static_cast
<
double
>
(
seed
);
double
d_prob
=
1.0
f
-
(
std
::
abs
(
d_value
-
d_max_value
)
/
d_is
);
// prob to round down
double
thresh
=
UINT_MAX
*
d_prob
;
if
(
!
get_data_has_inf
<
T
>
()
||
d_seed
<=
thresh
)
// return static_cast<T>(satConvertToType(getDataMax<DTYPE>())); //round down time
return
sign
==
0
?
NumericUtils
<
f4_t
>::
data_max_positive_normal_mask
:
NumericUtils
<
f4_t
>::
data_max_negative_normal_mask
;
else
{
exp
++
;
return
sign
<<
((
NumericUtils
<
T
>::
exp
+
NumericUtils
<
T
>::
mant
))
// inf
|
(
exp
<<
NumericUtils
<
T
>::
mant
);
}
}
else
{
if
(
!
get_data_has_inf
<
T
>
())
return
(
1
<<
(
NumericUtils
<
T
>::
mant
+
NumericUtils
<
T
>::
exp
))
-
1
;
else
{
exp
++
;
return
sign
<<
((
NumericUtils
<
T
>::
exp
+
NumericUtils
<
T
>::
mant
))
// inf
|
(
exp
<<
NumericUtils
<
T
>::
mant
);
}
}
}
uint32_t
f32
=
bit_cast
<
uint32_t
>
(
value
);
auto
f32_mant
=
f32
&
NumericUtils
<
float
>::
mant_mask
;
auto
head
=
f32
&
NumericUtils
<
float
>::
head_mask
;
auto
f32_exp
=
(
head
>>
NumericUtils
<
float
>::
mant
)
&
NumericUtils
<
float
>::
exp_mask
;
auto
sign_bit
=
head
>>
(
NumericUtils
<
float
>::
mant
+
NumericUtils
<
float
>::
exp
);
auto
sign
=
sign_bit
<<
(
NumericUtils
<
T
>::
exp
+
NumericUtils
<
T
>::
mant
);
f32_exp
=
static_cast
<
int32_t
>
(
f32_exp
)
-
NumericUtils
<
float
>::
bias
;
int32_t
exp
=
f32_exp
;
auto
mant
=
f32_mant
;
bool
subnorm
=
false
;
if
(
f32
==
0
)
return
0b0
;
if
(
exp
>=
NumericUtils
<
T
>::
unbiased_exp_min
)
{
mant
=
f32_mant
;
}
// if the exponent bit is 8, then the subnormal is exactly the same as f32
else
if
(
exp
<
NumericUtils
<
T
>::
unbiased_exp_min
&&
NumericUtils
<
T
>::
exp
<
NumericUtils
<
float
>::
exp
)
{
subnorm
=
true
;
auto
diff
=
static_cast
<
uint32_t
>
(
NumericUtils
<
T
>::
unbiased_exp_min
-
exp
);
if
(
diff
>=
32
)
{
mant
=
0
;
f32_mant
=
0
;
}
else
{
f32_mant
|=
static_cast
<
uint32_t
>
(
1
)
<<
NumericUtils
<
float
>::
mant
;
f32_mant
>>=
diff
;
}
exp
=
0
;
mant
=
f32_mant
;
}
uint32_t
sr_shift
=
NumericUtils
<
T
>::
sr_shift
;
// For stochastic-rounding we add the aligned random value to the
// mantissa and then truncate (RTZ).
mant
+=
seed
>>
sr_shift
;
// Increment exponent when mantissa overflows due to rounding
if
(
mant
>=
static_cast
<
uint32_t
>
(
1
)
<<
NumericUtils
<
float
>::
mant
)
++
exp
;
mant
>>=
(
NumericUtils
<
float
>::
mant
-
NumericUtils
<
T
>::
mant
);
mant
&=
((
1
<<
NumericUtils
<
T
>::
mant
)
-
1
);
auto
biased_exp
=
static_cast
<
uint32_t
>
(
exp
);
if
(
!
subnorm
)
biased_exp
=
static_cast
<
uint32_t
>
(
exp
+
NumericUtils
<
T
>::
bias
);
biased_exp
&=
((
1
<<
NumericUtils
<
T
>::
exp
)
-
1
);
auto
val
=
sign
|
biased_exp
<<
NumericUtils
<
T
>::
mant
|
mant
;
return
val
;
}
}
// namespace ck::utils
Prev
1
2
3
4
5
6
7
8
9
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment