Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel_ROCM
Commits
9ba504b6
Commit
9ba504b6
authored
Feb 07, 2025
by
ThomasNing
Browse files
merge with the develop support the fp8 with computev4
parents
e3402c93
f49de496
Changes
198
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
712 additions
and
90 deletions
+712
-90
example/62_convnd_activ/dynamic_unary/CMakeLists.txt
example/62_convnd_activ/dynamic_unary/CMakeLists.txt
+1
-1
example/62_convnd_activ/multi_AB/CMakeLists.txt
example/62_convnd_activ/multi_AB/CMakeLists.txt
+1
-1
example/62_convnd_activ/unary/CMakeLists.txt
example/62_convnd_activ/unary/CMakeLists.txt
+1
-1
example/67_gemm_microscaling/CMakeLists.txt
example/67_gemm_microscaling/CMakeLists.txt
+5
-0
example/67_gemm_microscaling/README.md
example/67_gemm_microscaling/README.md
+17
-0
example/67_gemm_microscaling/gemm_mx_common.hpp
example/67_gemm_microscaling/gemm_mx_common.hpp
+415
-0
example/67_gemm_microscaling/gemm_mx_fp8.cpp
example/67_gemm_microscaling/gemm_mx_fp8.cpp
+41
-0
example/CMakeLists.txt
example/CMakeLists.txt
+39
-30
example/ck_tile/03_gemm/gemm_basic.cpp
example/ck_tile/03_gemm/gemm_basic.cpp
+32
-6
example/ck_tile/03_gemm/gemm_basic.hpp
example/ck_tile/03_gemm/gemm_basic.hpp
+43
-6
example/ck_tile/03_gemm/run_gemm_example.inc
example/ck_tile/03_gemm/run_gemm_example.inc
+25
-8
example/ck_tile/03_gemm/script/benchmark_basic_bf16.sh
example/ck_tile/03_gemm/script/benchmark_basic_bf16.sh
+0
-0
example/ck_tile/03_gemm/script/benchmark_basic_bf8.sh
example/ck_tile/03_gemm/script/benchmark_basic_bf8.sh
+0
-0
example/ck_tile/03_gemm/script/benchmark_basic_fp8.sh
example/ck_tile/03_gemm/script/benchmark_basic_fp8.sh
+14
-0
example/ck_tile/03_gemm/script/benchmark_mem_pipeline.sh
example/ck_tile/03_gemm/script/benchmark_mem_pipeline.sh
+3
-3
example/ck_tile/03_gemm/script/benchmark_mem_pipeline_bf16.sh
...ple/ck_tile/03_gemm/script/benchmark_mem_pipeline_bf16.sh
+13
-0
example/ck_tile/03_gemm/script/benchmark_mem_pipeline_bf8.sh
example/ck_tile/03_gemm/script/benchmark_mem_pipeline_bf8.sh
+13
-0
example/ck_tile/03_gemm/script/benchmark_mem_pipeline_fp8.sh
example/ck_tile/03_gemm/script/benchmark_mem_pipeline_fp8.sh
+13
-0
example/ck_tile/03_gemm/script/smoke_test_basic.sh
example/ck_tile/03_gemm/script/smoke_test_basic.sh
+18
-17
example/ck_tile/03_gemm/script/smoke_test_mem_pipeline.sh
example/ck_tile/03_gemm/script/smoke_test_mem_pipeline.sh
+18
-17
No files found.
example/62_convnd_activ/dynamic_unary/CMakeLists.txt
View file @
9ba504b6
list
(
APPEND gpu_list gfx908 gfx90a gfx940 gfx941 gfx942
)
list
(
APPEND gpu_list gfx908 gfx90a gfx940 gfx941 gfx942
gfx950
)
set
(
target 0
)
foreach
(
gpu IN LISTS GPU_TARGETS
)
if
(
gpu IN_LIST gpu_list AND target EQUAL 0
)
...
...
example/62_convnd_activ/multi_AB/CMakeLists.txt
View file @
9ba504b6
list
(
APPEND gpu_list gfx908 gfx90a gfx940 gfx941 gfx942
)
list
(
APPEND gpu_list gfx908 gfx90a gfx940 gfx941 gfx942
gfx950
)
set
(
target 0
)
foreach
(
gpu IN LISTS GPU_TARGETS
)
if
(
gpu IN_LIST gpu_list AND target EQUAL 0
)
...
...
example/62_convnd_activ/unary/CMakeLists.txt
View file @
9ba504b6
list
(
APPEND gpu_list gfx908 gfx90a gfx940 gfx941 gfx942
)
list
(
APPEND gpu_list gfx908 gfx90a gfx940 gfx941 gfx942
gfx950
)
set
(
target 0
)
foreach
(
gpu IN LISTS GPU_TARGETS
)
if
(
gpu IN_LIST gpu_list AND target EQUAL 0
)
...
...
example/67_gemm_microscaling/CMakeLists.txt
0 → 100644
View file @
9ba504b6
add_custom_target
(
example_gemm_mx
)
add_example_executable
(
example_gemm_mx_fp8 gemm_mx_fp8.cpp
)
add_example_dependencies
(
example_gemm_mx example_gemm_mx_fp8
)
example/67_gemm_microscaling/README.md
0 → 100644
View file @
9ba504b6
# GEMM Examples for Microscaling Formats
## example_gemm_mx_fp8
```
bash
# arg1: verification (0=no, 1=CPU)
# arg2: initialization (0=no init, 1=integer value, 2=decimal value)
# arg3: time kernel (0=no, 1=yes)
# arg4: verbosity (0=no info, 1=verbose info)
# arg5 to 10: M (16x), N(16x), K(16x), StrideA, StrideB, StrideC
./bin/example_gemm_mx_fp8 1 1 0 1
```
```
bash
# Implies: ./bin/example_gemm_mx_fp8 1 2 0 0
./bin/example_gemm_mx_fp8
```
\ No newline at end of file
example/67_gemm_microscaling/gemm_mx_common.hpp
0 → 100644
View file @
9ba504b6
// SPDX-License-Identifier: MIT
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <iostream>
#include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/element/unary_element_wise_operation.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_xdl_cshuffle_v3_ab_scale.hpp"
#include "ck/utility/blkgemmpipe_scheduler.hpp"
#include "ck/utility/data_type.hpp"
#include "ck/utility/sequence.hpp"
#include "ck/library/reference_tensor_operation/cpu/reference_mx_gemm.hpp"
#include "ck/library/utility/check_err.hpp"
#include "ck/library/utility/device_memory.hpp"
#include "ck/library/utility/fill.hpp"
#include "ck/library/utility/host_tensor.hpp"
using
ScaleDataType
=
ck
::
e8m0_bexp_t
;
template
<
ck
::
index_t
...
Is
>
using
S
=
ck
::
Sequence
<
Is
...
>
;
using
Row
=
ck
::
tensor_layout
::
gemm
::
RowMajor
;
using
Col
=
ck
::
tensor_layout
::
gemm
::
ColumnMajor
;
using
PassThrough
=
ck
::
tensor_operation
::
element_wise
::
PassThrough
;
struct
ExecutionConfig
final
{
int
do_verification
=
1
;
// (0=no, 1=CPU)
int
init_method
=
2
;
// (0=no init, 1=integer value, 2=decimal value)
bool
time_kernel
=
false
;
// (0=no, 1=yes)
int
verbosity
=
0
;
// (0=no info, 1=verbose info)
};
struct
ProblemSize
final
{
ck
::
index_t
M
=
3840
;
ck
::
index_t
N
=
4096
;
ck
::
index_t
K
=
4096
;
ck
::
index_t
StrideA
=
-
1
;
ck
::
index_t
StrideB
=
-
1
;
ck
::
index_t
StrideC
=
-
1
;
};
bool
parse_cmd_args
(
int
argc
,
char
*
argv
[],
ProblemSize
&
problem_size
,
ExecutionConfig
&
config
)
{
if
(
argc
==
1
)
{
// use default case
}
else
if
(
argc
==
5
)
{
config
.
do_verification
=
std
::
stoi
(
argv
[
1
]);
config
.
init_method
=
std
::
stoi
(
argv
[
2
]);
config
.
time_kernel
=
std
::
stoi
(
argv
[
3
]);
config
.
verbosity
=
std
::
stoi
(
argv
[
4
]);
}
else
if
(
argc
==
11
)
{
config
.
do_verification
=
std
::
stoi
(
argv
[
1
]);
config
.
init_method
=
std
::
stoi
(
argv
[
2
]);
config
.
time_kernel
=
std
::
stoi
(
argv
[
3
]);
config
.
verbosity
=
std
::
stoi
(
argv
[
4
]);
problem_size
.
M
=
std
::
stoi
(
argv
[
5
]);
problem_size
.
N
=
std
::
stoi
(
argv
[
6
]);
problem_size
.
K
=
std
::
stoi
(
argv
[
7
]);
problem_size
.
StrideA
=
std
::
stoi
(
argv
[
8
]);
problem_size
.
StrideB
=
std
::
stoi
(
argv
[
9
]);
problem_size
.
StrideC
=
std
::
stoi
(
argv
[
10
]);
}
else
{
std
::
cerr
<<
"arg1: verification (0=no, 1=CPU)"
<<
std
::
endl
<<
"arg2: initialization (0=no init, 1=integer value, 2=decimal value)"
<<
std
::
endl
<<
"arg3: time kernel (0=no, 1=yes)"
<<
std
::
endl
<<
"arg4: verbosity (0=no info, 1=verbose info)"
<<
std
::
endl
<<
"arg5 to 10: M (16x), N(16x), K(16x), StrideA, StrideB, StrideC"
<<
std
::
endl
;
return
false
;
}
return
true
;
}
template
<
typename
ADataType
,
typename
BDataType
,
typename
XDataType
,
typename
CDataType
,
typename
ALayout
,
typename
BLayout
,
typename
CLayout
,
typename
CElementWiseOp
,
typename
AccDataType
,
typename
CShuffleDataType
,
ck
::
index_t
MXVectorSize
>
bool
run_mx_gemm
(
const
ProblemSize
&
problem_size
,
const
ExecutionConfig
&
config
)
{
using
ELayout
=
CLayout
;
using
DsLayout
=
ck
::
Tuple
<>
;
using
DsDataType
=
ck
::
Tuple
<>
;
using
AElementOp
=
PassThrough
;
using
BElementOp
=
PassThrough
;
using
CDEElementOp
=
CElementWiseOp
;
static
constexpr
auto
GemmSpec
=
ck
::
tensor_operation
::
device
::
GemmSpecialization
::
Default
;
static
constexpr
auto
BlkGemmPSched
=
ck
::
BlockGemmPipelineScheduler
::
Intrawave
;
static
constexpr
auto
BlkGemmPVer
=
ck
::
BlockGemmPipelineVersion
::
v3
;
#if 1
// XXX: These parameters should not exist in MX-native GEMM kernel
static
constexpr
ck
::
index_t
Scale_Block_M
=
128
;
static
constexpr
ck
::
index_t
Scale_Block_N
=
128
;
#endif
static
constexpr
ck
::
index_t
Scale_Block_K
=
MXVectorSize
;
// XXX: DeviceGemmMultiD_ABScale_Xdl_CShuffle_V3 is not designed to utilize MX-specific MFMA
// instructions.
//
// XXX: DeviceGemmMultiD_ABScale_Xdl_CShuffle_V3 is not designed to utilize device-optimized
// scaled type convert functions.
//
// XXX: In DeviceGemmMultiD_ABScale_Xdl_CShuffle_V3, KPerBlock is expected to be equal to
// ScaleBlockK (aka MXVectorSize).
// Additionally, the following is also expected:
// static_assert(ScaleBlockM % MPerBlock == 0);
// static_assert(ScaleBlockN % NPerBlock == 0);
// In MX-native GEMM kernel these requirements should be relaxed.
//
// XXX: It appears, by default we are using mfma_f32_16x16x4xf32
// MfmaSelector<ComputeTypeA, MPerXdl, NPerXdl, ComputeTypeB>::selected_mfma.k_per_blk =
// MfmaSelector<float, 16, 16, float>::selected_mfma.k_per_blk = mfma_f32_16x16x4xf32
// XXX: GridwiseGemmMultiD_ABScale_xdl_cshuffle_v3 assumes scale type is float
// clang-format off
using
DeviceOpInstance
=
ck
::
tensor_operation
::
device
::
DeviceGemmMultiD_ABScale_Xdl_CShuffle_V3
// ######| ALayout| BLayout| DsLayout| CLayout| ADataType| AScale| BDataType| BScale| DsDataType| CDataType| GemmAcc| CShuffleDataType|AElementwise|BElementwise| CElementwise| GemmSpec|Block| ScaleBlockM| ScaleBlockN| ScaleBlockK| M| N| K| AK1| BK1| M| N|MXdl|NXdl|ABlockTransfer|ABlockTransfer|ABlockTransfer|ABlockTransfer|ABlockTransfer|ABlockTransfer| ABlock|BBlockTransfer|BBlockTransfer|BBlockTransfer|BBlockTransfer|BBlockTransfer|BBlockTransfer| BBlock| CShuffle| CShuffle|CShuffleBlockTransfer|CDEShuffleBlockTransfer| BlkGemm| BlkGemm|ComputeTypeA|ComputeTypeB|LDSTypeA|LDSTypeB|
// ######| | | | | | DataType| | DataType| | | DataType| | Operation| Operation| Operation| | Size| | | | Per| Per| Per| | | Per| Per| Per| Per| ThreadCluster| ThreadCluster|SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar|LdsExtraM| ThreadCluster| ThreadCluster|SrcAccessOrder| SrcVector| SrcScalar| DstScalar|LdsExtraN| MXdl| NXdl| ClusterLengths| Scalar| PipeSched| PipelineVer| | | | |
// ######| | | | | | | | | | | | | | | | | | | | |Block|Block| Block| | | XDL| XDL|Wave|Wave| Lengths| ArrangeOrder| | | PerVector| PerVector_AK1| | Lengths| ArrangeOrder| | Dim| PerVector| PerVector_BK1| | PerWave| PerWave| MBlock_MPerBlock| PerVectors| | | | | | |
// ######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | AK0_M_AK1| | | | | | | BK0_N_BK1| | | | | |PerShuffle|PerShuffle| NBlock_NPerBlock| | | | | | | |
<
ALayout
,
BLayout
,
DsLayout
,
ELayout
,
ADataType
,
XDataType
,
BDataType
,
XDataType
,
DsDataType
,
CDataType
,
AccDataType
,
CShuffleDataType
,
AElementOp
,
BElementOp
,
CDEElementOp
,
GemmSpec
,
256
,
Scale_Block_M
,
Scale_Block_N
,
Scale_Block_K
,
128
,
128
,
128
,
16
,
16
,
16
,
16
,
4
,
4
,
S
<
8
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
16
,
16
,
0
,
S
<
8
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
16
,
16
,
0
,
1
,
2
,
S
<
1
,
32
,
1
,
8
>
,
S
<
8
,
8
,
1
>
,
BlkGemmPSched
,
BlkGemmPVer
,
float
,
float
,
float
,
float
>
;
// clang-format on
auto
M
=
problem_size
.
M
;
auto
N
=
problem_size
.
N
;
auto
K
=
problem_size
.
K
;
auto
StrideA
=
problem_size
.
StrideA
;
auto
StrideB
=
problem_size
.
StrideB
;
auto
StrideC
=
problem_size
.
StrideC
;
auto
f_host_tensor_descriptor
=
[](
ck
::
index_t
row
,
ck
::
index_t
col
,
ck
::
index_t
stride
,
auto
layout
)
{
if
constexpr
(
std
::
is_same_v
<
decltype
(
layout
),
ck
::
tensor_layout
::
gemm
::
RowMajor
>
)
{
return
HostTensorDescriptor
({
row
,
col
},
{
stride
,
1
});
}
else
{
return
HostTensorDescriptor
({
row
,
col
},
{
1
,
stride
});
}
};
auto
f_get_default_stride
=
[](
ck
::
index_t
row
,
ck
::
index_t
col
,
ck
::
index_t
stride
,
auto
layout
)
{
if
(
stride
==
-
1
)
{
// give a chance if stride is -1, return a default packed stride
if
constexpr
(
std
::
is_same_v
<
decltype
(
layout
),
ck
::
tensor_layout
::
gemm
::
RowMajor
>
)
{
return
static_cast
<
ck
::
index_t
>
(
col
);
}
else
{
return
static_cast
<
ck
::
index_t
>
(
row
);
}
}
else
return
static_cast
<
ck
::
index_t
>
(
stride
);
};
StrideA
=
f_get_default_stride
(
M
,
K
,
StrideA
,
ALayout
{});
StrideB
=
f_get_default_stride
(
K
,
N
,
StrideB
,
BLayout
{});
StrideC
=
f_get_default_stride
(
M
,
N
,
StrideC
,
CLayout
{});
if
(
K
%
Scale_Block_K
!=
0
)
{
throw
std
::
runtime_error
(
"wrong! K must be multiple of Scale_Block_K (16 or 32)"
);
};
auto
Scale_Stride_AM
=
f_get_default_stride
(
M
,
K
/
Scale_Block_K
,
StrideA
,
ALayout
{});
auto
Scale_Stride_BN
=
f_get_default_stride
(
K
/
Scale_Block_K
,
N
,
StrideB
,
BLayout
{});
Tensor
<
ADataType
>
a_m_k
(
f_host_tensor_descriptor
(
M
,
K
,
StrideA
,
ALayout
{}));
Tensor
<
BDataType
>
b_k_n
(
f_host_tensor_descriptor
(
K
,
N
,
StrideB
,
BLayout
{}));
Tensor
<
XDataType
>
a_m_k_scale
(
f_host_tensor_descriptor
(
M
,
K
/
Scale_Block_K
,
Scale_Stride_AM
,
ALayout
{}));
// scales for A
Tensor
<
XDataType
>
b_k_n_scale
(
f_host_tensor_descriptor
(
K
/
Scale_Block_K
,
N
,
Scale_Stride_BN
,
BLayout
{}));
// scales for B
Tensor
<
CDataType
>
c_m_n_host_result
(
f_host_tensor_descriptor
(
M
,
N
,
StrideC
,
CLayout
{}));
// host verification
Tensor
<
CDataType
>
c_m_n_device_result
(
f_host_tensor_descriptor
(
M
,
N
,
StrideC
,
CLayout
{}));
// device result downloaded to host
if
(
config
.
verbosity
>=
0
)
{
std
::
cout
<<
"a_m_k: "
<<
a_m_k
.
mDesc
<<
std
::
endl
;
std
::
cout
<<
"a_m_k_scale: "
<<
a_m_k_scale
.
mDesc
<<
std
::
endl
;
std
::
cout
<<
"b_k_n: "
<<
b_k_n
.
mDesc
<<
std
::
endl
;
std
::
cout
<<
"b_k_n_scale: "
<<
b_k_n_scale
.
mDesc
<<
std
::
endl
;
std
::
cout
<<
"c_m_n_device_result: "
<<
c_m_n_device_result
.
mDesc
<<
std
::
endl
;
}
switch
(
config
.
init_method
)
{
case
0
:
if
(
config
.
verbosity
>
0
)
{
std
::
cout
<<
"NOTE: No input data initialization."
<<
std
::
endl
;
}
break
;
case
1
:
case
2
:
ck
::
utils
::
FillConstant
<
ADataType
>
{
ck
::
type_convert
<
ADataType
>
(
1.0
f
)}(
a_m_k
);
ck
::
utils
::
FillConstant
<
XDataType
>
{
ck
::
type_convert
<
XDataType
>
(
0.5
f
)}(
a_m_k_scale
);
ck
::
utils
::
FillConstant
<
BDataType
>
{
ck
::
type_convert
<
BDataType
>
(
1.0
f
)}(
b_k_n
);
ck
::
utils
::
FillConstant
<
XDataType
>
{
ck
::
type_convert
<
XDataType
>
(
2.0
f
)}(
b_k_n_scale
);
if
(
config
.
verbosity
>
0
)
{
std
::
cout
<<
"Init A = {1}"
<<
std
::
endl
;
std
::
cout
<<
"Init A scale = {0.5}"
<<
std
::
endl
;
std
::
cout
<<
"Init B = {1}"
<<
std
::
endl
;
std
::
cout
<<
"Init B scale = {2.0}"
<<
std
::
endl
;
std
::
cout
<<
"Expect C = {K}"
<<
std
::
endl
;
}
break
;
default:
if
(
config
.
verbosity
>
0
)
{
std
::
cout
<<
"NOTE: No input data initialization."
<<
std
::
endl
;
}
}
if
(
config
.
verbosity
>
0
)
std
::
cout
<<
"Device memory allocation..."
<<
std
::
endl
;
DeviceMem
a_device_buf
(
sizeof
(
ADataType
)
*
a_m_k
.
mDesc
.
GetElementSpaceSize
());
DeviceMem
a_scale_device_buf
(
sizeof
(
XDataType
)
*
a_m_k_scale
.
mDesc
.
GetElementSpaceSize
());
DeviceMem
b_device_buf
(
sizeof
(
BDataType
)
*
b_k_n
.
mDesc
.
GetElementSpaceSize
());
DeviceMem
b_scale_device_buf
(
sizeof
(
XDataType
)
*
b_k_n_scale
.
mDesc
.
GetElementSpaceSize
());
DeviceMem
c_device_buf
(
sizeof
(
CDataType
)
*
c_m_n_device_result
.
mDesc
.
GetElementSpaceSize
());
if
(
config
.
verbosity
>
0
)
std
::
cout
<<
"Upload data to device..."
<<
std
::
endl
;
a_device_buf
.
ToDevice
(
a_m_k
.
mData
.
data
());
a_scale_device_buf
.
ToDevice
(
a_m_k_scale
.
mData
.
data
());
b_device_buf
.
ToDevice
(
b_k_n
.
mData
.
data
());
b_scale_device_buf
.
ToDevice
(
b_k_n_scale
.
mData
.
data
());
if
(
config
.
verbosity
>
0
)
std
::
cout
<<
"Done."
<<
std
::
endl
;
auto
a_element_op
=
AElementOp
{};
auto
b_element_op
=
BElementOp
{};
auto
cde_element_op
=
CDEElementOp
{};
constexpr
ck
::
index_t
NumDTensor
=
DsDataType
::
Size
();
// do GEMM
auto
device_op
=
DeviceOpInstance
{};
auto
invoker
=
device_op
.
MakeInvoker
();
auto
argument
=
device_op
.
MakeArgument
(
a_device_buf
.
GetDeviceBuffer
(),
b_device_buf
.
GetDeviceBuffer
(),
std
::
array
<
const
void
*
,
NumDTensor
>
{},
c_device_buf
.
GetDeviceBuffer
(),
M
,
N
,
K
,
StrideA
,
StrideB
,
std
::
array
<
ck
::
index_t
,
NumDTensor
>
{},
StrideC
,
a_scale_device_buf
.
GetDeviceBuffer
(),
b_scale_device_buf
.
GetDeviceBuffer
(),
a_element_op
,
b_element_op
,
cde_element_op
);
if
(
!
device_op
.
IsSupportedArgument
(
argument
))
{
throw
std
::
runtime_error
(
"wrong!
\n
"
"Provided combination of compilation and runtime parameters is "
"not consistent with the supported device_gemm arguments."
);
}
if
(
config
.
verbosity
>
0
)
std
::
cout
<<
"Computing GEMM on device..."
<<
std
::
endl
;
float
ave_time
=
invoker
.
Run
(
argument
,
StreamConfig
{
nullptr
,
config
.
time_kernel
,
config
.
verbosity
,
20
,
50
});
bool
res_verified
=
true
;
if
(
config
.
do_verification
>
0
)
{
c_device_buf
.
FromDevice
(
c_m_n_device_result
.
mData
.
data
());
if
(
config
.
verbosity
>
0
)
{
std
::
cout
<<
"Done."
<<
std
::
endl
;
std
::
cout
<<
"Computing GEMM on host..."
<<
std
::
endl
;
}
using
ReferenceGemmInstance
=
ck
::
tensor_operation
::
host
::
ReferenceMXGemm
<
ADataType
,
BDataType
,
CDataType
,
AccDataType
,
float
,
PassThrough
,
PassThrough
,
PassThrough
,
float
,
float
>
;
auto
ref_gemm
=
ReferenceGemmInstance
{};
auto
ref_invoker
=
ref_gemm
.
MakeInvoker
();
auto
ref_argument
=
ref_gemm
.
MakeArgument
(
a_m_k
,
a_m_k_scale
,
b_k_n
,
b_k_n_scale
,
c_m_n_host_result
,
PassThrough
{},
PassThrough
{},
PassThrough
{});
ref_invoker
.
Run
(
ref_argument
);
if
(
config
.
verbosity
>
0
)
{
std
::
cout
<<
"Done."
<<
std
::
endl
;
std
::
cout
<<
"Comparing results..."
<<
std
::
endl
;
}
if
(
config
.
init_method
==
1
)
{
res_verified
=
res_verified
&&
std
::
abs
(
static_cast
<
float
>
(
K
)
-
c_m_n_device_result
(
0
,
0
))
<=
0.0
f
;
std
::
cout
<<
"Expected vs Computed: "
<<
1.0
f
*
K
<<
" vs "
<<
c_m_n_device_result
(
0
,
0
)
<<
((
res_verified
)
?
" (PASSED!)"
:
" (FAILED!)"
)
<<
std
::
endl
;
}
res_verified
=
res_verified
&&
ck
::
utils
::
check_err
(
c_m_n_device_result
,
c_m_n_host_result
,
"Error: Incorrect results!"
);
if
(
config
.
verbosity
>
0
&&
res_verified
)
std
::
cout
<<
"Done."
<<
std
::
endl
;
}
else
{
if
(
config
.
verbosity
>
0
)
std
::
cout
<<
"Done."
<<
std
::
endl
;
}
if
(
config
.
time_kernel
)
{
std
::
size_t
flop
=
std
::
size_t
(
2
)
*
M
*
N
*
K
+
M
*
K
+
K
*
N
;
// GEMM + A scale + B scale
std
::
size_t
num_btype
=
sizeof
(
ADataType
)
*
M
*
K
+
sizeof
(
BDataType
)
*
K
*
N
+
sizeof
(
CDataType
)
*
M
*
N
+
sizeof
(
XDataType
)
*
(
M
*
K
+
K
*
N
)
/
Scale_Block_K
;
float
tflops
=
static_cast
<
float
>
(
flop
)
/
1.E9
/
ave_time
;
float
gb_per_sec
=
num_btype
/
1.E6
/
ave_time
;
std
::
cout
<<
"Perf: "
<<
ave_time
<<
" ms, "
<<
tflops
<<
" TFlops, "
<<
gb_per_sec
<<
" GB/s"
<<
std
::
endl
;
}
return
res_verified
;
}
template
<
typename
ADataType
,
typename
BDataType
,
typename
XDataType
,
typename
CDataType
,
typename
ALayout
,
typename
BLayout
,
typename
CLayout
,
typename
CElementWiseOp
,
typename
AccDataType
,
typename
CShuffleDataType
,
ck
::
index_t
MXVectorSize
>
bool
run_mx_gemm_example
(
int
argc
,
char
*
argv
[])
{
ProblemSize
problem_size
;
ExecutionConfig
config
;
return
parse_cmd_args
(
argc
,
argv
,
problem_size
,
config
)
&&
run_mx_gemm
<
ADataType
,
BDataType
,
XDataType
,
CDataType
,
ALayout
,
BLayout
,
CLayout
,
CElementWiseOp
,
AccDataType
,
CShuffleDataType
,
MXVectorSize
>
(
problem_size
,
config
);
}
example/67_gemm_microscaling/gemm_mx_fp8.cpp
0 → 100644
View file @
9ba504b6
// SPDX-License-Identifier: MIT
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
#include "gemm_mx_common.hpp"
using
ADataType
=
ck
::
f8_t
;
using
BDataType
=
ck
::
f8_t
;
#if 1
// XXX: MX-native GEMM kernel will work with e8m0_bexp_t scale type
using
XDataType
=
float
;
#else
using
XDataType
=
ck
::
e8m0_bexp_t
;
#endif
using
AccDataType
=
float
;
using
CShuffleDataType
=
float
;
using
CDataType
=
float
;
using
ALayout
=
Row
;
using
BLayout
=
Col
;
using
CLayout
=
Row
;
using
CElementOp
=
PassThrough
;
// elementwise transformation for C matrix
constexpr
ck
::
index_t
mx_vector_size
=
128
;
// scaling block size
int
main
(
int
argc
,
char
*
argv
[])
{
return
run_mx_gemm_example
<
ADataType
,
BDataType
,
XDataType
,
CDataType
,
ALayout
,
BLayout
,
CLayout
,
CElementOp
,
AccDataType
,
CShuffleDataType
,
mx_vector_size
>
(
argc
,
argv
)
?
0
:
-
1
;
}
example/CMakeLists.txt
View file @
9ba504b6
...
...
@@ -23,34 +23,34 @@ function(add_example_executable EXAMPLE_NAME FILE_NAME)
message
(
"adding example
${
EXAMPLE_NAME
}
"
)
set
(
result 1
)
if
(
DEFINED DTYPES
)
foreach
(
source IN LISTS FILE_NAME
)
set
(
test 0
)
if
((
source MATCHES
"_fp16"
OR source MATCHES
"_f16"
)
AND NOT
"fp16"
IN_LIST DTYPES
)
set
(
test 1
)
endif
()
if
((
source MATCHES
"_fp32"
OR source MATCHES
"_f32"
)
AND NOT
"fp32"
IN_LIST DTYPES
)
set
(
test 1
)
endif
()
if
((
source MATCHES
"_fp64"
OR source MATCHES
"_f64"
)
AND NOT
"fp64"
IN_LIST DTYPES
)
set
(
test 1
)
endif
()
if
((
source MATCHES
"_fp8"
OR source MATCHES
"_f8"
)
AND NOT
"fp8"
IN_LIST DTYPES
)
set
(
test 1
)
endif
()
if
((
source MATCHES
"_bf8"
OR source MATCHES
"_bf8"
)
AND NOT
"bf8"
IN_LIST DTYPES
)
set
(
test 1
)
endif
()
if
((
source MATCHES
"_bf16"
OR source MATCHES
"_b16"
)
AND NOT
"bf16"
IN_LIST DTYPES
)
set
(
test 1
)
endif
()
if
((
source MATCHES
"_int8"
OR source MATCHES
"_i8"
)
AND NOT
"int8"
IN_LIST DTYPES
)
set
(
test 1
)
endif
()
if
(
test EQUAL 1
)
message
(
"removing example source file
${
source
}
"
)
list
(
REMOVE_ITEM FILE_NAME
"
${
source
}
"
)
endif
()
endforeach
()
foreach
(
source IN LISTS FILE_NAME
)
set
(
test 0
)
if
((
source MATCHES
"_fp16"
OR source MATCHES
"_f16"
)
AND NOT
"fp16"
IN_LIST DTYPES
)
set
(
test 1
)
endif
()
if
((
source MATCHES
"_fp32"
OR source MATCHES
"_f32"
)
AND NOT
"fp32"
IN_LIST DTYPES
)
set
(
test 1
)
endif
()
if
((
source MATCHES
"_fp64"
OR source MATCHES
"_f64"
)
AND NOT
"fp64"
IN_LIST DTYPES
)
set
(
test 1
)
endif
()
if
((
source MATCHES
"_fp8"
OR source MATCHES
"_f8"
)
AND NOT
"fp8"
IN_LIST DTYPES
)
set
(
test 1
)
endif
()
if
((
source MATCHES
"_bf8"
OR source MATCHES
"_bf8"
)
AND NOT
"bf8"
IN_LIST DTYPES
)
set
(
test 1
)
endif
()
if
((
source MATCHES
"_bf16"
OR source MATCHES
"_b16"
)
AND NOT
"bf16"
IN_LIST DTYPES
)
set
(
test 1
)
endif
()
if
((
source MATCHES
"_int8"
OR source MATCHES
"_i8"
)
AND NOT
"int8"
IN_LIST DTYPES
)
set
(
test 1
)
endif
()
if
(
test EQUAL 1
)
message
(
"removing example source file
${
source
}
"
)
list
(
REMOVE_ITEM FILE_NAME
"
${
source
}
"
)
endif
()
endforeach
()
endif
()
set
(
EX_TARGETS
${
SUPPORTED_GPU_TARGETS
}
)
...
...
@@ -83,6 +83,13 @@ function(add_example_executable EXAMPLE_NAME FILE_NAME)
list
(
REMOVE_ITEM FILE_NAME
"
${
source
}
"
)
endif
()
endforeach
()
#Do not build any microscaling examples if gfx950 target is not on the list
foreach
(
source IN LISTS FILE_NAME
)
if
(
NOT EX_TARGETS MATCHES
"gfx950"
AND source MATCHES
"_mx"
)
message
(
"removing microscaling example
${
source
}
"
)
list
(
REMOVE_ITEM FILE_NAME
"
${
source
}
"
)
endif
()
endforeach
()
#Do not build any FP8 examples if CK_ENABLE_FP8 not set
foreach
(
source IN LISTS FILE_NAME
)
if
(
NOT DEFINED CK_ENABLE_FP8 AND source MATCHES
"_fp8"
)
...
...
@@ -102,7 +109,9 @@ function(add_example_executable EXAMPLE_NAME FILE_NAME)
if
(
FILE_NAME MATCHES
"_xdl"
)
list
(
REMOVE_ITEM EX_TARGETS gfx900 gfx906 gfx906:xnack- gfx1030 gfx1100 gfx1101 gfx1102 gfx1103 gfx1200 gfx1201 gfx10.3-generic gfx11-generic gfx12-generic
)
elseif
(
FILE_NAME MATCHES
"_wmma"
)
list
(
REMOVE_ITEM EX_TARGETS gfx900 gfx906 gfx906:xnack- gfx908:xnack+ gfx908:xnack- gfx90a:xnack+ gfx90a:xnack- gfx908 gfx90a gfx940 gfx941 gfx942 gfx1030
)
list
(
REMOVE_ITEM EX_TARGETS gfx900 gfx906 gfx906:xnack- gfx908:xnack+ gfx908:xnack- gfx90a:xnack+ gfx90a:xnack- gfx908 gfx90a gfx940 gfx941 gfx942 gfx1030 gfx950
)
elseif
(
FILE_NAME MATCHES
"_mx"
)
#only build mx example for gfx950
list
(
REMOVE_ITEM EX_TARGETS gfx900 gfx906 gfx906:xnack- gfx908:xnack+ gfx908:xnack- gfx90a:xnack+ gfx90a:xnack- gfx908 gfx90a gfx940 gfx941 gfx942 gfx1030 gfx1100 gfx1101 gfx1102 gfx1103 gfx1200 gfx1201 gfx10.3-generic gfx11-generic gfx12-generic
)
endif
()
set_source_files_properties
(
${
FILE_NAME
}
PROPERTIES LANGUAGE HIP
)
add_executable
(
${
EXAMPLE_NAME
}
${
FILE_NAME
}
)
...
...
@@ -195,7 +204,7 @@ function(add_example_executable_no_testing EXAMPLE_NAME FILE_NAME)
if
(
FILE_NAME MATCHES
"_xdl"
)
list
(
REMOVE_ITEM EX_TARGETS gfx900 gfx906 gfx906:xnack- gfx1030 gfx1100 gfx1101 gfx1102 gfx1103 gfx1200 gfx1201 gfx10.3-generic gfx11-generic gfx12-generic
)
elseif
(
FILE_NAME MATCHES
"_wmma"
)
list
(
REMOVE_ITEM EX_TARGETS gfx900 gfx906 gfx906:xnack- gfx908:xnack+ gfx908:xnack- gfx90a:xnack+ gfx90a:xnack- gfx908 gfx90a gfx940 gfx941 gfx942 gfx1030
)
list
(
REMOVE_ITEM EX_TARGETS gfx900 gfx906 gfx906:xnack- gfx908:xnack+ gfx908:xnack- gfx90a:xnack+ gfx90a:xnack- gfx908 gfx90a gfx940 gfx941 gfx942 gfx1030
gfx950
)
endif
()
set_source_files_properties
(
${
FILE_NAME
}
PROPERTIES LANGUAGE HIP
)
add_executable
(
${
EXAMPLE_NAME
}
${
FILE_NAME
}
)
...
...
example/ck_tile/03_gemm/gemm_basic.cpp
View file @
9ba504b6
...
...
@@ -12,7 +12,13 @@
#include "ck_tile/host.hpp"
#include "gemm_basic.hpp"
template
<
typename
ALayout
,
typename
BLayout
,
typename
CLayout
>
template
<
typename
ADataType
,
typename
BDataType
,
typename
AccDataType
,
typename
CDataType
,
typename
ALayout
,
typename
BLayout
,
typename
CLayout
>
float
gemm_calc
(
const
ck_tile
::
GemmHostArgs
&
args
,
const
ck_tile
::
stream_config
&
s
)
{
// The kPadM, kPadN, kPadK & kBlockPerCu should also come from the Codegen part.
...
...
@@ -25,7 +31,7 @@ float gemm_calc(const ck_tile::GemmHostArgs& args, const ck_tile::stream_config&
// This part comes from the Codegen
constexpr
ck_tile
::
index_t
M_Tile
=
128
;
constexpr
ck_tile
::
index_t
N_Tile
=
128
;
constexpr
ck_tile
::
index_t
K_Tile
=
32
;
constexpr
ck_tile
::
index_t
K_Tile
=
64
;
constexpr
ck_tile
::
index_t
M_Warp
=
2
;
constexpr
ck_tile
::
index_t
N_Warp
=
2
;
...
...
@@ -33,7 +39,7 @@ float gemm_calc(const ck_tile::GemmHostArgs& args, const ck_tile::stream_config&
constexpr
ck_tile
::
index_t
M_Warp_Tile
=
32
;
constexpr
ck_tile
::
index_t
N_Warp_Tile
=
32
;
constexpr
ck_tile
::
index_t
K_Warp_Tile
=
8
;
constexpr
ck_tile
::
index_t
K_Warp_Tile
=
16
;
using
CodegenGemmShape
=
ck_tile
::
TileGemmShape
<
ck_tile
::
sequence
<
M_Tile
,
N_Tile
,
K_Tile
>
,
...
...
@@ -99,12 +105,32 @@ int run_gemm_example(int argc, char* argv[])
using
Row
=
ck_tile
::
tensor_layout
::
gemm
::
RowMajor
;
using
Col
=
ck_tile
::
tensor_layout
::
gemm
::
ColumnMajor
;
std
::
string
a_layout
=
arg_parser
.
get_str
(
"a_layout"
);
std
::
string
b_layout
=
arg_parser
.
get_str
(
"b_layout"
);
std
::
string
data_type
=
arg_parser
.
get_str
(
"prec"
);
std
::
string
a_layout
=
arg_parser
.
get_str
(
"a_layout"
);
std
::
string
b_layout
=
arg_parser
.
get_str
(
"b_layout"
);
if
(
a_layout
==
"R"
&&
b_layout
==
"C"
)
{
return
run_gemm_example_with_layouts
(
argc
,
argv
,
Row
{},
Col
{},
Row
{});
if
(
data_type
==
"fp16"
)
{
return
run_gemm_example_with_layouts
<
ck_tile
::
half_t
>
(
argc
,
argv
,
Row
{},
Col
{},
Row
{});
}
else
if
(
data_type
==
"bf16"
)
{
return
run_gemm_example_with_layouts
<
ck_tile
::
bf16_t
>
(
argc
,
argv
,
Row
{},
Col
{},
Row
{});
}
else
if
(
data_type
==
"fp8"
)
{
return
run_gemm_example_with_layouts
<
ck_tile
::
fp8_t
>
(
argc
,
argv
,
Row
{},
Col
{},
Row
{});
}
else
if
(
data_type
==
"bf8"
)
{
return
run_gemm_example_with_layouts
<
ck_tile
::
bf8_t
>
(
argc
,
argv
,
Row
{},
Col
{},
Row
{});
}
else
{
throw
std
::
runtime_error
(
"Unsupported data_type!"
);
}
}
else
{
...
...
example/ck_tile/03_gemm/gemm_basic.hpp
View file @
9ba504b6
...
...
@@ -48,6 +48,33 @@ struct GemmBasicTypeConfig<ck_tile::half_t>
// ToDo: Add more bias config to support different categories of GEMM.
};
template
<
>
struct
GemmBasicTypeConfig
<
ck_tile
::
bf16_t
>
{
using
ADataType
=
ck_tile
::
bf16_t
;
using
BDataType
=
ck_tile
::
bf16_t
;
using
AccDataType
=
float
;
using
CDataType
=
ck_tile
::
bf16_t
;
};
template
<
>
struct
GemmBasicTypeConfig
<
ck_tile
::
fp8_t
>
{
using
ADataType
=
ck_tile
::
fp8_t
;
using
BDataType
=
ck_tile
::
fp8_t
;
using
AccDataType
=
float
;
using
CDataType
=
ck_tile
::
half_t
;
};
template
<
>
struct
GemmBasicTypeConfig
<
ck_tile
::
bf8_t
>
{
using
ADataType
=
ck_tile
::
bf8_t
;
using
BDataType
=
ck_tile
::
bf8_t
;
using
AccDataType
=
float
;
using
CDataType
=
ck_tile
::
half_t
;
};
template
<
typename
T
>
struct
DataTypeTraits
;
...
...
@@ -69,13 +96,23 @@ struct DataTypeTraits<ck_tile::half_t>
static
constexpr
const
char
*
name
=
"fp16"
;
};
using
Types
=
GemmBasicTypeConfig
<
ck_tile
::
half_t
>
;
template
<
>
struct
DataTypeTraits
<
ck_tile
::
bf16_t
>
{
static
constexpr
const
char
*
name
=
"bf16"
;
};
// Specific type aliases for easy access
using
ADataType
=
Types
::
ADataType
;
using
BDataType
=
Types
::
BDataType
;
using
AccDataType
=
Types
::
AccDataType
;
using
CDataType
=
Types
::
CDataType
;
template
<
>
struct
DataTypeTraits
<
ck_tile
::
fp8_t
>
{
static
constexpr
const
char
*
name
=
"fp8"
;
};
template
<
>
struct
DataTypeTraits
<
ck_tile
::
bf8_t
>
{
static
constexpr
const
char
*
name
=
"bf8"
;
};
auto
create_args
(
int
argc
,
char
*
argv
[])
{
...
...
example/ck_tile/03_gemm/run_gemm_example.inc
View file @
9ba504b6
...
...
@@ -9,6 +9,7 @@ static constexpr inline auto is_row_major(Layout layout_)
ck_tile
::
tensor_layout
::
gemm
::
RowMajor
>>
{};
}
template
<
typename
ADataType
,
typename
BDataType
,
typename
AccDataType
,
typename
CDataType
>
auto
calculate_rtol_atol
(
const
ck_tile
::
index_t
K
,
const
ck_tile
::
index_t
kbatch
,
const
float
max_accumulated_value
)
...
...
@@ -29,7 +30,8 @@ auto calculate_rtol_atol(const ck_tile::index_t K,
return
ck_tile
::
make_tuple
(
std
::
max
(
rtol
,
rtol_split_k
),
std
::
max
(
atol
,
atol_split_k
));
}
template
<
typename
ALayout
,
typename
BLayout
,
typename
CLayout
>
template
<
typename
ADataType
,
typename
BDataType
,
typename
AccDataType
,
typename
CDataType
,
typename
ALayout
,
typename
BLayout
,
typename
CLayout
>
float
invoke_gemm
(
ck_tile
::
DeviceMem
&
a_m_k_dev_buf
,
ck_tile
::
DeviceMem
&
b_k_n_dev_buf
,
ck_tile
::
DeviceMem
&
c_m_n_dev_buf
,
...
...
@@ -55,7 +57,8 @@ float invoke_gemm(ck_tile::DeviceMem& a_m_k_dev_buf,
args
.
stride_B
=
stride_B
;
args
.
stride_C
=
stride_C
;
float
ave_time
=
gemm_calc
<
ALayout
,
BLayout
,
CLayout
>
(
float
ave_time
=
gemm_calc
<
ADataType
,
BDataType
,
AccDataType
,
CDataType
,
ALayout
,
BLayout
,
CLayout
>
(
args
,
ck_tile
::
stream_config
{
nullptr
,
true
,
1
,
n_warmup
,
n_repeat
});
std
::
size_t
flop
=
std
::
size_t
(
2
)
*
M
*
N
*
K
;
...
...
@@ -66,13 +69,19 @@ float invoke_gemm(ck_tile::DeviceMem& a_m_k_dev_buf,
std
::
cout
<<
"Run Gemm kernel with M ="
<<
M
<<
" N ="
<<
N
<<
" K ="
<<
K
<<
" StrideA ="
<<
stride_A
<<
" StrideB ="
<<
stride_B
<<
" StrideC ="
<<
stride_C
<<
" A_Layout ="
<<
ALayout
::
name
<<
" B_Layout ="
<<
BLayout
::
name
<<
" C_Layout ="
<<
CLayout
::
name
<<
" A Type = "
<<
DataTypeTraits
<
ADataType
>::
name
<<
" B Type = "
<<
DataTypeTraits
<
BDataType
>::
name
<<
" C Type = "
<<
DataTypeTraits
<
CDataType
>::
name
<<
" : "
<<
ave_time
<<
" ms, "
<<
tflops
<<
" TFlops, "
<<
gb_per_sec
<<
" GB/s, "
<<
std
::
endl
;
return
ave_time
;
}
template
<
typename
ALayout
,
typename
BLayout
,
typename
CLayout
>
template
<
typename
PrecType
,
typename
ALayout
,
typename
BLayout
,
typename
CLayout
>
int
run_gemm_example_with_layouts
(
int
argc
,
char
*
argv
[],
const
ALayout
a_layout
=
ALayout
{},
...
...
@@ -83,6 +92,11 @@ int run_gemm_example_with_layouts(int argc,
if
(
!
result
)
return
-
1
;
using
ADataType
=
typename
GemmBasicTypeConfig
<
PrecType
>::
ADataType
;
using
BDataType
=
typename
GemmBasicTypeConfig
<
PrecType
>::
BDataType
;
using
CDataType
=
typename
GemmBasicTypeConfig
<
PrecType
>::
CDataType
;
using
AccDataType
=
typename
GemmBasicTypeConfig
<
PrecType
>::
AccDataType
;
ck_tile
::
index_t
M
=
arg_parser
.
get_int
(
"m"
);
ck_tile
::
index_t
N
=
arg_parser
.
get_int
(
"n"
);
ck_tile
::
index_t
K
=
arg_parser
.
get_int
(
"k"
);
...
...
@@ -114,8 +128,8 @@ int run_gemm_example_with_layouts(int argc,
ck_tile
::
FillMonotonicSeq
<
ADataType
>
{}(
a_m_k
);
ck_tile
::
FillMonotonicSeq
<
BDataType
>
{}(
b_k_n
);
}
else
if
(
init_method
==
2
)
{
ck_tile
::
FillConstant
<
ADataType
>
{
1.
f
}(
a_m_k
);
ck_tile
::
FillConstant
<
BDataType
>
{
1.
f
}(
b_k_n
);
ck_tile
::
FillConstant
<
ADataType
>
{
static_cast
<
ADataType
>
(
1
)
}(
a_m_k
);
ck_tile
::
FillConstant
<
BDataType
>
{
static_cast
<
BDataType
>
(
1
)
}(
b_k_n
);
}
else
{
a_m_k
.
SetZero
();
b_k_n
.
SetZero
();
...
...
@@ -130,7 +144,8 @@ int run_gemm_example_with_layouts(int argc,
c_m_n_dev_buf
.
SetZero
();
c_m_n_dev_result
.
SetZero
();
invoke_gemm
<
ALayout
,
BLayout
,
CLayout
>
(
a_m_k_dev_buf
,
invoke_gemm
<
ADataType
,
BDataType
,
AccDataType
,
CDataType
,
ALayout
,
BLayout
,
CLayout
>
(
a_m_k_dev_buf
,
b_k_n_dev_buf
,
c_m_n_dev_buf
,
M
,
...
...
@@ -156,7 +171,8 @@ int run_gemm_example_with_layouts(int argc,
a_m_k
,
b_k_n
,
c_m_n_host_ref
);
const
float
max_accumulated_value
=
*
std
::
max_element
(
c_m_n_host_ref
.
mData
.
begin
(),
c_m_n_host_ref
.
mData
.
end
());
const
auto
rtol_atol
=
calculate_rtol_atol
(
K
,
kbatch
,
max_accumulated_value
);
const
auto
rtol_atol
=
calculate_rtol_atol
<
ADataType
,
BDataType
,
AccDataType
,
CDataType
>
(
K
,
kbatch
,
max_accumulated_value
);
pass
=
ck_tile
::
check_err
(
c_m_n_dev_result
,
c_m_n_host_ref
,
"Error: Incorrect results!"
,
...
...
@@ -213,7 +229,8 @@ int run_gemm_example_with_layouts(int argc,
c_m_n_gpu_buf_ref
.
FromDevice
(
c_m_n_gpu_ref
.
data
());
const
float
max_accumulated_value
=
*
std
::
max_element
(
c_m_n_gpu_ref
.
mData
.
begin
(),
c_m_n_gpu_ref
.
mData
.
end
());
const
auto
rtol_atol
=
calculate_rtol_atol
(
K
,
kbatch
,
max_accumulated_value
);
const
auto
rtol_atol
=
calculate_rtol_atol
<
ADataType
,
BDataType
,
AccDataType
,
CDataType
>
(
K
,
kbatch
,
max_accumulated_value
);
pass
=
ck_tile
::
check_err
(
c_m_n_dev_result
,
c_m_n_gpu_ref
,
"Error: Incorrect results!"
,
...
...
example/ck_tile/03_gemm/script/benchmark_basic_bf16.sh
0 → 100644
View file @
9ba504b6
example/ck_tile/03_gemm/script/benchmark_basic_bf8.sh
0 → 100644
View file @
9ba504b6
example/ck_tile/03_gemm/script/benchmark_basic_fp8.sh
0 → 100644
View file @
9ba504b6
#!/bin/sh
EXE
=
"
$(
find
.
-name
tile_example_gemm_basic
-type
f |
head
-n
1
)
"
VALID
=
1
for
b_matrix_layout
in
"C"
;
do
for
m
in
"64"
"512"
"1024"
"2048"
;
do
for
n
in
"512"
"1024"
"2048"
;
do
for
k
in
"64"
"512"
"1024"
"2048"
;
do
$EXE
-prec
=
fp8
-m
=
$m
-n
=
$n
-k
=
$k
-a_layout
=
"R"
-b_layout
=
"
$b_matrix_layout
"
-c_layout
=
"R"
-v
=
$VALID
done
done
done
done
\ No newline at end of file
example/ck_tile/03_gemm/script/benchmark_mem_pipeline.sh
View file @
9ba504b6
...
...
@@ -2,10 +2,10 @@
EXE
=
"
$(
find
.
-name
tile_example_gemm_universal
-type
f |
head
-n
1
)
"
VALID
=
1
for
b_matrix_layout
in
"R"
"C"
;
do
for
m
in
"64"
"512"
"1024"
"2048"
;
do
for
b_matrix_layout
in
"C"
;
do
for
m
in
"512"
"1024"
"2048"
"4096"
;
do
for
n
in
"512"
"1024"
"2048"
;
do
for
k
in
"64"
"512"
"1024"
"2048"
;
do
for
k
in
"512"
"1024"
"2048"
;
do
$EXE
-prec
=
fp16
-m
=
$m
-n
=
$n
-k
=
$k
-a_layout
=
"R"
-b_layout
=
"
$b_matrix_layout
"
-c_layout
=
"R"
-v
=
$VALID
done
done
...
...
example/ck_tile/03_gemm/script/benchmark_mem_pipeline_bf16.sh
0 → 100644
View file @
9ba504b6
#!/bin/sh
EXE
=
"
$(
find
.
-name
tile_example_gemm_universal
-type
f |
head
-n
1
)
"
VALID
=
1
for
b_matrix_layout
in
"C"
;
do
for
m
in
"512"
"1024"
"2048"
"4096"
;
do
for
n
in
"512"
"1024"
"2048"
;
do
for
k
in
"512"
"1024"
"2048"
;
do
$EXE
-prec
=
bf16
-m
=
$m
-n
=
$n
-k
=
$k
-a_layout
=
"R"
-b_layout
=
"
$b_matrix_layout
"
-c_layout
=
"R"
-v
=
$VALID
done
done
done
done
\ No newline at end of file
example/ck_tile/03_gemm/script/benchmark_mem_pipeline_bf8.sh
0 → 100644
View file @
9ba504b6
#!/bin/sh
EXE
=
"
$(
find
.
-name
tile_example_gemm_universal
-type
f |
head
-n
1
)
"
VALID
=
1
for
b_matrix_layout
in
"C"
;
do
for
m
in
"512"
"1024"
"2048"
"4096"
;
do
for
n
in
"512"
"1024"
"2048"
;
do
for
k
in
"512"
"1024"
"2048"
;
do
$EXE
-prec
=
bf8
-m
=
$m
-n
=
$n
-k
=
$k
-a_layout
=
"R"
-b_layout
=
"
$b_matrix_layout
"
-c_layout
=
"R"
-v
=
$VALID
done
done
done
done
\ No newline at end of file
example/ck_tile/03_gemm/script/benchmark_mem_pipeline_fp8.sh
0 → 100644
View file @
9ba504b6
#!/bin/sh
EXE
=
"
$(
find
.
-name
tile_example_gemm_universal
-type
f |
head
-n
1
)
"
VALID
=
1
for
b_matrix_layout
in
"C"
;
do
for
m
in
"512"
"1024"
"2048"
"4096"
;
do
for
n
in
"512"
"1024"
"2048"
;
do
for
k
in
"512"
"1024"
"2048"
;
do
$EXE
-prec
=
fp8
-m
=
$m
-n
=
$n
-k
=
$k
-a_layout
=
"R"
-b_layout
=
"
$b_matrix_layout
"
-c_layout
=
"R"
-v
=
$VALID
done
done
done
done
\ No newline at end of file
example/ck_tile/03_gemm/script/smoke_test_basic.sh
View file @
9ba504b6
...
...
@@ -7,22 +7,20 @@ export CK_REPEAT=1
COMMON_ARGS
=
'-v=2 -warmup=0 -repeat=1'
run_fp16_tests
()
{
for
batch
in
1 2
;
do
for
m
in
128 1024
;
do
for
n
in
128 2048
;
do
for
k
in
32 64
;
do
$EXE
-b
=
$batch
-m
=
$m
-n
=
$n
-k
=
$k
-stride_a
=
0
-stride_b
=
0
-stride_c
=
0
-e
=
1e-5
-prec
=
fp16
$COMMON_ARGS
if
[
$?
-eq
0
]
;
then
echo
"Success: Test with batch=
$batch
, m=
$m
, n=
$n
, k=
$k
executed successfully."
else
echo
"Error: Test with batch=
$batch
, m=
$m
, n=
$n
, k=
$k
failed to execute properly."
# Optionally, exit or break if you need to halt further execution
# exit 1
fi
done
run_tests
()
{
for
m
in
128 1024
;
do
for
n
in
128 2048
;
do
for
k
in
64 128
;
do
$EXE
-m
=
$m
-n
=
$n
-k
=
$k
-stride_a
=
0
-stride_b
=
0
-stride_c
=
0
-prec
=
$1
$COMMON_ARGS
if
[
$?
-eq
0
]
;
then
echo
"Success: Test with m=
$m
, n=
$n
, k=
$k
executed successfully."
else
echo
"Error: Test with m=
$m
, n=
$n
, k=
$k
failed to execute properly."
# Optionally, exit or break if you need to halt further execution
# exit 1
fi
done
done
done
...
...
@@ -30,6 +28,9 @@ run_fp16_tests() {
set
-x
run_fp16_tests
run_tests
"fp16"
run_tests
"bf16"
run_tests
"fp8"
run_tests
"bf8"
set
+x
example/ck_tile/03_gemm/script/smoke_test_mem_pipeline.sh
View file @
9ba504b6
...
...
@@ -7,22 +7,20 @@ export CK_REPEAT=1
COMMON_ARGS
=
'-v=2 -warmup=0 -repeat=1'
run_fp16_tests
()
{
for
batch
in
1 2
;
do
for
m
in
128 1024
;
do
for
n
in
128 2048
;
do
for
k
in
32 64
;
do
$EXE
-b
=
$batch
-m
=
$m
-n
=
$n
-k
=
$k
-stride_a
=
0
-stride_b
=
0
-stride_c
=
0
-e
=
1e-5
-prec
=
fp16
$COMMON_ARGS
if
[
$?
-eq
0
]
;
then
echo
"Success: Test with batch=
$batch
, m=
$m
, n=
$n
, k=
$k
executed successfully."
else
echo
"Error: Test with batch=
$batch
, m=
$m
, n=
$n
, k=
$k
failed to execute properly."
# Optionally, exit or break if you need to halt further execution
# exit 1
fi
done
run_tests
()
{
for
m
in
512 1024
;
do
for
n
in
512 2048
;
do
for
k
in
512 1024
;
do
$EXE
-m
=
$m
-n
=
$n
-k
=
$k
-stride_a
=
0
-stride_b
=
0
-stride_c
=
0
-prec
=
$1
$COMMON_ARGS
if
[
$?
-eq
0
]
;
then
echo
"Success: Test with batch=
$batch
, m=
$m
, n=
$n
, k=
$k
executed successfully."
else
echo
"Error: Test with batch=
$batch
, m=
$m
, n=
$n
, k=
$k
failed to execute properly."
# Optionally, exit or break if you need to halt further execution
# exit 1
fi
done
done
done
...
...
@@ -30,6 +28,9 @@ run_fp16_tests() {
set
-x
run_fp16_tests
run_tests
"fp16"
run_tests
"bf16"
run_tests
"fp8"
run_tests
"bf8"
set
+x
Prev
1
2
3
4
5
6
…
10
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment