Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel_ROCM
Commits
dbb7002d
"vscode:/vscode.git/clone" did not exist on "4a2f044abb4baff12a3300d6a35e6680d9a49f01"
Commit
dbb7002d
authored
Feb 06, 2025
by
Adam Osewski
Browse files
Merge remote-tracking branch 'origin/develop' into aosewski/hotloop
parents
96c8d948
2bef5501
Changes
228
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
763 additions
and
123 deletions
+763
-123
example/67_gemm_microscaling/CMakeLists.txt
example/67_gemm_microscaling/CMakeLists.txt
+5
-0
example/67_gemm_microscaling/README.md
example/67_gemm_microscaling/README.md
+17
-0
example/67_gemm_microscaling/gemm_mx_common.hpp
example/67_gemm_microscaling/gemm_mx_common.hpp
+427
-0
example/67_gemm_microscaling/gemm_mx_fp8.cpp
example/67_gemm_microscaling/gemm_mx_fp8.cpp
+41
-0
example/CMakeLists.txt
example/CMakeLists.txt
+39
-30
example/ck_tile/03_gemm/script/benchmark_basic.sh
example/ck_tile/03_gemm/script/benchmark_basic.sh
+2
-2
example/ck_tile/03_gemm/script/benchmark_mem_pipeline.sh
example/ck_tile/03_gemm/script/benchmark_mem_pipeline.sh
+2
-2
include/ck/README.md
include/ck/README.md
+20
-16
include/ck/ck.hpp
include/ck/ck.hpp
+14
-4
include/ck/config.h.in
include/ck/config.h.in
+4
-0
include/ck/host_utility/device_prop.hpp
include/ck/host_utility/device_prop.hpp
+4
-3
include/ck/library/utility/check_err.hpp
include/ck/library/utility/check_err.hpp
+75
-24
include/ck/library/utility/host_tensor_generator.hpp
include/ck/library/utility/host_tensor_generator.hpp
+43
-0
include/ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v7r2.hpp
...ion/gpu/block/thread_group_tensor_slice_transfer_v7r2.hpp
+2
-2
include/ck/tensor_operation/gpu/device/convolution_forward_specialization.hpp
...eration/gpu/device/convolution_forward_specialization.hpp
+5
-1
include/ck/tensor_operation/gpu/device/device_base.hpp
include/ck/tensor_operation/gpu/device/device_base.hpp
+9
-4
include/ck/tensor_operation/gpu/device/device_grouped_conv_fwd_multiple_abd.hpp
...ation/gpu/device/device_grouped_conv_fwd_multiple_abd.hpp
+18
-4
include/ck/tensor_operation/gpu/device/gemm_specialization.hpp
...de/ck/tensor_operation/gpu/device/gemm_specialization.hpp
+3
-1
include/ck/tensor_operation/gpu/device/impl/codegen_device_grouped_conv_fwd_multiple_abd_xdl_cshuffle.hpp
...gen_device_grouped_conv_fwd_multiple_abd_xdl_cshuffle.hpp
+32
-28
include/ck/tensor_operation/gpu/device/impl/device_batched_contraction_multiple_d_xdl_cshuffle.hpp
...pl/device_batched_contraction_multiple_d_xdl_cshuffle.hpp
+1
-2
No files found.
example/67_gemm_microscaling/CMakeLists.txt
0 → 100644
View file @
dbb7002d
add_custom_target
(
example_gemm_mx
)
add_example_executable
(
example_gemm_mx_fp8 gemm_mx_fp8.cpp
)
add_example_dependencies
(
example_gemm_mx example_gemm_mx_fp8
)
example/67_gemm_microscaling/README.md
0 → 100644
View file @
dbb7002d
# GEMM Examples for Microscaling Formats
## example_gemm_mx_fp8
```
bash
# arg1: verification (0=no, 1=CPU)
# arg2: initialization (0=no init, 1=integer value, 2=decimal value)
# arg3: time kernel (0=no, 1=yes)
# arg4: verbosity (0=no info, 1=verbose info)
# arg5 to 10: M (16x), N(16x), K(16x), StrideA, StrideB, StrideC
./bin/example_gemm_mx_fp8 1 1 0 1
```
```
bash
# Implies: ./bin/example_gemm_mx_fp8 1 2 0 0
./bin/example_gemm_mx_fp8
```
\ No newline at end of file
example/67_gemm_microscaling/gemm_mx_common.hpp
0 → 100644
View file @
dbb7002d
// SPDX-License-Identifier: MIT
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <iostream>
#include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/element/unary_element_wise_operation.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_xdl_cshuffle_v3_ab_scale.hpp"
#include "ck/utility/blkgemmpipe_scheduler.hpp"
#include "ck/utility/data_type.hpp"
#include "ck/utility/sequence.hpp"
#include "ck/library/reference_tensor_operation/cpu/reference_gemm.hpp"
#include "ck/library/utility/check_err.hpp"
#include "ck/library/utility/device_memory.hpp"
#include "ck/library/utility/fill.hpp"
#include "ck/library/utility/host_tensor.hpp"
using
ScaleDataType
=
ck
::
e8m0_bexp_t
;
template
<
ck
::
index_t
...
Is
>
using
S
=
ck
::
Sequence
<
Is
...
>
;
using
Row
=
ck
::
tensor_layout
::
gemm
::
RowMajor
;
using
Col
=
ck
::
tensor_layout
::
gemm
::
ColumnMajor
;
using
PassThrough
=
ck
::
tensor_operation
::
element_wise
::
PassThrough
;
struct
ExecutionConfig
final
{
int
do_verification
=
1
;
// (0=no, 1=CPU)
int
init_method
=
2
;
// (0=no init, 1=integer value, 2=decimal value)
bool
time_kernel
=
false
;
// (0=no, 1=yes)
int
verbosity
=
0
;
// (0=no info, 1=verbose info)
};
struct
ProblemSize
final
{
ck
::
index_t
M
=
3840
;
ck
::
index_t
N
=
4096
;
ck
::
index_t
K
=
4096
;
ck
::
index_t
StrideA
=
-
1
;
ck
::
index_t
StrideB
=
-
1
;
ck
::
index_t
StrideC
=
-
1
;
};
bool
parse_cmd_args
(
int
argc
,
char
*
argv
[],
ProblemSize
&
problem_size
,
ExecutionConfig
&
config
)
{
if
(
argc
==
1
)
{
// use default case
}
else
if
(
argc
==
5
)
{
config
.
do_verification
=
std
::
stoi
(
argv
[
1
]);
config
.
init_method
=
std
::
stoi
(
argv
[
2
]);
config
.
time_kernel
=
std
::
stoi
(
argv
[
3
]);
config
.
verbosity
=
std
::
stoi
(
argv
[
4
]);
}
else
if
(
argc
==
11
)
{
config
.
do_verification
=
std
::
stoi
(
argv
[
1
]);
config
.
init_method
=
std
::
stoi
(
argv
[
2
]);
config
.
time_kernel
=
std
::
stoi
(
argv
[
3
]);
config
.
verbosity
=
std
::
stoi
(
argv
[
4
]);
problem_size
.
M
=
std
::
stoi
(
argv
[
5
]);
problem_size
.
N
=
std
::
stoi
(
argv
[
6
]);
problem_size
.
K
=
std
::
stoi
(
argv
[
7
]);
problem_size
.
StrideA
=
std
::
stoi
(
argv
[
8
]);
problem_size
.
StrideB
=
std
::
stoi
(
argv
[
9
]);
problem_size
.
StrideC
=
std
::
stoi
(
argv
[
10
]);
}
else
{
std
::
cerr
<<
"arg1: verification (0=no, 1=CPU)"
<<
std
::
endl
<<
"arg2: initialization (0=no init, 1=integer value, 2=decimal value)"
<<
std
::
endl
<<
"arg3: time kernel (0=no, 1=yes)"
<<
std
::
endl
<<
"arg4: verbosity (0=no info, 1=verbose info)"
<<
std
::
endl
<<
"arg5 to 10: M (16x), N(16x), K(16x), StrideA, StrideB, StrideC"
<<
std
::
endl
;
return
false
;
}
return
true
;
}
template
<
typename
ADataType
,
typename
BDataType
,
typename
XDataType
,
typename
CDataType
,
typename
ALayout
,
typename
BLayout
,
typename
CLayout
,
typename
CElementWiseOp
,
typename
AccDataType
,
typename
CShuffleDataType
,
ck
::
index_t
MXVectorSize
>
bool
run_mx_gemm
(
const
ProblemSize
&
problem_size
,
const
ExecutionConfig
&
config
)
{
using
ELayout
=
CLayout
;
using
DsLayout
=
ck
::
Tuple
<>
;
using
DsDataType
=
ck
::
Tuple
<>
;
using
AElementOp
=
PassThrough
;
using
BElementOp
=
PassThrough
;
using
CDEElementOp
=
CElementWiseOp
;
static
constexpr
auto
GemmSpec
=
ck
::
tensor_operation
::
device
::
GemmSpecialization
::
Default
;
static
constexpr
auto
BlkGemmPSched
=
ck
::
BlockGemmPipelineScheduler
::
Intrawave
;
static
constexpr
auto
BlkGemmPVer
=
ck
::
BlockGemmPipelineVersion
::
v3
;
#if 1
// XXX: These parameters should not exist in MX-native GEMM kernel
static
constexpr
ck
::
index_t
Scale_Block_M
=
128
;
static
constexpr
ck
::
index_t
Scale_Block_N
=
128
;
#endif
static
constexpr
ck
::
index_t
Scale_Block_K
=
MXVectorSize
;
// XXX: DeviceGemmMultiD_ABScale_Xdl_CShuffle_V3 is not designed to utilize MX-specific MFMA
// instructions.
//
// XXX: DeviceGemmMultiD_ABScale_Xdl_CShuffle_V3 is not designed to utilize device-optimized
// scaled type convert functions.
//
// XXX: In DeviceGemmMultiD_ABScale_Xdl_CShuffle_V3, KPerBlock is expected to be equal to
// ScaleBlockK (aka MXVectorSize).
// Additionally, the following is also expected:
// static_assert(ScaleBlockM % MPerBlock == 0);
// static_assert(ScaleBlockN % NPerBlock == 0);
// In MX-native GEMM kernel these requirements should be relaxed.
//
// XXX: It appears, by default we are using mfma_f32_16x16x4xf32
// MfmaSelector<ComputeTypeA, MPerXdl, NPerXdl, ComputeTypeB>::selected_mfma.k_per_blk =
// MfmaSelector<float, 16, 16, float>::selected_mfma.k_per_blk = mfma_f32_16x16x4xf32
// XXX: GridwiseGemmMultiD_ABScale_xdl_cshuffle_v3 assumes scale type is float
// clang-format off
using
DeviceOpInstance
=
ck
::
tensor_operation
::
device
::
DeviceGemmMultiD_ABScale_Xdl_CShuffle_V3
// ######| ALayout| BLayout| DsLayout| CLayout| ADataType| AScale| BDataType| BScale| DsDataType| CDataType| GemmAcc| CShuffleDataType|AElementwise|BElementwise| CElementwise| GemmSpec|Block| ScaleBlockM| ScaleBlockN| ScaleBlockK| M| N| K| AK1| BK1| M| N|MXdl|NXdl|ABlockTransfer|ABlockTransfer|ABlockTransfer|ABlockTransfer|ABlockTransfer|ABlockTransfer| ABlock|BBlockTransfer|BBlockTransfer|BBlockTransfer|BBlockTransfer|BBlockTransfer|BBlockTransfer| BBlock| CShuffle| CShuffle|CShuffleBlockTransfer|CDEShuffleBlockTransfer| BlkGemm| BlkGemm|ComputeTypeA|ComputeTypeB|LDSTypeA|LDSTypeB|
// ######| | | | | | DataType| | DataType| | | DataType| | Operation| Operation| Operation| | Size| | | | Per| Per| Per| | | Per| Per| Per| Per| ThreadCluster| ThreadCluster|SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar|LdsExtraM| ThreadCluster| ThreadCluster|SrcAccessOrder| SrcVector| SrcScalar| DstScalar|LdsExtraN| MXdl| NXdl| ClusterLengths| Scalar| PipeSched| PipelineVer| | | | |
// ######| | | | | | | | | | | | | | | | | | | | |Block|Block| Block| | | XDL| XDL|Wave|Wave| Lengths| ArrangeOrder| | | PerVector| PerVector_AK1| | Lengths| ArrangeOrder| | Dim| PerVector| PerVector_BK1| | PerWave| PerWave| MBlock_MPerBlock| PerVectors| | | | | | |
// ######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | AK0_M_AK1| | | | | | | BK0_N_BK1| | | | | |PerShuffle|PerShuffle| NBlock_NPerBlock| | | | | | | |
<
ALayout
,
BLayout
,
DsLayout
,
ELayout
,
ADataType
,
XDataType
,
BDataType
,
XDataType
,
DsDataType
,
CDataType
,
AccDataType
,
CShuffleDataType
,
AElementOp
,
BElementOp
,
CDEElementOp
,
GemmSpec
,
256
,
Scale_Block_M
,
Scale_Block_N
,
Scale_Block_K
,
128
,
128
,
128
,
16
,
16
,
16
,
16
,
4
,
4
,
S
<
8
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
16
,
16
,
0
,
S
<
8
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
16
,
16
,
0
,
1
,
2
,
S
<
1
,
32
,
1
,
8
>
,
S
<
8
,
8
,
1
>
,
BlkGemmPSched
,
BlkGemmPVer
,
float
,
float
,
float
,
float
>
;
// clang-format on
auto
M
=
problem_size
.
M
;
auto
N
=
problem_size
.
N
;
auto
K
=
problem_size
.
K
;
auto
StrideA
=
problem_size
.
StrideA
;
auto
StrideB
=
problem_size
.
StrideB
;
auto
StrideC
=
problem_size
.
StrideC
;
auto
f_host_tensor_descriptor
=
[](
ck
::
index_t
row
,
ck
::
index_t
col
,
ck
::
index_t
stride
,
auto
layout
)
{
if
constexpr
(
std
::
is_same_v
<
decltype
(
layout
),
ck
::
tensor_layout
::
gemm
::
RowMajor
>
)
{
return
HostTensorDescriptor
({
row
,
col
},
{
stride
,
1
});
}
else
{
return
HostTensorDescriptor
({
row
,
col
},
{
1
,
stride
});
}
};
auto
f_get_default_stride
=
[](
ck
::
index_t
row
,
ck
::
index_t
col
,
ck
::
index_t
stride
,
auto
layout
)
{
if
(
stride
==
-
1
)
{
// give a chance if stride is -1, return a default packed stride
if
constexpr
(
std
::
is_same_v
<
decltype
(
layout
),
ck
::
tensor_layout
::
gemm
::
RowMajor
>
)
{
return
static_cast
<
ck
::
index_t
>
(
col
);
}
else
{
return
static_cast
<
ck
::
index_t
>
(
row
);
}
}
else
return
static_cast
<
ck
::
index_t
>
(
stride
);
};
StrideA
=
f_get_default_stride
(
M
,
K
,
StrideA
,
ALayout
{});
StrideB
=
f_get_default_stride
(
K
,
N
,
StrideB
,
BLayout
{});
StrideC
=
f_get_default_stride
(
M
,
N
,
StrideC
,
CLayout
{});
if
(
K
%
Scale_Block_K
!=
0
)
{
throw
std
::
runtime_error
(
"wrong! K must be multiple of Scale_Block_K (16 or 32)"
);
};
auto
Scale_Stride_AM
=
f_get_default_stride
(
M
,
K
/
Scale_Block_K
,
StrideA
,
ALayout
{});
auto
Scale_Stride_BN
=
f_get_default_stride
(
K
/
Scale_Block_K
,
N
,
StrideB
,
BLayout
{});
Tensor
<
ADataType
>
a_m_k
(
f_host_tensor_descriptor
(
M
,
K
,
StrideA
,
ALayout
{}));
Tensor
<
BDataType
>
b_k_n
(
f_host_tensor_descriptor
(
K
,
N
,
StrideB
,
BLayout
{}));
Tensor
<
XDataType
>
a_m_k_scale
(
f_host_tensor_descriptor
(
M
,
K
/
Scale_Block_K
,
Scale_Stride_AM
,
ALayout
{}));
// scales for A
Tensor
<
XDataType
>
b_k_n_scale
(
f_host_tensor_descriptor
(
K
/
Scale_Block_K
,
N
,
Scale_Stride_BN
,
BLayout
{}));
// scales for B
Tensor
<
CDataType
>
c_m_n_host_result
(
f_host_tensor_descriptor
(
M
,
N
,
StrideC
,
CLayout
{}));
// host verification
Tensor
<
CDataType
>
c_m_n_device_result
(
f_host_tensor_descriptor
(
M
,
N
,
StrideC
,
CLayout
{}));
// device result downloaded to host
if
(
config
.
verbosity
>=
0
)
{
std
::
cout
<<
"a_m_k: "
<<
a_m_k
.
mDesc
<<
std
::
endl
;
std
::
cout
<<
"a_m_k_scale: "
<<
a_m_k_scale
.
mDesc
<<
std
::
endl
;
std
::
cout
<<
"b_k_n: "
<<
b_k_n
.
mDesc
<<
std
::
endl
;
std
::
cout
<<
"b_k_n_scale: "
<<
b_k_n_scale
.
mDesc
<<
std
::
endl
;
std
::
cout
<<
"c_m_n_device_result: "
<<
c_m_n_device_result
.
mDesc
<<
std
::
endl
;
}
switch
(
config
.
init_method
)
{
case
0
:
if
(
config
.
verbosity
>
0
)
{
std
::
cout
<<
"NOTE: No input data initialization."
<<
std
::
endl
;
}
break
;
case
1
:
case
2
:
ck
::
utils
::
FillConstant
<
ADataType
>
{
ck
::
type_convert
<
ADataType
>
(
1.0
f
)}(
a_m_k
);
ck
::
utils
::
FillConstant
<
XDataType
>
{
ck
::
type_convert
<
XDataType
>
(
0.5
f
)}(
a_m_k_scale
);
ck
::
utils
::
FillConstant
<
BDataType
>
{
ck
::
type_convert
<
BDataType
>
(
1.0
f
)}(
b_k_n
);
ck
::
utils
::
FillConstant
<
XDataType
>
{
ck
::
type_convert
<
XDataType
>
(
2.0
f
)}(
b_k_n_scale
);
if
(
config
.
verbosity
>
0
)
{
std
::
cout
<<
"Init A = {1}"
<<
std
::
endl
;
std
::
cout
<<
"Init A scale = {0.5}"
<<
std
::
endl
;
std
::
cout
<<
"Init B = {1}"
<<
std
::
endl
;
std
::
cout
<<
"Init B scale = {2.0}"
<<
std
::
endl
;
std
::
cout
<<
"Expect C = {K}"
<<
std
::
endl
;
}
break
;
default:
if
(
config
.
verbosity
>
0
)
{
std
::
cout
<<
"NOTE: No input data initialization."
<<
std
::
endl
;
}
}
if
(
config
.
verbosity
>
0
)
std
::
cout
<<
"Device memory allocation..."
<<
std
::
endl
;
DeviceMem
a_device_buf
(
sizeof
(
ADataType
)
*
a_m_k
.
mDesc
.
GetElementSpaceSize
());
DeviceMem
a_scale_device_buf
(
sizeof
(
XDataType
)
*
a_m_k_scale
.
mDesc
.
GetElementSpaceSize
());
DeviceMem
b_device_buf
(
sizeof
(
BDataType
)
*
b_k_n
.
mDesc
.
GetElementSpaceSize
());
DeviceMem
b_scale_device_buf
(
sizeof
(
XDataType
)
*
b_k_n_scale
.
mDesc
.
GetElementSpaceSize
());
DeviceMem
c_device_buf
(
sizeof
(
CDataType
)
*
c_m_n_device_result
.
mDesc
.
GetElementSpaceSize
());
if
(
config
.
verbosity
>
0
)
std
::
cout
<<
"Upload data to device..."
<<
std
::
endl
;
a_device_buf
.
ToDevice
(
a_m_k
.
mData
.
data
());
a_scale_device_buf
.
ToDevice
(
a_m_k_scale
.
mData
.
data
());
b_device_buf
.
ToDevice
(
b_k_n
.
mData
.
data
());
b_scale_device_buf
.
ToDevice
(
b_k_n_scale
.
mData
.
data
());
if
(
config
.
verbosity
>
0
)
std
::
cout
<<
"Done."
<<
std
::
endl
;
auto
a_element_op
=
AElementOp
{};
auto
b_element_op
=
BElementOp
{};
auto
cde_element_op
=
CDEElementOp
{};
constexpr
ck
::
index_t
NumDTensor
=
DsDataType
::
Size
();
// do GEMM
auto
device_op
=
DeviceOpInstance
{};
auto
invoker
=
device_op
.
MakeInvoker
();
auto
argument
=
device_op
.
MakeArgument
(
a_device_buf
.
GetDeviceBuffer
(),
b_device_buf
.
GetDeviceBuffer
(),
std
::
array
<
const
void
*
,
NumDTensor
>
{},
c_device_buf
.
GetDeviceBuffer
(),
M
,
N
,
K
,
StrideA
,
StrideB
,
std
::
array
<
ck
::
index_t
,
NumDTensor
>
{},
StrideC
,
a_scale_device_buf
.
GetDeviceBuffer
(),
b_scale_device_buf
.
GetDeviceBuffer
(),
a_element_op
,
b_element_op
,
cde_element_op
);
if
(
!
device_op
.
IsSupportedArgument
(
argument
))
{
throw
std
::
runtime_error
(
"wrong!
\n
"
"Provided combination of compilation and runtime parameters is "
"not consistent with the supported device_gemm arguments."
);
}
if
(
config
.
verbosity
>
0
)
std
::
cout
<<
"Computing GEMM on device..."
<<
std
::
endl
;
float
ave_time
=
invoker
.
Run
(
argument
,
StreamConfig
{
nullptr
,
config
.
time_kernel
,
config
.
verbosity
,
20
,
50
});
bool
res_verified
=
true
;
if
(
config
.
do_verification
>
0
)
{
c_device_buf
.
FromDevice
(
c_m_n_device_result
.
mData
.
data
());
if
(
config
.
verbosity
>
0
)
{
std
::
cout
<<
"Done."
<<
std
::
endl
;
std
::
cout
<<
"Computing GEMM on host..."
<<
std
::
endl
;
}
Tensor
<
CDataType
>
c
({
M
,
N
});
Tensor
<
float
>
a
({
M
,
K
});
Tensor
<
float
>
b
({
K
,
N
});
for
(
int
m
=
0
;
m
<
M
;
m
++
)
{
for
(
int
k
=
0
;
k
<
K
;
k
++
)
{
a
(
m
,
k
)
=
ck
::
type_convert
<
float
>
(
a_m_k
(
m
,
k
))
*
ck
::
type_convert
<
float
>
(
a_m_k_scale
(
m
,
k
/
Scale_Block_K
));
}
}
for
(
int
n
=
0
;
n
<
N
;
n
++
)
{
for
(
int
k
=
0
;
k
<
K
;
k
++
)
{
b
(
k
,
n
)
=
ck
::
type_convert
<
float
>
(
b_k_n
(
k
,
n
))
*
ck
::
type_convert
<
float
>
(
b_k_n_scale
(
k
/
Scale_Block_K
,
n
));
}
}
using
ReferenceGemmInstance
=
ck
::
tensor_operation
::
host
::
ReferenceGemm
<
float
,
float
,
CShuffleDataType
,
CDataType
,
PassThrough
,
PassThrough
,
PassThrough
>
;
auto
ref_gemm
=
ReferenceGemmInstance
{};
auto
ref_invoker
=
ref_gemm
.
MakeInvoker
();
auto
ref_argument
=
ref_gemm
.
MakeArgument
(
a
,
b
,
c
,
PassThrough
{},
PassThrough
{},
PassThrough
{});
ref_invoker
.
Run
(
ref_argument
);
if
(
config
.
verbosity
>
0
)
{
std
::
cout
<<
"Done."
<<
std
::
endl
;
std
::
cout
<<
"Comparing results..."
<<
std
::
endl
;
}
if
(
config
.
init_method
==
1
)
{
res_verified
=
res_verified
&&
std
::
abs
(
static_cast
<
float
>
(
K
)
-
c_m_n_device_result
(
0
,
0
))
<=
0.0
f
;
std
::
cout
<<
"Expected vs Computed: "
<<
1.0
f
*
K
<<
" vs "
<<
c_m_n_device_result
(
0
,
0
)
<<
((
res_verified
)
?
" (PASSED!)"
:
" (FAILED!)"
)
<<
std
::
endl
;
}
res_verified
=
res_verified
&&
ck
::
utils
::
check_err
(
c_m_n_device_result
,
c
,
"Error: Incorrect results!"
);
if
(
config
.
verbosity
>
0
&&
res_verified
)
std
::
cout
<<
"Done."
<<
std
::
endl
;
}
else
{
if
(
config
.
verbosity
>
0
)
std
::
cout
<<
"Done."
<<
std
::
endl
;
}
if
(
config
.
time_kernel
)
{
std
::
size_t
flop
=
std
::
size_t
(
2
)
*
M
*
N
*
K
+
M
*
K
+
K
*
N
;
// GEMM + A scale + B scale
std
::
size_t
num_btype
=
sizeof
(
ADataType
)
*
M
*
K
+
sizeof
(
BDataType
)
*
K
*
N
+
sizeof
(
CDataType
)
*
M
*
N
+
sizeof
(
XDataType
)
*
(
M
*
K
+
K
*
N
)
/
Scale_Block_K
;
float
tflops
=
static_cast
<
float
>
(
flop
)
/
1.E9
/
ave_time
;
float
gb_per_sec
=
num_btype
/
1.E6
/
ave_time
;
std
::
cout
<<
"Perf: "
<<
ave_time
<<
" ms, "
<<
tflops
<<
" TFlops, "
<<
gb_per_sec
<<
" GB/s"
<<
std
::
endl
;
}
return
res_verified
;
}
template
<
typename
ADataType
,
typename
BDataType
,
typename
XDataType
,
typename
CDataType
,
typename
ALayout
,
typename
BLayout
,
typename
CLayout
,
typename
CElementWiseOp
,
typename
AccDataType
,
typename
CShuffleDataType
,
ck
::
index_t
MXVectorSize
>
bool
run_mx_gemm_example
(
int
argc
,
char
*
argv
[])
{
ProblemSize
problem_size
;
ExecutionConfig
config
;
return
parse_cmd_args
(
argc
,
argv
,
problem_size
,
config
)
&&
run_mx_gemm
<
ADataType
,
BDataType
,
XDataType
,
CDataType
,
ALayout
,
BLayout
,
CLayout
,
CElementWiseOp
,
AccDataType
,
CShuffleDataType
,
MXVectorSize
>
(
problem_size
,
config
);
}
example/67_gemm_microscaling/gemm_mx_fp8.cpp
0 → 100644
View file @
dbb7002d
// SPDX-License-Identifier: MIT
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
#include "gemm_mx_common.hpp"
using
ADataType
=
ck
::
f8_t
;
using
BDataType
=
ck
::
f8_t
;
#if 1
// XXX: MX-native GEMM kernel will work with e8m0_bexp_t scale type
using
XDataType
=
float
;
#else
using
XDataType
=
ck
::
e8m0_bexp_t
;
#endif
using
AccDataType
=
float
;
using
CShuffleDataType
=
float
;
using
CDataType
=
float
;
using
ALayout
=
Row
;
using
BLayout
=
Col
;
using
CLayout
=
Row
;
using
CElementOp
=
PassThrough
;
// elementwise transformation for C matrix
constexpr
ck
::
index_t
mx_vector_size
=
128
;
// scaling block size
int
main
(
int
argc
,
char
*
argv
[])
{
return
run_mx_gemm_example
<
ADataType
,
BDataType
,
XDataType
,
CDataType
,
ALayout
,
BLayout
,
CLayout
,
CElementOp
,
AccDataType
,
CShuffleDataType
,
mx_vector_size
>
(
argc
,
argv
)
?
0
:
-
1
;
}
example/CMakeLists.txt
View file @
dbb7002d
...
@@ -23,34 +23,34 @@ function(add_example_executable EXAMPLE_NAME FILE_NAME)
...
@@ -23,34 +23,34 @@ function(add_example_executable EXAMPLE_NAME FILE_NAME)
message
(
"adding example
${
EXAMPLE_NAME
}
"
)
message
(
"adding example
${
EXAMPLE_NAME
}
"
)
set
(
result 1
)
set
(
result 1
)
if
(
DEFINED DTYPES
)
if
(
DEFINED DTYPES
)
foreach
(
source IN LISTS FILE_NAME
)
foreach
(
source IN LISTS FILE_NAME
)
set
(
test 0
)
set
(
test 0
)
if
((
source MATCHES
"_fp16"
OR source MATCHES
"_f16"
)
AND NOT
"fp16"
IN_LIST DTYPES
)
if
((
source MATCHES
"_fp16"
OR source MATCHES
"_f16"
)
AND NOT
"fp16"
IN_LIST DTYPES
)
set
(
test 1
)
set
(
test 1
)
endif
()
endif
()
if
((
source MATCHES
"_fp32"
OR source MATCHES
"_f32"
)
AND NOT
"fp32"
IN_LIST DTYPES
)
if
((
source MATCHES
"_fp32"
OR source MATCHES
"_f32"
)
AND NOT
"fp32"
IN_LIST DTYPES
)
set
(
test 1
)
set
(
test 1
)
endif
()
endif
()
if
((
source MATCHES
"_fp64"
OR source MATCHES
"_f64"
)
AND NOT
"fp64"
IN_LIST DTYPES
)
if
((
source MATCHES
"_fp64"
OR source MATCHES
"_f64"
)
AND NOT
"fp64"
IN_LIST DTYPES
)
set
(
test 1
)
set
(
test 1
)
endif
()
endif
()
if
((
source MATCHES
"_fp8"
OR source MATCHES
"_f8"
)
AND NOT
"fp8"
IN_LIST DTYPES
)
if
((
source MATCHES
"_fp8"
OR source MATCHES
"_f8"
)
AND NOT
"fp8"
IN_LIST DTYPES
)
set
(
test 1
)
set
(
test 1
)
endif
()
endif
()
if
((
source MATCHES
"_bf8"
OR source MATCHES
"_bf8"
)
AND NOT
"bf8"
IN_LIST DTYPES
)
if
((
source MATCHES
"_bf8"
OR source MATCHES
"_bf8"
)
AND NOT
"bf8"
IN_LIST DTYPES
)
set
(
test 1
)
set
(
test 1
)
endif
()
endif
()
if
((
source MATCHES
"_bf16"
OR source MATCHES
"_b16"
)
AND NOT
"bf16"
IN_LIST DTYPES
)
if
((
source MATCHES
"_bf16"
OR source MATCHES
"_b16"
)
AND NOT
"bf16"
IN_LIST DTYPES
)
set
(
test 1
)
set
(
test 1
)
endif
()
endif
()
if
((
source MATCHES
"_int8"
OR source MATCHES
"_i8"
)
AND NOT
"int8"
IN_LIST DTYPES
)
if
((
source MATCHES
"_int8"
OR source MATCHES
"_i8"
)
AND NOT
"int8"
IN_LIST DTYPES
)
set
(
test 1
)
set
(
test 1
)
endif
()
endif
()
if
(
test EQUAL 1
)
if
(
test EQUAL 1
)
message
(
"removing example source file
${
source
}
"
)
message
(
"removing example source file
${
source
}
"
)
list
(
REMOVE_ITEM FILE_NAME
"
${
source
}
"
)
list
(
REMOVE_ITEM FILE_NAME
"
${
source
}
"
)
endif
()
endif
()
endforeach
()
endforeach
()
endif
()
endif
()
set
(
EX_TARGETS
${
SUPPORTED_GPU_TARGETS
}
)
set
(
EX_TARGETS
${
SUPPORTED_GPU_TARGETS
}
)
...
@@ -83,6 +83,13 @@ function(add_example_executable EXAMPLE_NAME FILE_NAME)
...
@@ -83,6 +83,13 @@ function(add_example_executable EXAMPLE_NAME FILE_NAME)
list
(
REMOVE_ITEM FILE_NAME
"
${
source
}
"
)
list
(
REMOVE_ITEM FILE_NAME
"
${
source
}
"
)
endif
()
endif
()
endforeach
()
endforeach
()
#Do not build any microscaling examples if gfx950 target is not on the list
foreach
(
source IN LISTS FILE_NAME
)
if
(
NOT EX_TARGETS MATCHES
"gfx950"
AND source MATCHES
"_mx"
)
message
(
"removing microscaling example
${
source
}
"
)
list
(
REMOVE_ITEM FILE_NAME
"
${
source
}
"
)
endif
()
endforeach
()
#Do not build any FP8 examples if CK_ENABLE_FP8 not set
#Do not build any FP8 examples if CK_ENABLE_FP8 not set
foreach
(
source IN LISTS FILE_NAME
)
foreach
(
source IN LISTS FILE_NAME
)
if
(
NOT DEFINED CK_ENABLE_FP8 AND source MATCHES
"_fp8"
)
if
(
NOT DEFINED CK_ENABLE_FP8 AND source MATCHES
"_fp8"
)
...
@@ -102,7 +109,9 @@ function(add_example_executable EXAMPLE_NAME FILE_NAME)
...
@@ -102,7 +109,9 @@ function(add_example_executable EXAMPLE_NAME FILE_NAME)
if
(
FILE_NAME MATCHES
"_xdl"
)
if
(
FILE_NAME MATCHES
"_xdl"
)
list
(
REMOVE_ITEM EX_TARGETS gfx900 gfx906 gfx906:xnack- gfx1030 gfx1100 gfx1101 gfx1102 gfx1103 gfx1200 gfx1201 gfx10.3-generic gfx11-generic gfx12-generic
)
list
(
REMOVE_ITEM EX_TARGETS gfx900 gfx906 gfx906:xnack- gfx1030 gfx1100 gfx1101 gfx1102 gfx1103 gfx1200 gfx1201 gfx10.3-generic gfx11-generic gfx12-generic
)
elseif
(
FILE_NAME MATCHES
"_wmma"
)
elseif
(
FILE_NAME MATCHES
"_wmma"
)
list
(
REMOVE_ITEM EX_TARGETS gfx900 gfx906 gfx906:xnack- gfx908:xnack+ gfx908:xnack- gfx90a:xnack+ gfx90a:xnack- gfx908 gfx90a gfx940 gfx941 gfx942 gfx1030
)
list
(
REMOVE_ITEM EX_TARGETS gfx900 gfx906 gfx906:xnack- gfx908:xnack+ gfx908:xnack- gfx90a:xnack+ gfx90a:xnack- gfx908 gfx90a gfx940 gfx941 gfx942 gfx1030 gfx950
)
elseif
(
FILE_NAME MATCHES
"_mx"
)
#only build mx example for gfx950
list
(
REMOVE_ITEM EX_TARGETS gfx900 gfx906 gfx906:xnack- gfx908:xnack+ gfx908:xnack- gfx90a:xnack+ gfx90a:xnack- gfx908 gfx90a gfx940 gfx941 gfx942 gfx1030 gfx1100 gfx1101 gfx1102 gfx1103 gfx1200 gfx1201 gfx10.3-generic gfx11-generic gfx12-generic
)
endif
()
endif
()
set_source_files_properties
(
${
FILE_NAME
}
PROPERTIES LANGUAGE HIP
)
set_source_files_properties
(
${
FILE_NAME
}
PROPERTIES LANGUAGE HIP
)
add_executable
(
${
EXAMPLE_NAME
}
${
FILE_NAME
}
)
add_executable
(
${
EXAMPLE_NAME
}
${
FILE_NAME
}
)
...
@@ -195,7 +204,7 @@ function(add_example_executable_no_testing EXAMPLE_NAME FILE_NAME)
...
@@ -195,7 +204,7 @@ function(add_example_executable_no_testing EXAMPLE_NAME FILE_NAME)
if
(
FILE_NAME MATCHES
"_xdl"
)
if
(
FILE_NAME MATCHES
"_xdl"
)
list
(
REMOVE_ITEM EX_TARGETS gfx900 gfx906 gfx906:xnack- gfx1030 gfx1100 gfx1101 gfx1102 gfx1103 gfx1200 gfx1201 gfx10.3-generic gfx11-generic gfx12-generic
)
list
(
REMOVE_ITEM EX_TARGETS gfx900 gfx906 gfx906:xnack- gfx1030 gfx1100 gfx1101 gfx1102 gfx1103 gfx1200 gfx1201 gfx10.3-generic gfx11-generic gfx12-generic
)
elseif
(
FILE_NAME MATCHES
"_wmma"
)
elseif
(
FILE_NAME MATCHES
"_wmma"
)
list
(
REMOVE_ITEM EX_TARGETS gfx900 gfx906 gfx906:xnack- gfx908:xnack+ gfx908:xnack- gfx90a:xnack+ gfx90a:xnack- gfx908 gfx90a gfx940 gfx941 gfx942 gfx1030
)
list
(
REMOVE_ITEM EX_TARGETS gfx900 gfx906 gfx906:xnack- gfx908:xnack+ gfx908:xnack- gfx90a:xnack+ gfx90a:xnack- gfx908 gfx90a gfx940 gfx941 gfx942 gfx1030
gfx950
)
endif
()
endif
()
set_source_files_properties
(
${
FILE_NAME
}
PROPERTIES LANGUAGE HIP
)
set_source_files_properties
(
${
FILE_NAME
}
PROPERTIES LANGUAGE HIP
)
add_executable
(
${
EXAMPLE_NAME
}
${
FILE_NAME
}
)
add_executable
(
${
EXAMPLE_NAME
}
${
FILE_NAME
}
)
...
...
example/ck_tile/03_gemm/script/benchmark_basic.sh
View file @
dbb7002d
#!/bin/sh
#!/bin/sh
EXE
=
"
$(
find
.
-name
tile_example_gemm_basic
-type
f |
head
-n
1
)
"
EXE
=
"
$(
find
.
-name
tile_example_gemm_basic
-type
f |
head
-n
1
)
"
VALID
=
0
VALID
=
1
for
b_matrix_layout
in
"R"
"C"
;
do
for
b_matrix_layout
in
"R"
"C"
;
do
for
m
in
"64"
"512"
"1024"
"2048"
;
do
for
m
in
"64"
"512"
"1024"
"2048"
;
do
for
n
in
"512"
"1024"
"2048"
;
do
for
n
in
"512"
"1024"
"2048"
;
do
for
k
in
"64"
"512"
"1024"
"2048"
;
do
for
k
in
"64"
"512"
"1024"
"2048"
;
do
$EXE
-prec
=
fp16
-b
=
1
-m
=
$m
-n
=
$n
-k
=
$k
-a_layout
=
"R"
-b_layout
=
"
$b_matrix_layout
"
-c_layout
=
"R"
-v
=
$VALID
$EXE
-prec
=
fp16
-m
=
$m
-n
=
$n
-k
=
$k
-a_layout
=
"R"
-b_layout
=
"
$b_matrix_layout
"
-c_layout
=
"R"
-v
=
$VALID
done
done
done
done
done
done
...
...
example/ck_tile/03_gemm/script/benchmark_mem_pipeline.sh
View file @
dbb7002d
#!/bin/sh
#!/bin/sh
EXE
=
"
$(
find
.
-name
tile_example_gemm_universal
-type
f |
head
-n
1
)
"
EXE
=
"
$(
find
.
-name
tile_example_gemm_universal
-type
f |
head
-n
1
)
"
VALID
=
0
VALID
=
1
for
b_matrix_layout
in
"R"
"C"
;
do
for
b_matrix_layout
in
"R"
"C"
;
do
for
m
in
"64"
"512"
"1024"
"2048"
;
do
for
m
in
"64"
"512"
"1024"
"2048"
;
do
for
n
in
"512"
"1024"
"2048"
;
do
for
n
in
"512"
"1024"
"2048"
;
do
for
k
in
"64"
"512"
"1024"
"2048"
;
do
for
k
in
"64"
"512"
"1024"
"2048"
;
do
$EXE
-prec
=
fp16
-b
=
1
-m
=
$m
-n
=
$n
-k
=
$k
-a_layout
=
"R"
-b_layout
=
"
$b_matrix_layout
"
-c_layout
=
"R"
-v
=
$VALID
$EXE
-prec
=
fp16
-m
=
$m
-n
=
$n
-k
=
$k
-a_layout
=
"R"
-b_layout
=
"
$b_matrix_layout
"
-c_layout
=
"R"
-v
=
$VALID
done
done
done
done
done
done
...
...
include/ck/README.md
View file @
dbb7002d
[
Back to the main page
](
../../README.md
)
[
Back to the main page
](
../../README.md
)
# Composable Kernel supported operations
# Composable Kernel supported operations
## Supported device operations
## Supported device operations
*
[
Average pooling
](
)
<!-- * [Average pooling](../../docs/markdown/tensor_operation/average_pooling.md) -->
*
[
Batched contraction
](
)
<!-- * [Batched contraction](../../docs/markdown/tensor_operation/batched_contraction.md) -->
*
[
Batched gemm
](
)
<!-- * [Batched gemm](../../docs/markdown/tensor_operation/batched_gemm.md) -->
*
[
Batchnorm
](
)
<!-- * [Batchnorm](../../docs/markdown/tensor_operation/batchnorm.md) -->
*
[
CGEMM
](
)
<!-- * [CGEMM](../../docs/markdown/tensor_operation/cgemm.md) -->
*
[
Contraction
](
)
<!-- * [Contraction](../../docs/markdown/tensor_operation/contraction.md) -->
*
[
Convolution
](
)
<!-- * [Convolution](../../docs/markdown/tensor_operation/convolution.md) -->
*
[
Image to Column and Column to Image
](
)
<!-- * [Elementwise](../../docs/markdown/tensor_operation/elementwise.md) -->
*
[
Elementwise
](
)
*
[
GEMM
](
../../client_example/01_gemm/README.md
)
*
[
GEMM
](
)
*
[
Grouped Convolution Forward
](
../../client_example/07_grouped_convnd_fwd/README.md
)
*
[
Max pooling
](
)
*
[
Grouped Convolution Backward Data
](
../../client_example/10_grouped_convnd_bwd_data/README.md
)
*
[
Reduce
](
)
*
[
Grouped Convolution Backward Weight
](
../../client_example/11_grouped_conv_bwd_weight/README.md
)
*
[
Normalization
](
)
<!-- * [Grouped GEMM](../../docs/markdown/tensor_operation/grouped_gemm.md) -->
*
[
Permute
](
)
<!-- * [Image to Column and Column to Image](../../docs/markdown/tensor_operation/img2col.md) -->
*
[
Put
](
)
<!-- * [Max pooling](../../docs/markdown/tensor_operation/max_pooling.md) -->
*
[
Softmax
](
)
<!-- * [Reduce](../../docs/markdown/tensor_operation/reduce.md) -->
<!-- * [Normalization](../../docs/markdown/tensor_operation/normalization.md) -->
<!-- * [Permute](../../docs/markdown/tensor_operation/permute.md) -->
<!-- * [Put](../../docs/markdown/tensor_operation/put.md) -->
<!-- * [Softmax](../../docs/markdown/tensor_operation/softmax.md) -->
include/ck/ck.hpp
View file @
dbb7002d
...
@@ -5,7 +5,7 @@
...
@@ -5,7 +5,7 @@
#include "ck/config.h"
#include "ck/config.h"
#include "ck/utility/env.hpp"
#include "ck/utility/env.hpp"
#ifndef CK_CODE_GEN_RTC
#ifndef CK_DONT_USE_HIP_RUNTIME_HEADERS
#ifndef CK_DONT_USE_HIP_RUNTIME_HEADERS
#include "hip/hip_runtime.h"
#include "hip/hip_runtime.h"
#include "hip/hip_fp16.h"
#include "hip/hip_fp16.h"
...
@@ -14,7 +14,7 @@
...
@@ -14,7 +14,7 @@
// environment variable to enable logging:
// environment variable to enable logging:
// export CK_LOGGING=ON or CK_LOGGING=1 or CK_LOGGING=ENABLED
// export CK_LOGGING=ON or CK_LOGGING=1 or CK_LOGGING=ENABLED
CK_DECLARE_ENV_VAR_BOOL
(
CK_LOGGING
)
CK_DECLARE_ENV_VAR_BOOL
(
CK_LOGGING
)
#endif
// to do: add various levels of logging with CK_LOG_LEVEL
// to do: add various levels of logging with CK_LOG_LEVEL
#ifndef CK_TIME_KERNEL
#ifndef CK_TIME_KERNEL
...
@@ -55,10 +55,10 @@ CK_DECLARE_ENV_VAR_BOOL(CK_LOGGING)
...
@@ -55,10 +55,10 @@ CK_DECLARE_ENV_VAR_BOOL(CK_LOGGING)
// define general macros for various architectures
// define general macros for various architectures
#if defined(__gfx908__) || defined(__gfx90a__) || defined(__gfx940__) || defined(__gfx941__) || \
#if defined(__gfx908__) || defined(__gfx90a__) || defined(__gfx940__) || defined(__gfx941__) || \
defined(__gfx942__)
defined(__gfx942__)
|| defined(__gfx950__)
#define __gfx9__
#define __gfx9__
#endif
#endif
#if defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__)
#if defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__)
|| defined(__gfx950__)
#define __gfx94__
#define __gfx94__
#endif
#endif
#if defined(__gfx1010__) || defined(__gfx1011__) || defined(__gfx1012__)
#if defined(__gfx1010__) || defined(__gfx1011__) || defined(__gfx1012__)
...
@@ -163,6 +163,16 @@ CK_DECLARE_ENV_VAR_BOOL(CK_LOGGING)
...
@@ -163,6 +163,16 @@ CK_DECLARE_ENV_VAR_BOOL(CK_LOGGING)
// set rounding to nearest even as default for f8 conversions
// set rounding to nearest even as default for f8 conversions
#define CK_USE_SR_F8_CONVERSION 0
#define CK_USE_SR_F8_CONVERSION 0
// set rounding to nearest even as default for f6 conversions
#define CK_USE_SR_F6_CONVERSION 0
// set rounding to nearest even as default for f4 conversions
#define CK_USE_SR_F4_CONVERSION 0
// shuffle pk_i4 values during conversion to optimize number of binary
// operations
#define CK_USE_PK4_LAYOUT_SHUFFLE 1
// block synchronization only s_wait lgkmcnt(0), not vmcnt(0)
// block synchronization only s_wait lgkmcnt(0), not vmcnt(0)
#define CK_EXPERIMENTAL_BLOCK_SYNC_LDS_WITHOUT_SYNC_VMEM 1
#define CK_EXPERIMENTAL_BLOCK_SYNC_LDS_WITHOUT_SYNC_VMEM 1
...
...
include/ck/config.h.in
View file @
dbb7002d
...
@@ -131,6 +131,10 @@
...
@@ -131,6 +131,10 @@
#cmakedefine CK_USE_FP8_ON_UNSUPPORTED_ARCH @CK_USE_FP8_ON_UNSUPPORTED_ARCH@
#cmakedefine CK_USE_FP8_ON_UNSUPPORTED_ARCH @CK_USE_FP8_ON_UNSUPPORTED_ARCH@
#endif
#endif
#ifndef CK_USE_NATIVE_MX_SUPPORT
#cmakedefine CK_USE_NATIVE_MX_SUPPORT @CK_USE_NATIVE_MX_SUPPORT@
#endif
// clang-format on
// clang-format on
#endif // CK_CONFIG_H_IN
#endif // CK_CONFIG_H_IN
include/ck/host_utility/device_prop.hpp
View file @
dbb7002d
...
@@ -55,20 +55,21 @@ inline bool is_xdl_supported()
...
@@ -55,20 +55,21 @@ inline bool is_xdl_supported()
{
{
return
ck
::
get_device_name
()
==
"gfx908"
||
ck
::
get_device_name
()
==
"gfx90a"
||
return
ck
::
get_device_name
()
==
"gfx908"
||
ck
::
get_device_name
()
==
"gfx90a"
||
ck
::
get_device_name
()
==
"gfx940"
||
ck
::
get_device_name
()
==
"gfx941"
||
ck
::
get_device_name
()
==
"gfx940"
||
ck
::
get_device_name
()
==
"gfx941"
||
ck
::
get_device_name
()
==
"gfx942"
;
ck
::
get_device_name
()
==
"gfx942"
||
ck
::
get_device_name
()
==
"gfx950"
;
}
}
inline
bool
is_lds_direct_load_supported
()
inline
bool
is_lds_direct_load_supported
()
{
{
// Check if direct loads from global memory to LDS are supported.
// Check if direct loads from global memory to LDS are supported.
return
ck
::
get_device_name
()
==
"gfx90a"
||
ck
::
get_device_name
()
==
"gfx940"
||
return
ck
::
get_device_name
()
==
"gfx90a"
||
ck
::
get_device_name
()
==
"gfx940"
||
ck
::
get_device_name
()
==
"gfx941"
||
ck
::
get_device_name
()
==
"gfx942"
;
ck
::
get_device_name
()
==
"gfx941"
||
ck
::
get_device_name
()
==
"gfx942"
||
ck
::
get_device_name
()
==
"gfx950"
;
}
}
inline
bool
is_bf16_atomic_supported
()
inline
bool
is_bf16_atomic_supported
()
{
{
return
ck
::
get_device_name
()
==
"gfx940"
||
ck
::
get_device_name
()
==
"gfx941"
||
return
ck
::
get_device_name
()
==
"gfx940"
||
ck
::
get_device_name
()
==
"gfx941"
||
ck
::
get_device_name
()
==
"gfx942"
;
ck
::
get_device_name
()
==
"gfx942"
||
ck
::
get_device_name
()
==
"gfx950"
;
}
}
inline
bool
is_gfx101_supported
()
inline
bool
is_gfx101_supported
()
...
...
include/ck/library/utility/check_err.hpp
View file @
dbb7002d
...
@@ -26,6 +26,7 @@ namespace utils {
...
@@ -26,6 +26,7 @@ namespace utils {
template
<
typename
ComputeDataType
,
typename
OutDataType
,
typename
AccDataType
=
ComputeDataType
>
template
<
typename
ComputeDataType
,
typename
OutDataType
,
typename
AccDataType
=
ComputeDataType
>
double
get_relative_threshold
(
const
int
number_of_accumulations
=
1
)
double
get_relative_threshold
(
const
int
number_of_accumulations
=
1
)
{
{
using
F4
=
ck
::
f4_t
;
using
F8
=
ck
::
f8_t
;
using
F8
=
ck
::
f8_t
;
using
F16
=
ck
::
half_t
;
using
F16
=
ck
::
half_t
;
using
BF16
=
ck
::
bhalf_t
;
using
BF16
=
ck
::
bhalf_t
;
...
@@ -33,10 +34,10 @@ double get_relative_threshold(const int number_of_accumulations = 1)
...
@@ -33,10 +34,10 @@ double get_relative_threshold(const int number_of_accumulations = 1)
using
I8
=
int8_t
;
using
I8
=
int8_t
;
using
I32
=
int32_t
;
using
I32
=
int32_t
;
static_assert
(
is_same_v
<
ComputeDataType
,
F
8
>
||
is_same_v
<
ComputeDataType
,
F
16
>
||
static_assert
(
is_same_v
<
ComputeDataType
,
F
4
>
||
is_same_v
<
ComputeDataType
,
F
8
>
||
is_same_v
<
ComputeDataType
,
B
F16
>
||
is_same_v
<
ComputeDataType
,
F32
>
||
is_same_v
<
ComputeDataType
,
F16
>
||
is_same_v
<
ComputeDataType
,
BF16
>
||
is_same_v
<
ComputeDataType
,
I8
>
||
is_same_v
<
ComputeDataType
,
I
32
>
||
is_same_v
<
ComputeDataType
,
F32
>
||
is_same_v
<
ComputeDataType
,
I
8
>
||
is_same_v
<
ComputeDataType
,
int
>
,
is_same_v
<
ComputeDataType
,
I32
>
||
is_same_v
<
ComputeDataType
,
int
>
,
"Warning: Unhandled ComputeDataType for setting up the relative threshold!"
);
"Warning: Unhandled ComputeDataType for setting up the relative threshold!"
);
double
compute_error
=
0
;
double
compute_error
=
0
;
if
constexpr
(
is_same_v
<
ComputeDataType
,
I8
>
||
is_same_v
<
ComputeDataType
,
I32
>
||
if
constexpr
(
is_same_v
<
ComputeDataType
,
I8
>
||
is_same_v
<
ComputeDataType
,
I32
>
||
...
@@ -49,10 +50,10 @@ double get_relative_threshold(const int number_of_accumulations = 1)
...
@@ -49,10 +50,10 @@ double get_relative_threshold(const int number_of_accumulations = 1)
compute_error
=
std
::
pow
(
2
,
-
NumericUtils
<
ComputeDataType
>::
mant
)
*
0.5
;
compute_error
=
std
::
pow
(
2
,
-
NumericUtils
<
ComputeDataType
>::
mant
)
*
0.5
;
}
}
static_assert
(
is_same_v
<
OutDataType
,
F
8
>
||
is_same_v
<
OutDataType
,
F
16
>
||
static_assert
(
is_same_v
<
OutDataType
,
F
4
>
||
is_same_v
<
OutDataType
,
F
8
>
||
is_same_v
<
OutDataType
,
B
F16
>
||
is_same_v
<
OutDataType
,
F32
>
||
is_same_v
<
OutDataType
,
F16
>
||
is_same_v
<
OutDataType
,
BF16
>
||
is_same_v
<
OutDataType
,
I8
>
||
is_same_v
<
OutDataType
,
I
32
>
||
is_same_v
<
OutDataType
,
F32
>
||
is_same_v
<
OutDataType
,
I
8
>
||
is_same_v
<
OutDataType
,
int
>
,
is_same_v
<
OutDataType
,
I32
>
||
is_same_v
<
OutDataType
,
int
>
,
"Warning: Unhandled OutDataType for setting up the relative threshold!"
);
"Warning: Unhandled OutDataType for setting up the relative threshold!"
);
double
output_error
=
0
;
double
output_error
=
0
;
if
constexpr
(
is_same_v
<
OutDataType
,
I8
>
||
is_same_v
<
OutDataType
,
I32
>
||
if
constexpr
(
is_same_v
<
OutDataType
,
I8
>
||
is_same_v
<
OutDataType
,
I32
>
||
...
@@ -66,10 +67,10 @@ double get_relative_threshold(const int number_of_accumulations = 1)
...
@@ -66,10 +67,10 @@ double get_relative_threshold(const int number_of_accumulations = 1)
}
}
double
midway_error
=
std
::
max
(
compute_error
,
output_error
);
double
midway_error
=
std
::
max
(
compute_error
,
output_error
);
static_assert
(
is_same_v
<
AccDataType
,
F
8
>
||
is_same_v
<
AccDataType
,
F
16
>
||
static_assert
(
is_same_v
<
AccDataType
,
F
4
>
||
is_same_v
<
AccDataType
,
F
8
>
||
is_same_v
<
AccDataType
,
B
F16
>
||
is_same_v
<
AccDataType
,
F32
>
||
is_same_v
<
AccDataType
,
F16
>
||
is_same_v
<
AccDataType
,
BF16
>
||
is_same_v
<
AccDataType
,
I8
>
||
is_same_v
<
AccDataType
,
I
32
>
||
is_same_v
<
AccDataType
,
F32
>
||
is_same_v
<
AccDataType
,
I
8
>
||
is_same_v
<
AccDataType
,
int
>
,
is_same_v
<
AccDataType
,
I32
>
||
is_same_v
<
AccDataType
,
int
>
,
"Warning: Unhandled AccDataType for setting up the relative threshold!"
);
"Warning: Unhandled AccDataType for setting up the relative threshold!"
);
double
acc_error
=
0
;
double
acc_error
=
0
;
if
constexpr
(
is_same_v
<
AccDataType
,
I8
>
||
is_same_v
<
AccDataType
,
I32
>
||
if
constexpr
(
is_same_v
<
AccDataType
,
I8
>
||
is_same_v
<
AccDataType
,
I32
>
||
...
@@ -87,6 +88,7 @@ double get_relative_threshold(const int number_of_accumulations = 1)
...
@@ -87,6 +88,7 @@ double get_relative_threshold(const int number_of_accumulations = 1)
template
<
typename
ComputeDataType
,
typename
OutDataType
,
typename
AccDataType
=
ComputeDataType
>
template
<
typename
ComputeDataType
,
typename
OutDataType
,
typename
AccDataType
=
ComputeDataType
>
double
get_absolute_threshold
(
const
double
max_possible_num
,
const
int
number_of_accumulations
=
1
)
double
get_absolute_threshold
(
const
double
max_possible_num
,
const
int
number_of_accumulations
=
1
)
{
{
using
F4
=
ck
::
f4_t
;
using
F8
=
ck
::
f8_t
;
using
F8
=
ck
::
f8_t
;
using
F16
=
ck
::
half_t
;
using
F16
=
ck
::
half_t
;
using
BF16
=
ck
::
bhalf_t
;
using
BF16
=
ck
::
bhalf_t
;
...
@@ -94,10 +96,10 @@ double get_absolute_threshold(const double max_possible_num, const int number_of
...
@@ -94,10 +96,10 @@ double get_absolute_threshold(const double max_possible_num, const int number_of
using
I8
=
int8_t
;
using
I8
=
int8_t
;
using
I32
=
int32_t
;
using
I32
=
int32_t
;
static_assert
(
is_same_v
<
ComputeDataType
,
F
8
>
||
is_same_v
<
ComputeDataType
,
F
16
>
||
static_assert
(
is_same_v
<
ComputeDataType
,
F
4
>
||
is_same_v
<
ComputeDataType
,
F
8
>
||
is_same_v
<
ComputeDataType
,
B
F16
>
||
is_same_v
<
ComputeDataType
,
F32
>
||
is_same_v
<
ComputeDataType
,
F16
>
||
is_same_v
<
ComputeDataType
,
BF16
>
||
is_same_v
<
ComputeDataType
,
I8
>
||
is_same_v
<
ComputeDataType
,
I
32
>
||
is_same_v
<
ComputeDataType
,
F32
>
||
is_same_v
<
ComputeDataType
,
I
8
>
||
is_same_v
<
ComputeDataType
,
int
>
,
is_same_v
<
ComputeDataType
,
I32
>
||
is_same_v
<
ComputeDataType
,
int
>
,
"Warning: Unhandled ComputeDataType for setting up the absolute threshold!"
);
"Warning: Unhandled ComputeDataType for setting up the absolute threshold!"
);
auto
expo
=
std
::
log2
(
std
::
abs
(
max_possible_num
));
auto
expo
=
std
::
log2
(
std
::
abs
(
max_possible_num
));
double
compute_error
=
0
;
double
compute_error
=
0
;
...
@@ -111,10 +113,10 @@ double get_absolute_threshold(const double max_possible_num, const int number_of
...
@@ -111,10 +113,10 @@ double get_absolute_threshold(const double max_possible_num, const int number_of
compute_error
=
std
::
pow
(
2
,
expo
-
NumericUtils
<
ComputeDataType
>::
mant
)
*
0.5
;
compute_error
=
std
::
pow
(
2
,
expo
-
NumericUtils
<
ComputeDataType
>::
mant
)
*
0.5
;
}
}
static_assert
(
is_same_v
<
OutDataType
,
F
8
>
||
is_same_v
<
OutDataType
,
F
16
>
||
static_assert
(
is_same_v
<
OutDataType
,
F
4
>
||
is_same_v
<
OutDataType
,
F
8
>
||
is_same_v
<
OutDataType
,
B
F16
>
||
is_same_v
<
OutDataType
,
F32
>
||
is_same_v
<
OutDataType
,
F16
>
||
is_same_v
<
OutDataType
,
BF16
>
||
is_same_v
<
OutDataType
,
I8
>
||
is_same_v
<
OutDataType
,
I
32
>
||
is_same_v
<
OutDataType
,
F32
>
||
is_same_v
<
OutDataType
,
I
8
>
||
is_same_v
<
OutDataType
,
int
>
,
is_same_v
<
OutDataType
,
I32
>
||
is_same_v
<
OutDataType
,
int
>
,
"Warning: Unhandled OutDataType for setting up the absolute threshold!"
);
"Warning: Unhandled OutDataType for setting up the absolute threshold!"
);
double
output_error
=
0
;
double
output_error
=
0
;
if
constexpr
(
is_same_v
<
OutDataType
,
I8
>
||
is_same_v
<
OutDataType
,
I32
>
||
if
constexpr
(
is_same_v
<
OutDataType
,
I8
>
||
is_same_v
<
OutDataType
,
I32
>
||
...
@@ -128,10 +130,10 @@ double get_absolute_threshold(const double max_possible_num, const int number_of
...
@@ -128,10 +130,10 @@ double get_absolute_threshold(const double max_possible_num, const int number_of
}
}
double
midway_error
=
std
::
max
(
compute_error
,
output_error
);
double
midway_error
=
std
::
max
(
compute_error
,
output_error
);
static_assert
(
is_same_v
<
AccDataType
,
F
8
>
||
is_same_v
<
AccDataType
,
F
16
>
||
static_assert
(
is_same_v
<
AccDataType
,
F
4
>
||
is_same_v
<
AccDataType
,
F
8
>
||
is_same_v
<
AccDataType
,
B
F16
>
||
is_same_v
<
AccDataType
,
F32
>
||
is_same_v
<
AccDataType
,
F16
>
||
is_same_v
<
AccDataType
,
BF16
>
||
is_same_v
<
AccDataType
,
I8
>
||
is_same_v
<
AccDataType
,
I
32
>
||
is_same_v
<
AccDataType
,
F32
>
||
is_same_v
<
AccDataType
,
I
8
>
||
is_same_v
<
AccDataType
,
int
>
,
is_same_v
<
AccDataType
,
I32
>
||
is_same_v
<
AccDataType
,
int
>
,
"Warning: Unhandled AccDataType for setting up the absolute threshold!"
);
"Warning: Unhandled AccDataType for setting up the absolute threshold!"
);
double
acc_error
=
0
;
double
acc_error
=
0
;
if
constexpr
(
is_same_v
<
AccDataType
,
I8
>
||
is_same_v
<
AccDataType
,
I32
>
||
if
constexpr
(
is_same_v
<
AccDataType
,
I8
>
||
is_same_v
<
AccDataType
,
I32
>
||
...
@@ -450,5 +452,54 @@ check_err(const Range& out,
...
@@ -450,5 +452,54 @@ check_err(const Range& out,
return
res
;
return
res
;
}
}
template
<
typename
Range
,
typename
RefRange
>
std
::
enable_if_t
<
(
std
::
is_same_v
<
ranges
::
range_value_t
<
Range
>
,
ranges
::
range_value_t
<
RefRange
>>
&&
std
::
is_same_v
<
ranges
::
range_value_t
<
Range
>
,
f4_t
>
),
bool
>
check_err
(
const
Range
&
out
,
const
RefRange
&
ref
,
const
std
::
string
&
msg
=
"Error: Incorrect results!"
,
double
rtol
=
0.5
,
double
atol
=
0.5
)
{
if
(
out
.
size
()
!=
ref
.
size
())
{
std
::
cerr
<<
msg
<<
" out.size() != ref.size(), :"
<<
out
.
size
()
<<
" != "
<<
ref
.
size
()
<<
std
::
endl
;
return
false
;
}
bool
res
{
true
};
int
err_count
=
0
;
double
err
=
0
;
double
max_err
=
std
::
numeric_limits
<
float
>::
min
();
for
(
std
::
size_t
i
=
0
;
i
<
ref
.
size
();
++
i
)
{
const
double
o
=
type_convert
<
float
>
(
*
std
::
next
(
std
::
begin
(
out
),
i
));
const
double
r
=
type_convert
<
float
>
(
*
std
::
next
(
std
::
begin
(
ref
),
i
));
err
=
std
::
abs
(
o
-
r
);
if
(
err
>
atol
+
rtol
*
std
::
abs
(
r
)
||
!
std
::
isfinite
(
o
)
||
!
std
::
isfinite
(
r
))
{
max_err
=
err
>
max_err
?
err
:
max_err
;
err_count
++
;
if
(
err_count
<
5
)
{
std
::
cerr
<<
msg
<<
std
::
setw
(
12
)
<<
std
::
setprecision
(
7
)
<<
" out["
<<
i
<<
"] != ref["
<<
i
<<
"]: "
<<
o
<<
" != "
<<
r
<<
std
::
endl
;
}
res
=
false
;
}
}
if
(
!
res
)
{
std
::
cerr
<<
std
::
setw
(
12
)
<<
std
::
setprecision
(
7
)
<<
"max err: "
<<
max_err
<<
" number of errors: "
<<
err_count
<<
std
::
endl
;
}
return
res
;
}
}
// namespace utils
}
// namespace utils
}
// namespace ck
}
// namespace ck
include/ck/library/utility/host_tensor_generator.hpp
View file @
dbb7002d
...
@@ -69,6 +69,18 @@ struct GeneratorTensor_1<ck::f8_t>
...
@@ -69,6 +69,18 @@ struct GeneratorTensor_1<ck::f8_t>
};
};
#endif
#endif
template
<
>
struct
GeneratorTensor_1
<
ck
::
f4_t
>
{
float
value
=
1.0
;
template
<
typename
...
Is
>
ck
::
f4_t
operator
()(
Is
...)
{
return
ck
::
type_convert
<
ck
::
f4_t
>
(
value
);
}
};
template
<
>
template
<
>
struct
GeneratorTensor_1
<
int8_t
>
struct
GeneratorTensor_1
<
int8_t
>
{
{
...
@@ -183,6 +195,20 @@ struct GeneratorTensor_2<ck::bf8_t>
...
@@ -183,6 +195,20 @@ struct GeneratorTensor_2<ck::bf8_t>
};
};
#endif
#endif
template
<
>
struct
GeneratorTensor_2
<
ck
::
f4_t
>
{
int
min_value
=
0
;
int
max_value
=
1
;
template
<
typename
...
Is
>
ck
::
f4_t
operator
()(
Is
...)
{
float
tmp
=
(
std
::
rand
()
%
(
max_value
-
min_value
))
+
min_value
;
return
ck
::
type_convert
<
ck
::
f4_t
>
(
tmp
);
}
};
template
<
typename
T
>
template
<
typename
T
>
struct
GeneratorTensor_3
struct
GeneratorTensor_3
{
{
...
@@ -253,6 +279,23 @@ struct GeneratorTensor_3<ck::bf8_t>
...
@@ -253,6 +279,23 @@ struct GeneratorTensor_3<ck::bf8_t>
};
};
#endif
#endif
template
<
>
struct
GeneratorTensor_3
<
ck
::
f4_t
>
{
float
min_value
=
0
;
float
max_value
=
1
;
template
<
typename
...
Is
>
ck
::
f4_t
operator
()(
Is
...)
{
float
tmp
=
float
(
std
::
rand
())
/
float
(
RAND_MAX
);
float
fp32_tmp
=
min_value
+
tmp
*
(
max_value
-
min_value
);
return
ck
::
type_convert
<
ck
::
f4_t
>
(
fp32_tmp
);
}
};
template
<
typename
T
>
template
<
typename
T
>
struct
GeneratorTensor_4
struct
GeneratorTensor_4
{
{
...
...
include/ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v7r2.hpp
View file @
dbb7002d
// SPDX-License-Identifier: MIT
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-202
3
, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-202
5
, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#pragma once
...
@@ -131,7 +131,7 @@ struct ThreadGroupTensorSliceTransfer_v7r2
...
@@ -131,7 +131,7 @@ struct ThreadGroupTensorSliceTransfer_v7r2
}
}
template
<
typename
T
>
template
<
typename
T
>
using
is_tuple
=
decltype
(
std
::
declval
<
T
&>
().
IsTuple
());
using
is_tuple
=
decltype
(
ck
::
declval
<
T
&>
().
IsTuple
());
template
<
typename
DstBuffers
,
index_t
ThreadScratchId
=
0
>
template
<
typename
DstBuffers
,
index_t
ThreadScratchId
=
0
>
__device__
void
RunWrite
(
const
DstDescs
&
dst_descs
,
__device__
void
RunWrite
(
const
DstDescs
&
dst_descs
,
...
...
include/ck/tensor_operation/gpu/device/convolution_forward_specialization.hpp
View file @
dbb7002d
// SPDX-License-Identifier: MIT
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-202
4
, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-202
5
, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#pragma once
#ifndef CK_CODE_GEN_RTC
#include <string>
#include <string>
#endif
namespace
ck
{
namespace
ck
{
namespace
tensor_operation
{
namespace
tensor_operation
{
...
@@ -18,6 +20,7 @@ enum struct ConvolutionForwardSpecialization
...
@@ -18,6 +20,7 @@ enum struct ConvolutionForwardSpecialization
Filter3x3
,
Filter3x3
,
};
};
#ifndef CK_CODE_GEN_RTC
inline
std
::
string
getConvForwardSpecializationString
(
const
ConvolutionForwardSpecialization
&
s
)
inline
std
::
string
getConvForwardSpecializationString
(
const
ConvolutionForwardSpecialization
&
s
)
{
{
switch
(
s
)
switch
(
s
)
...
@@ -30,6 +33,7 @@ inline std::string getConvForwardSpecializationString(const ConvolutionForwardSp
...
@@ -30,6 +33,7 @@ inline std::string getConvForwardSpecializationString(const ConvolutionForwardSp
default:
return
"Unrecognized specialization!"
;
default:
return
"Unrecognized specialization!"
;
}
}
}
}
#endif
}
// namespace device
}
// namespace device
}
// namespace tensor_operation
}
// namespace tensor_operation
...
...
include/ck/tensor_operation/gpu/device/device_base.hpp
View file @
dbb7002d
// SPDX-License-Identifier: MIT
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-202
3
, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-202
5
, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#pragma once
#ifndef CK_CODE_GEN_RTC
#include <string>
#include <string>
#include <sstream>
#include <sstream>
#include <regex>
#include <regex>
#include <optional>
#include <optional>
#include "ck/stream_config.hpp"
#include "ck/stream_config.hpp"
#endif
namespace
ck
{
namespace
ck
{
namespace
tensor_operation
{
namespace
tensor_operation
{
namespace
device
{
namespace
device
{
#ifndef CK_CODE_GEN_RTC
#define GET_OBJECT_NAME_IMLP \
#define GET_OBJECT_NAME_IMLP \
std::optional<std::string> GetObjectName() const override \
std::optional<std::string> GetObjectName() const override \
{ \
{ \
...
@@ -41,7 +43,9 @@ namespace device {
...
@@ -41,7 +43,9 @@ namespace device {
}
}
#define REGISTER_EXTRA_PRINTING_METHODS GET_OBJECT_NAME_IMLP GET_TEMPLATE_INFO_IMPL
#define REGISTER_EXTRA_PRINTING_METHODS GET_OBJECT_NAME_IMLP GET_TEMPLATE_INFO_IMPL
#endif
#ifndef CK_CODE_GEN_RTC
struct
BaseArgument
struct
BaseArgument
{
{
BaseArgument
()
=
default
;
BaseArgument
()
=
default
;
...
@@ -66,13 +70,14 @@ struct BaseInvoker
...
@@ -66,13 +70,14 @@ struct BaseInvoker
virtual
~
BaseInvoker
()
{}
virtual
~
BaseInvoker
()
{}
};
};
#endif
struct
BaseOperator
struct
BaseOperator
{
{
BaseOperator
()
=
default
;
BaseOperator
()
=
default
;
BaseOperator
(
const
BaseOperator
&
)
=
default
;
BaseOperator
(
const
BaseOperator
&
)
=
default
;
BaseOperator
&
operator
=
(
const
BaseOperator
&
)
=
default
;
BaseOperator
&
operator
=
(
const
BaseOperator
&
)
=
default
;
#ifndef CK_CODE_GEN_RTC
virtual
bool
IsSupportedArgument
(
const
BaseArgument
*
)
{
return
false
;
}
virtual
bool
IsSupportedArgument
(
const
BaseArgument
*
)
{
return
false
;
}
virtual
std
::
string
GetTypeString
()
const
{
return
""
;
}
virtual
std
::
string
GetTypeString
()
const
{
return
""
;
}
...
@@ -100,7 +105,7 @@ struct BaseOperator
...
@@ -100,7 +105,7 @@ struct BaseOperator
assert
(
p_arg
);
assert
(
p_arg
);
p_arg
->
p_workspace_
=
p_workspace
;
p_arg
->
p_workspace_
=
p_workspace
;
}
}
#endif
virtual
~
BaseOperator
()
{}
virtual
~
BaseOperator
()
{}
};
};
...
...
include/ck/tensor_operation/gpu/device/device_grouped_conv_fwd_multiple_abd.hpp
View file @
dbb7002d
// SPDX-License-Identifier: MIT
// SPDX-License-Identifier: MIT
// Copyright (c) 2023-202
4
, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2023-202
5
, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#pragma once
#ifndef CK_CODE_GEN_RTC
#include <array>
#include <array>
#endif
#include "ck/tensor_operation/gpu/device/device_base.hpp"
#include "ck/tensor_operation/gpu/device/device_base.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_utils.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_utils.hpp"
...
@@ -13,8 +15,13 @@ namespace ck {
...
@@ -13,8 +15,13 @@ namespace ck {
namespace
tensor_operation
{
namespace
tensor_operation
{
namespace
device
{
namespace
device
{
#ifdef CK_CODE_GEN_RTC
template
<
typename
T
>
using
is_tuple
=
decltype
(
ck
::
declval
<
T
&>
().
IsTuple
());
#else
template
<
typename
T
>
template
<
typename
T
>
using
is_tuple
=
decltype
(
std
::
declval
<
T
&>
().
IsTuple
());
using
is_tuple
=
decltype
(
std
::
declval
<
T
&>
().
IsTuple
());
#endif
/**
/**
* \brief Grouped Convolution Forward
* \brief Grouped Convolution Forward
...
@@ -72,12 +79,18 @@ struct DeviceGroupedConvFwdMultipleABD : public BaseOperator
...
@@ -72,12 +79,18 @@ struct DeviceGroupedConvFwdMultipleABD : public BaseOperator
static
constexpr
index_t
NumDTensor
=
DsDataType
::
Size
();
static
constexpr
index_t
NumDTensor
=
DsDataType
::
Size
();
static_assert
(
NumDTensor
==
DsLayout
::
Size
(),
"wrong! Inconsistent NumDTensor"
);
static_assert
(
NumDTensor
==
DsLayout
::
Size
(),
"wrong! Inconsistent NumDTensor"
);
#ifdef CK_CODE_GEN_RTC
using
APointers
=
ck
::
conditional_t
<
isMultiA
,
ck
::
Array
<
const
void
*
,
NumATensor
>&
,
const
void
*>
;
using
BPointers
=
ck
::
conditional_t
<
isMultiB
,
ck
::
Array
<
const
void
*
,
NumBTensor
>&
,
const
void
*>
;
#else
// If DataType is tuple, user has to pass std::array with pointers.
// If DataType is tuple, user has to pass std::array with pointers.
using
APointers
=
using
APointers
=
std
::
conditional_t
<
isMultiA
,
std
::
array
<
const
void
*
,
NumATensor
>&
,
const
void
*>
;
ck
::
conditional_t
<
isMultiA
,
std
::
array
<
const
void
*
,
NumATensor
>&
,
const
void
*>
;
using
BPointers
=
using
BPointers
=
std
::
conditional_t
<
isMultiB
,
std
::
array
<
const
void
*
,
NumBTensor
>&
,
const
void
*>
;
ck
::
conditional_t
<
isMultiB
,
std
::
array
<
const
void
*
,
NumBTensor
>&
,
const
void
*>
;
#endif
#ifndef CK_CODE_GEN_RTC
/**
/**
* \brief Make argument pointer for grouped conv fwd.
* \brief Make argument pointer for grouped conv fwd.
...
@@ -150,6 +163,7 @@ struct DeviceGroupedConvFwdMultipleABD : public BaseOperator
...
@@ -150,6 +163,7 @@ struct DeviceGroupedConvFwdMultipleABD : public BaseOperator
const
CDEElementwiseOperation
&
cde_element_op
)
=
0
;
const
CDEElementwiseOperation
&
cde_element_op
)
=
0
;
virtual
std
::
unique_ptr
<
BaseInvoker
>
MakeInvokerPointer
()
=
0
;
virtual
std
::
unique_ptr
<
BaseInvoker
>
MakeInvokerPointer
()
=
0
;
#endif
};
};
}
// namespace device
}
// namespace device
...
...
include/ck/tensor_operation/gpu/device/gemm_specialization.hpp
View file @
dbb7002d
// SPDX-License-Identifier: MIT
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-202
3
, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-202
5
, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#pragma once
...
@@ -29,6 +29,7 @@ enum struct GemmSpecialization
...
@@ -29,6 +29,7 @@ enum struct GemmSpecialization
MNKOPadding
,
MNKOPadding
,
};
};
#ifndef CK_CODE_GEN_RTC
inline
std
::
string
getGemmSpecializationString
(
const
GemmSpecialization
&
s
)
inline
std
::
string
getGemmSpecializationString
(
const
GemmSpecialization
&
s
)
{
{
switch
(
s
)
switch
(
s
)
...
@@ -52,6 +53,7 @@ inline std::string getGemmSpecializationString(const GemmSpecialization& s)
...
@@ -52,6 +53,7 @@ inline std::string getGemmSpecializationString(const GemmSpecialization& s)
default:
return
"Unrecognized specialization!"
;
default:
return
"Unrecognized specialization!"
;
}
}
}
}
#endif
}
// namespace device
}
// namespace device
}
// namespace tensor_operation
}
// namespace tensor_operation
...
...
include/ck/tensor_operation/gpu/device/impl/codegen_device_grouped_conv_fwd_multiple_abd_xdl_cshuffle.hpp
View file @
dbb7002d
...
@@ -3,11 +3,17 @@
...
@@ -3,11 +3,17 @@
#pragma once
#pragma once
#ifndef CK_CODE_GEN_RTC
#include <functional>
#include <functional>
#include <iostream>
#include <iostream>
#include <iterator>
#include <iterator>
#include <numeric>
#include <numeric>
#include <sstream>
#include <sstream>
#include <stdio.h>
#include "ck/host_utility/device_prop.hpp"
#include "ck/host_utility/kernel_launch.hpp"
#endif
#include "ck/utility/common_header.hpp"
#include "ck/utility/common_header.hpp"
#include "ck/tensor_description/tensor_descriptor.hpp"
#include "ck/tensor_description/tensor_descriptor.hpp"
...
@@ -15,15 +21,12 @@
...
@@ -15,15 +21,12 @@
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/convolution_forward_specialization.hpp"
#include "ck/tensor_operation/gpu/device/convolution_forward_specialization.hpp"
#include "ck/tensor_operation/operator_transform/transform_conv_fwd_to_gemm.hpp"
#include "ck/tensor_operation/operator_transform/transform_conv_fwd_to_gemm.hpp"
#include "ck/tensor_operation/gpu/device/device_grouped_conv_fwd_multiple_abd.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/device/matrix_padder.hpp"
#include "ck/tensor_operation/gpu/device/matrix_padder.hpp"
#include "ck/tensor_operation/gpu/device/device_grouped_conv_fwd_multiple_abd.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_d_xdl_cshuffle.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_d_xdl_cshuffle.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_abd_xdl_cshuffle.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_abd_xdl_cshuffle.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_utils.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_utils.hpp"
#include "ck/host_utility/device_prop.hpp"
#include "ck/host_utility/kernel_launch.hpp"
#include "ck/host_utility/io.hpp"
namespace
ck
{
namespace
ck
{
namespace
tensor_operation
{
namespace
tensor_operation
{
...
@@ -91,8 +94,7 @@ __device__ void device_grouped_conv_fwd_multiple_abd_xdl_cshuffle(
...
@@ -91,8 +94,7 @@ __device__ void device_grouped_conv_fwd_multiple_abd_xdl_cshuffle(
const
Block2ETileMap
block_2_ctile_map
,
const
Block2ETileMap
block_2_ctile_map
,
const
ComputePtrOffsetOfBatch
compute_ptr_offset_of_batch
)
const
ComputePtrOffsetOfBatch
compute_ptr_offset_of_batch
)
{
{
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx9__))
defined(__gfx94__))
// offset base pointer for each work-group
// offset base pointer for each work-group
const
index_t
num_blocks_per_batch
=
const
index_t
num_blocks_per_batch
=
__builtin_amdgcn_readfirstlane
(
get_grid_size
()
/
batch_count
);
__builtin_amdgcn_readfirstlane
(
get_grid_size
()
/
batch_count
);
...
@@ -259,8 +261,13 @@ __global__ void
...
@@ -259,8 +261,13 @@ __global__ void
}
// namespace
}
// namespace
#ifdef CK_CODE_GEN_RTC
template
<
typename
T
>
using
is_tuple
=
decltype
(
ck
::
declval
<
T
&>
().
IsTuple
());
#else
template
<
typename
T
>
template
<
typename
T
>
using
is_tuple
=
decltype
(
std
::
declval
<
T
&>
().
IsTuple
());
using
is_tuple
=
decltype
(
std
::
declval
<
T
&>
().
IsTuple
());
#endif
//
//
// @brief Device Convolution operation.
// @brief Device Convolution operation.
...
@@ -429,8 +436,8 @@ struct CodegenDeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
...
@@ -429,8 +436,8 @@ struct CodegenDeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
// If we are using multiAB and one of the template datatype parameters is not a tuple, convert
// If we are using multiAB and one of the template datatype parameters is not a tuple, convert
// it to it
// it to it
using
GemmADataType
=
std
::
conditional_t
<!
isMultiA
&&
isMultiB
,
Tuple
<
ADataType
>
,
ADataType
>
;
using
GemmADataType
=
ck
::
conditional_t
<!
isMultiA
&&
isMultiB
,
Tuple
<
ADataType
>
,
ADataType
>
;
using
GemmBDataType
=
std
::
conditional_t
<!
isMultiB
&&
isMultiA
,
Tuple
<
BDataType
>
,
BDataType
>
;
using
GemmBDataType
=
ck
::
conditional_t
<!
isMultiB
&&
isMultiA
,
Tuple
<
BDataType
>
,
BDataType
>
;
#define GridwiseGemmTemplateParameters \
#define GridwiseGemmTemplateParameters \
GemmADataType, GemmBDataType, ComputeDataType, AccDataType, CShuffleDataType, DsDataType, \
GemmADataType, GemmBDataType, ComputeDataType, AccDataType, CShuffleDataType, DsDataType, \
...
@@ -449,15 +456,13 @@ struct CodegenDeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
...
@@ -449,15 +456,13 @@ struct CodegenDeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
CDEBlockTransferScalarPerVector_NPerBlock, LoopSched
CDEBlockTransferScalarPerVector_NPerBlock, LoopSched
// Use appropriate gridwise gemm
// Use appropriate gridwise gemm
using
GridwiseGemm
=
using
GridwiseGemm
=
std
::
conditional_t
<
isMultiA
||
isMultiB
,
ck
::
conditional_t
<
isMultiA
||
isMultiB
,
GridwiseGemmMultipleABD_xdl_cshuffle
<
GridwiseGemmTemplateParameters
>
,
GridwiseGemmMultipleABD_xdl_cshuffle
<
GridwiseGemmTemplateParameters
>
,
GridwiseGemmMultipleD_xdl_cshuffle
<
GridwiseGemmTemplateParameters
>>
;
GridwiseGemmMultipleD_xdl_cshuffle
<
GridwiseGemmTemplateParameters
>>
;
// If ADataTypes or BDataTypes is tuple, user has to pass ck::Array with pointers.
// If ADataTypes or BDataTypes is tuple, user has to pass ck::Array with pointers.
using
APointers
=
using
APointers
=
ck
::
conditional_t
<
isMultiA
,
ck
::
Array
<
const
void
*
,
NumATensor
>&
,
const
void
*>
;
std
::
conditional_t
<
isMultiA
,
ck
::
Array
<
const
void
*
,
NumATensor
>&
,
const
void
*>
;
using
BPointers
=
ck
::
conditional_t
<
isMultiB
,
ck
::
Array
<
const
void
*
,
NumBTensor
>&
,
const
void
*>
;
using
BPointers
=
std
::
conditional_t
<
isMultiB
,
ck
::
Array
<
const
void
*
,
NumBTensor
>&
,
const
void
*>
;
// Use Tuple for the both cases for GridPointer to initialize it in Argument constructor (not
// Use Tuple for the both cases for GridPointer to initialize it in Argument constructor (not
// in initializer list what is required for single const pointer).
// in initializer list what is required for single const pointer).
using
AGridPointer
=
remove_cvref_t
<
using
AGridPointer
=
remove_cvref_t
<
...
@@ -812,7 +817,6 @@ struct CodegenDeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
...
@@ -812,7 +817,6 @@ struct CodegenDeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
static_for
<
0
,
NumDTensor
,
1
>
{}([
&
](
auto
i
)
{
static_for
<
0
,
NumDTensor
,
1
>
{}([
&
](
auto
i
)
{
using
DLayout
=
remove_cvref_t
<
tuple_element_t
<
i
.
value
,
DsLayout
>>
;
using
DLayout
=
remove_cvref_t
<
tuple_element_t
<
i
.
value
,
DsLayout
>>
;
// FIXME: layout
// FIXME: layout
if
constexpr
(
is_same_v
<
DLayout
,
ctc
::
G_NW_K
>
||
is_same_v
<
DLayout
,
ctc
::
G_NHW_K
>
||
if
constexpr
(
is_same_v
<
DLayout
,
ctc
::
G_NW_K
>
||
is_same_v
<
DLayout
,
ctc
::
G_NHW_K
>
||
is_same_v
<
DLayout
,
ctc
::
G_NDHW_K
>
||
is_same_v
<
DLayout
,
ctc
::
GNWK
>
||
is_same_v
<
DLayout
,
ctc
::
G_NDHW_K
>
||
is_same_v
<
DLayout
,
ctc
::
GNWK
>
||
...
@@ -965,18 +969,18 @@ struct CodegenDeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
...
@@ -965,18 +969,18 @@ struct CodegenDeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
const
BElementwiseOperation
&
b_element_op
,
const
BElementwiseOperation
&
b_element_op
,
const
CDEElementwiseOperation
&
cde_element_op
)
const
CDEElementwiseOperation
&
cde_element_op
)
{
{
std
::
a
rray
<
index_t
,
NDimSpatial
+
3
>
a_g_n_c_wis_lengths_i32
;
ck
::
A
rray
<
index_t
,
NDimSpatial
+
3
>
a_g_n_c_wis_lengths_i32
;
std
::
a
rray
<
index_t
,
NDimSpatial
+
3
>
a_g_n_c_wis_strides_i32
;
ck
::
A
rray
<
index_t
,
NDimSpatial
+
3
>
a_g_n_c_wis_strides_i32
;
std
::
a
rray
<
index_t
,
NDimSpatial
+
3
>
b_g_k_c_xs_lengths_i32
;
ck
::
A
rray
<
index_t
,
NDimSpatial
+
3
>
b_g_k_c_xs_lengths_i32
;
std
::
a
rray
<
index_t
,
NDimSpatial
+
3
>
b_g_k_c_xs_strides_i32
;
ck
::
A
rray
<
index_t
,
NDimSpatial
+
3
>
b_g_k_c_xs_strides_i32
;
std
::
a
rray
<
std
::
a
rray
<
index_t
,
NDimSpatial
+
3
>
,
NumDTensor
>
ds_g_n_k_wos_lengths_i32
;
ck
::
A
rray
<
ck
::
A
rray
<
index_t
,
NDimSpatial
+
3
>
,
NumDTensor
>
ds_g_n_k_wos_lengths_i32
;
std
::
a
rray
<
std
::
a
rray
<
index_t
,
NDimSpatial
+
3
>
,
NumDTensor
>
ds_g_n_k_wos_strides_i32
;
ck
::
A
rray
<
ck
::
A
rray
<
index_t
,
NDimSpatial
+
3
>
,
NumDTensor
>
ds_g_n_k_wos_strides_i32
;
std
::
a
rray
<
index_t
,
NDimSpatial
+
3
>
e_g_n_k_wos_lengths_i32
;
ck
::
A
rray
<
index_t
,
NDimSpatial
+
3
>
e_g_n_k_wos_lengths_i32
;
std
::
a
rray
<
index_t
,
NDimSpatial
+
3
>
e_g_n_k_wos_strides_i32
;
ck
::
A
rray
<
index_t
,
NDimSpatial
+
3
>
e_g_n_k_wos_strides_i32
;
std
::
a
rray
<
index_t
,
NDimSpatial
>
conv_filter_strides_i32
;
ck
::
A
rray
<
index_t
,
NDimSpatial
>
conv_filter_strides_i32
;
std
::
a
rray
<
index_t
,
NDimSpatial
>
conv_filter_dilations_i32
;
ck
::
A
rray
<
index_t
,
NDimSpatial
>
conv_filter_dilations_i32
;
std
::
a
rray
<
index_t
,
NDimSpatial
>
input_left_pads_i32
;
ck
::
A
rray
<
index_t
,
NDimSpatial
>
input_left_pads_i32
;
std
::
a
rray
<
index_t
,
NDimSpatial
>
input_right_pads_i32
;
ck
::
A
rray
<
index_t
,
NDimSpatial
>
input_right_pads_i32
;
array_convert
(
a_g_n_c_wis_lengths_i32
,
a_g_n_c_wis_lengths
);
array_convert
(
a_g_n_c_wis_lengths_i32
,
a_g_n_c_wis_lengths
);
array_convert
(
a_g_n_c_wis_strides_i32
,
a_g_n_c_wis_strides
);
array_convert
(
a_g_n_c_wis_strides_i32
,
a_g_n_c_wis_strides
);
...
...
include/ck/tensor_operation/gpu/device/impl/device_batched_contraction_multiple_d_xdl_cshuffle.hpp
View file @
dbb7002d
...
@@ -56,8 +56,7 @@ __global__ void
...
@@ -56,8 +56,7 @@ __global__ void
const
ComputePtrOffsetOfBatch
compute_ptr_offset_of_batch
,
const
ComputePtrOffsetOfBatch
compute_ptr_offset_of_batch
,
const
Block2ETileMap
block_2_etile_map
)
const
Block2ETileMap
block_2_etile_map
)
{
{
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx9__))
defined(__gfx94__))
__shared__
char
p_shared
[
GridwiseGemm
::
GetSharedMemoryNumberOfByte
()];
__shared__
char
p_shared
[
GridwiseGemm
::
GetSharedMemoryNumberOfByte
()];
const
index_t
num_blocks_per_batch
=
const
index_t
num_blocks_per_batch
=
...
...
Prev
1
2
3
4
5
6
7
…
12
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment