Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel_ROCM
Commits
5d718e6b
Commit
5d718e6b
authored
Mar 04, 2024
by
Jun Liu
Browse files
Merge branch 'develop' into amd-develop
parents
6d9a07d7
9ce18b04
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
45 additions
and
27 deletions
+45
-27
example/01_gemm/gemm_xdl_fp8.cpp
example/01_gemm/gemm_xdl_fp8.cpp
+9
-5
example/01_gemm/gemm_xdl_fp8_bf8.cpp
example/01_gemm/gemm_xdl_fp8_bf8.cpp
+4
-4
include/ck/utility/type_convert.hpp
include/ck/utility/type_convert.hpp
+32
-18
No files found.
example/01_gemm/gemm_xdl_fp8.cpp
View file @
5d718e6b
...
@@ -20,14 +20,18 @@ using BElementOp = PassThrough;
...
@@ -20,14 +20,18 @@ using BElementOp = PassThrough;
using
CElementOp
=
PassThrough
;
using
CElementOp
=
PassThrough
;
static
constexpr
auto
GemmDefault
=
ck
::
tensor_operation
::
device
::
GemmSpecialization
::
Default
;
static
constexpr
auto
GemmDefault
=
ck
::
tensor_operation
::
device
::
GemmSpecialization
::
Default
;
static
constexpr
auto
LoopSched
=
ck
::
make_default_loop_scheduler
();
static
constexpr
auto
PipelineVer
=
ck
::
PipelineVersion
::
v1
;
using
ComputeTypeA
=
ck
::
f8_t
;
using
ComputeTypeB
=
ck
::
f8_t
;
// clang-format off
// clang-format off
using
DeviceGemmInstance
=
ck
::
tensor_operation
::
device
::
DeviceGemm_Xdl_CShuffle
using
DeviceGemmInstance
=
ck
::
tensor_operation
::
device
::
DeviceGemm_Xdl_CShuffle
// ######| ALayout| BLayout| CLayout| AData| BData| CData| AccData| CShuffle| A| B| C| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer|
// ######| ALayout| BLayout| CLayout| AData| BData| CData| AccData| CShuffle| A| B| C| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer|
Loop| Pipeline| Compute| Compute|
// ######| | | | Type| Type| Type| Type| DataType| Elementwise| Elementwise| Elementwise| Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector|
// ######| | | | Type| Type| Type| Type| DataType| Elementwise| Elementwise| Elementwise| Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector|
Scheduler| Version| TypeA| TypeB|
// ######| | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl|
// ######| | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl|
| | | |
// ######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
// ######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
| | | |
<
ALayout
,
BLayout
,
CLayout
,
ADataType
,
BDataType
,
CDataType
,
AccDataType
,
CShuffleDataType
,
AElementOp
,
BElementOp
,
CElementOp
,
GemmDefault
,
1
,
256
,
256
,
128
,
64
,
16
,
16
,
32
,
32
,
4
,
2
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
16
,
16
,
1
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
1
,
1
,
1
,
S
<
1
,
64
,
1
,
4
>
,
8
>
;
<
ALayout
,
BLayout
,
CLayout
,
ADataType
,
BDataType
,
CDataType
,
AccDataType
,
CShuffleDataType
,
AElementOp
,
BElementOp
,
CElementOp
,
GemmDefault
,
1
,
256
,
256
,
128
,
64
,
16
,
16
,
32
,
32
,
4
,
2
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
16
,
16
,
1
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
1
,
1
,
1
,
S
<
1
,
64
,
1
,
4
>
,
8
,
LoopSched
,
PipelineVer
,
ComputeTypeA
,
ComputeTypeB
>
;
// clang-format on
// clang-format on
using
ReferenceGemmInstance
=
ck
::
tensor_operation
::
host
::
using
ReferenceGemmInstance
=
ck
::
tensor_operation
::
host
::
...
...
example/01_gemm/gemm_xdl_fp8_bf8.cpp
View file @
5d718e6b
...
@@ -27,10 +27,10 @@ using ComputeTypeB = ck::bf8_t;
...
@@ -27,10 +27,10 @@ using ComputeTypeB = ck::bf8_t;
// clang-format off
// clang-format off
using
DeviceGemmInstance
=
ck
::
tensor_operation
::
device
::
DeviceGemm_Xdl_CShuffle
using
DeviceGemmInstance
=
ck
::
tensor_operation
::
device
::
DeviceGemm_Xdl_CShuffle
// ######| ALayout| BLayout| CLayout| AData| BData| CData| AccData| CShuffle| A| B| C| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer|
// ######| ALayout| BLayout| CLayout| AData| BData| CData| AccData| CShuffle| A| B| C| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer|
Loop| Pipeline| Compute| Compute|
// ######| | | | Type| Type| Type| Type| DataType| Elementwise| Elementwise| Elementwise| Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector|
// ######| | | | Type| Type| Type| Type| DataType| Elementwise| Elementwise| Elementwise| Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector|
Scheduler| Version| TypeA| TypeB|
// ######| | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl|
// ######| | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl|
| | | |
// ######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
// ######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
| | | |
<
ALayout
,
BLayout
,
CLayout
,
ADataType
,
BDataType
,
CDataType
,
AccDataType
,
CShuffleDataType
,
AElementOp
,
BElementOp
,
CElementOp
,
GemmDefault
,
1
,
256
,
256
,
128
,
64
,
16
,
16
,
32
,
32
,
4
,
2
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
16
,
16
,
1
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
1
,
1
,
1
,
S
<
1
,
64
,
1
,
4
>
,
8
,
LoopSched
,
PipelineVer
,
ComputeTypeA
,
ComputeTypeB
>
;
<
ALayout
,
BLayout
,
CLayout
,
ADataType
,
BDataType
,
CDataType
,
AccDataType
,
CShuffleDataType
,
AElementOp
,
BElementOp
,
CElementOp
,
GemmDefault
,
1
,
256
,
256
,
128
,
64
,
16
,
16
,
32
,
32
,
4
,
2
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
16
,
16
,
1
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
1
,
1
,
1
,
S
<
1
,
64
,
1
,
4
>
,
8
,
LoopSched
,
PipelineVer
,
ComputeTypeA
,
ComputeTypeB
>
;
// clang-format on
// clang-format on
...
...
include/ck/utility/type_convert.hpp
View file @
5d718e6b
...
@@ -109,9 +109,6 @@ inline __host__ __device__ f8_t f8_convert_sr<f8_t, float>(float x)
...
@@ -109,9 +109,6 @@ inline __host__ __device__ f8_t f8_convert_sr<f8_t, float>(float x)
{
{
constexpr
int
seed
=
1254739
;
constexpr
int
seed
=
1254739
;
uint32_t
rng
=
prand_generator
<
float
,
seed
>
(
reinterpret_cast
<
uintptr_t
>
(
&
x
),
x
);
uint32_t
rng
=
prand_generator
<
float
,
seed
>
(
reinterpret_cast
<
uintptr_t
>
(
&
x
),
x
);
float
max_fp8
=
240.0
f
;
if
(
!
std
::
isinf
(
x
))
x
=
x
>
max_fp8
?
max_fp8
:
(
x
<
-
max_fp8
?
-
max_fp8
:
x
);
#if defined(__gfx94__)
#if defined(__gfx94__)
union
union
{
{
...
@@ -119,10 +116,15 @@ inline __host__ __device__ f8_t f8_convert_sr<f8_t, float>(float x)
...
@@ -119,10 +116,15 @@ inline __host__ __device__ f8_t f8_convert_sr<f8_t, float>(float x)
uint32_t
i32val
;
uint32_t
i32val
;
uint8_t
i8val
[
4
];
// not endian independent
uint8_t
i8val
[
4
];
// not endian independent
}
val
;
}
val
;
val
.
fval
=
x
;
val
.
fval
=
x
;
uint32_t
ival
=
0
;
uint32_t
ival
=
0
;
ival
=
__builtin_amdgcn_cvt_sr_fp8_f32
(
val
.
fval
,
rng
,
ival
,
0
);
// 0 pos
const
float
max_fp8
=
240.0
f
;
val
.
i32val
=
ival
;
// if x is not +/- infinity or nan
if
((
val
.
i32val
&
NumericUtils
<
float
>::
nan_mask
)
!=
NumericUtils
<
float
>::
Inf
)
// clip float value
val
.
fval
=
__builtin_amdgcn_fmed3f
(
val
.
fval
,
max_fp8
,
-
max_fp8
);
ival
=
__builtin_amdgcn_cvt_sr_fp8_f32
(
val
.
fval
,
rng
,
ival
,
0
);
// 0 pos
val
.
i32val
=
ival
;
return
val
.
i8val
[
0
];
// little endian
return
val
.
i8val
[
0
];
// little endian
#else
#else
constexpr
bool
negative_zero_nan
=
true
;
constexpr
bool
negative_zero_nan
=
true
;
...
@@ -166,10 +168,15 @@ inline __host__ __device__ bf8_t f8_convert_sr<bf8_t, float>(float x)
...
@@ -166,10 +168,15 @@ inline __host__ __device__ bf8_t f8_convert_sr<bf8_t, float>(float x)
uint32_t
i32val
;
uint32_t
i32val
;
uint8_t
i8val
[
4
];
// not endian independent
uint8_t
i8val
[
4
];
// not endian independent
}
val
;
}
val
;
val
.
fval
=
x
;
val
.
fval
=
x
;
uint32_t
ival
=
0
;
uint32_t
ival
=
0
;
ival
=
__builtin_amdgcn_cvt_sr_bf8_f32
(
val
.
fval
,
rng
,
ival
,
0
);
// 0 pos
const
float
max_bf8
=
57344.0
f
;
val
.
i32val
=
ival
;
// if x is not +/- infinity or nan
if
((
val
.
i32val
&
NumericUtils
<
float
>::
nan_mask
)
!=
NumericUtils
<
float
>::
Inf
)
// clip float value
val
.
fval
=
__builtin_amdgcn_fmed3f
(
val
.
fval
,
max_bf8
,
-
max_bf8
);
ival
=
__builtin_amdgcn_cvt_sr_bf8_f32
(
val
.
fval
,
rng
,
ival
,
0
);
// 0 pos
val
.
i32val
=
ival
;
return
val
.
i8val
[
0
];
// little endian
return
val
.
i8val
[
0
];
// little endian
#else
#else
constexpr
bool
negative_zero_nan
=
true
;
constexpr
bool
negative_zero_nan
=
true
;
...
@@ -208,9 +215,6 @@ __host__ __device__ constexpr Y f8_convert_rne(X x);
...
@@ -208,9 +215,6 @@ __host__ __device__ constexpr Y f8_convert_rne(X x);
template
<
>
template
<
>
inline
__host__
__device__
f8_t
f8_convert_rne
<
f8_t
,
float
>
(
float
x
)
inline
__host__
__device__
f8_t
f8_convert_rne
<
f8_t
,
float
>
(
float
x
)
{
{
float
max_fp8
=
240.0
f
;
if
(
!
std
::
isinf
(
x
))
x
=
x
>
max_fp8
?
max_fp8
:
(
x
<
-
max_fp8
?
-
max_fp8
:
x
);
#if defined(__gfx94__)
#if defined(__gfx94__)
union
union
{
{
...
@@ -218,8 +222,13 @@ inline __host__ __device__ f8_t f8_convert_rne<f8_t, float>(float x)
...
@@ -218,8 +222,13 @@ inline __host__ __device__ f8_t f8_convert_rne<f8_t, float>(float x)
uint32_t
i32val
;
uint32_t
i32val
;
uint8_t
i8val
[
4
];
// not endian independent
uint8_t
i8val
[
4
];
// not endian independent
}
val
;
}
val
;
val
.
fval
=
x
;
val
.
fval
=
x
;
uint32_t
ival
=
0
;
uint32_t
ival
=
0
;
const
float
max_fp8
=
240.0
f
;
// if x is not +/- infinity or nan
if
((
val
.
i32val
&
NumericUtils
<
float
>::
nan_mask
)
!=
NumericUtils
<
float
>::
Inf
)
// clip float value
val
.
fval
=
__builtin_amdgcn_fmed3f
(
val
.
fval
,
max_fp8
,
-
max_fp8
);
ival
=
__builtin_amdgcn_cvt_pk_fp8_f32
(
val
.
fval
,
val
.
fval
,
ival
,
false
);
// false -> WORD0
ival
=
__builtin_amdgcn_cvt_pk_fp8_f32
(
val
.
fval
,
val
.
fval
,
ival
,
false
);
// false -> WORD0
val
.
i32val
=
ival
;
val
.
i32val
=
ival
;
return
val
.
i8val
[
0
];
return
val
.
i8val
[
0
];
...
@@ -263,8 +272,13 @@ inline __host__ __device__ bf8_t f8_convert_rne<bf8_t, float>(float x)
...
@@ -263,8 +272,13 @@ inline __host__ __device__ bf8_t f8_convert_rne<bf8_t, float>(float x)
uint32_t
i32val
;
uint32_t
i32val
;
uint8_t
i8val
[
4
];
// not endian independent
uint8_t
i8val
[
4
];
// not endian independent
}
val
;
}
val
;
val
.
fval
=
x
;
val
.
fval
=
x
;
uint32_t
ival
=
0
;
uint32_t
ival
=
0
;
const
float
max_bf8
=
57344.0
f
;
// if x is not +/- infinity or nan
if
((
val
.
i32val
&
NumericUtils
<
float
>::
nan_mask
)
!=
NumericUtils
<
float
>::
Inf
)
// clip float value
val
.
fval
=
__builtin_amdgcn_fmed3f
(
val
.
fval
,
max_bf8
,
-
max_bf8
);
ival
=
__builtin_amdgcn_cvt_pk_bf8_f32
(
val
.
fval
,
val
.
fval
,
ival
,
false
);
// false -> WORD0
ival
=
__builtin_amdgcn_cvt_pk_bf8_f32
(
val
.
fval
,
val
.
fval
,
ival
,
false
);
// false -> WORD0
val
.
i32val
=
ival
;
val
.
i32val
=
ival
;
return
val
.
i8val
[
0
];
return
val
.
i8val
[
0
];
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment