Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel_ROCM
Commits
687d2b7e
Commit
687d2b7e
authored
Mar 19, 2024
by
Jun Liu
Browse files
Merge branch 'develop' into amd-develop
parents
5d718e6b
f5210953
Changes
161
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
475 additions
and
140 deletions
+475
-140
docs/index.rst
docs/index.rst
+12
-13
docs/install/dockerhub.rst
docs/install/dockerhub.rst
+1
-1
docs/license.md
docs/license.md
+0
-2
docs/license.rst
docs/license.rst
+11
-0
docs/reference/API_Reference_Guide.rst
docs/reference/API_Reference_Guide.rst
+0
-0
docs/reference/Supported_Primitives_Guide.rst
docs/reference/Supported_Primitives_Guide.rst
+0
-0
docs/reference/wrapper.rst
docs/reference/wrapper.rst
+7
-7
docs/sphinx/_toc.yml.in
docs/sphinx/_toc.yml.in
+24
-9
docs/sphinx/requirements.in
docs/sphinx/requirements.in
+1
-1
docs/sphinx/requirements.txt
docs/sphinx/requirements.txt
+1
-1
docs/tutorial/tutorial_hello_world.rst
docs/tutorial/tutorial_hello_world.rst
+0
-0
example/01_gemm/CMakeLists.txt
example/01_gemm/CMakeLists.txt
+8
-7
example/01_gemm/gemm_wmma_fp16.cpp
example/01_gemm/gemm_wmma_fp16.cpp
+41
-6
example/01_gemm/gemm_xdl_fp16_fp8.cpp
example/01_gemm/gemm_xdl_fp16_fp8.cpp
+8
-2
example/01_gemm/run_gemm_example.inc
example/01_gemm/run_gemm_example.inc
+103
-2
example/02_gemm_bilinear/gemm_bilinear_wmma_fp16.cpp
example/02_gemm_bilinear/gemm_bilinear_wmma_fp16.cpp
+44
-43
example/02_gemm_bilinear/gemm_bilinear_wmma_int8.cpp
example/02_gemm_bilinear/gemm_bilinear_wmma_int8.cpp
+44
-43
example/09_convnd_fwd/CMakeLists.txt
example/09_convnd_fwd/CMakeLists.txt
+1
-0
example/09_convnd_fwd/convnd_fwd_common.hpp
example/09_convnd_fwd/convnd_fwd_common.hpp
+88
-3
example/09_convnd_fwd/convnd_fwd_xdl_fp8.cpp
example/09_convnd_fwd/convnd_fwd_xdl_fp8.cpp
+81
-0
No files found.
docs/index.rst
View file @
687d2b7e
...
...
@@ -12,27 +12,26 @@ The Composable Kernel (CK) library provides a programming model for writing perf
The CK documentation is structured as follows:
.. card:: Conceptual
.. grid:: 2
:gutter: 3
* :ref:`what-is-ck`
.. grid-item-card:: Installation
.. card:: Installation
* :ref:`docker-hub`
* :ref:`docker-hub`
.. grid-item-card:: Conceptual
.. card:: Tutorial
* :ref:`what-is-ck`
* :ref:`hello-world`
.. grid-item-card:: API reference
.. card:: API reference
* :ref:`supported-primitives`
* :ref:`api-reference`
* :ref:`wrapper`
* :ref:`supported-primitives`
* :ref:`api-reference`
* :ref:`wrapper`
.. grid-item-card:: Tutorial
.. card:: Contributing to CK
* :ref:`contributing-to`
* :ref:`hello-world`
To contribute to the documentation refer to `Contributing to ROCm <https://rocm.docs.amd.com/en/latest/contribute/index.html>`_.
...
...
docs/dockerhub.rst
→
docs/
install/
dockerhub.rst
View file @
687d2b7e
...
...
@@ -36,7 +36,7 @@ What is inside the image?
The docker images have everything you need for running CK including:
* `ROCm <https://
www
.amd.com/en/
graphics/servers-solutions-rocm
>`_
* `ROCm <https://
rocm.docs
.amd.com/en/
latest/index.html
>`_
* `CMake <https://cmake.org/getting-started/>`_
* `Compiler <https://github.com/ROCm/llvm-project>`_
* `Composable Kernel library <https://github.com/ROCm/composable_kernel>`_
...
...
docs/license.md
deleted
100644 → 0
View file @
5d718e6b
```
{include} ../LICENSE.md
```
docs/license.rst
0 → 100644
View file @
687d2b7e
.. meta::
:description: Composable Kernel documentation and API reference library
:keywords: composable kernel, CK, ROCm, API, documentation
.. _license:
********************************************************************
License
********************************************************************
.. include:: ../LICENSE
\ No newline at end of file
docs/API_Reference_Guide.rst
→
docs/
reference/
API_Reference_Guide.rst
View file @
687d2b7e
File moved
docs/Supported_Primitives_Guide.rst
→
docs/
reference/
Supported_Primitives_Guide.rst
View file @
687d2b7e
File moved
docs/wrapper.rst
→
docs/
reference/
wrapper.rst
View file @
687d2b7e
...
...
@@ -64,31 +64,31 @@ Advanced examples:
Layout
-------------------------------------
.. doxygenstruct::
ck::wrapper::
Layout
.. doxygenstruct:: Layout
-------------------------------------
Layout helpers
-------------------------------------
.. doxygenfile:: layout_utils.hpp
.. doxygenfile::
include/ck/wrapper/utils/
layout_utils.hpp
-------------------------------------
Tensor
-------------------------------------
.. doxygenstruct::
ck::wrapper::
Tensor
.. doxygenstruct:: Tensor
-------------------------------------
Tensor helpers
-------------------------------------
.. doxygenfile:: tensor_utils.hpp
.. doxygenfile::
include/ck/wrapper/utils/
tensor_utils.hpp
.. doxygenfile:: tensor_partition.hpp
.. doxygenfile::
include/ck/wrapper/utils/
tensor_partition.hpp
-------------------------------------
Operations
-------------------------------------
.. doxygenfile:: copy.hpp
.. doxygenfile:: gemm.hpp
.. doxygenfile::
include/ck/wrapper/operations/
copy.hpp
.. doxygenfile::
include/ck/wrapper/operations/
gemm.hpp
docs/sphinx/_toc.yml.in
View file @
687d2b7e
...
...
@@ -2,20 +2,35 @@ defaults:
numbered: False
root: index
subtrees:
- entries:
- file: what-is-ck.rst
- caption: Conceptual
entries:
- file: conceptual/what-is-ck.rst
title: What is Composable Kernel?
- file: dockerhub.rst
- caption: Install
entries:
- file: install/dockerhub.rst
title: Docker Hub
- file: tutorial_hello_world.rst
title: Hello World Tutorial
- file: Supported_Primitives_Guide.rst
- caption: CK API Reference
entries:
- file: reference/Supported_Primitives_Guide.rst
title: Supported Primitives
- file: API_Reference_Guide.rst
- file:
reference/
API_Reference_Guide.rst
title: API Reference
- file: wrapper.rst
- file:
reference/
wrapper.rst
title: Wrapper
- caption: Tutorial
entries:
- file: tutorial/tutorial_hello_world.rst
title: Hello World Tutorial
- caption: About
entries:
- file: Contributors_Guide.rst
title: Contributing to CK
- file: license.
md
- file: license.
rst
title: License
\ No newline at end of file
docs/sphinx/requirements.in
View file @
687d2b7e
rocm-docs-core==0.3
5
.0
rocm-docs-core==0.3
6
.0
sphinxcontrib-bibtex==2.6.2
docs/sphinx/requirements.txt
View file @
687d2b7e
...
...
@@ -113,7 +113,7 @@ requests==2.31.0
# via
# pygithub
# sphinx
rocm-docs-core==0.3
5
.0
rocm-docs-core==0.3
6
.0
# via -r requirements.in
six==1.16.0
# via
...
...
docs/tutorial_hello_world.rst
→
docs/tutorial
/tutorial
_hello_world.rst
View file @
687d2b7e
File moved
example/01_gemm/CMakeLists.txt
View file @
687d2b7e
...
...
@@ -27,7 +27,7 @@ add_example_dependencies(example_gemm_xdl example_gemm_xdl_wavelet_fp16)
add_example_executable
(
example_gemm_xdl_skip_b_lds_fp16 gemm_xdl_skip_b_lds_fp16.cpp
)
add_example_dependencies
(
example_gemm_xdl example_gemm_xdl_skip_b_lds_fp16
)
if
(
GPU_TARGETS MATCHES
"gfx11
00"
OR GPU_TARGETS MATCHES
"gfx1101"
OR GPU_TARGETS MATCHES
"gfx1102
"
)
if
(
GPU_TARGETS MATCHES
"gfx11"
)
add_custom_target
(
example_gemm_wmma
)
add_example_executable
(
example_gemm_wmma_fp16 gemm_wmma_fp16.cpp
)
add_example_dependencies
(
example_gemm_wmma example_gemm_wmma_fp16
)
...
...
@@ -53,12 +53,6 @@ add_example_dependencies(example_gemm_xdl example_gemm_xdl_fp64)
add_example_executable
(
example_gemm_xdl_streamk gemm_xdl_streamk.cpp
)
add_example_executable
(
example_gemm_xdl_fp8 gemm_xdl_fp8.cpp
)
add_example_dependencies
(
example_gemm_xdl example_gemm_xdl_fp8
)
add_example_executable
(
example_gemm_xdl_fp8_bf8 gemm_xdl_fp8_bf8.cpp
)
add_example_dependencies
(
example_gemm_xdl example_gemm_xdl_fp8_bf8
)
list
(
APPEND gpu_list gfx90a gfx940 gfx941 gfx942
)
set
(
target 0
)
foreach
(
gpu IN LISTS GPU_TARGETS
)
...
...
@@ -72,5 +66,12 @@ foreach(gpu IN LISTS GPU_TARGETS)
endif
()
endforeach
()
add_example_executable
(
example_gemm_xdl_fp8 gemm_xdl_fp8.cpp
)
add_example_dependencies
(
example_gemm_xdl example_gemm_xdl_fp8
)
add_example_executable
(
example_gemm_xdl_fp8_bf8 gemm_xdl_fp8_bf8.cpp
)
add_example_dependencies
(
example_gemm_xdl example_gemm_xdl_fp8_bf8
)
add_example_executable
(
example_gemm_xdl_fp16_fp8 gemm_xdl_fp16_fp8.cpp
)
add_example_dependencies
(
example_gemm_xdl example_gemm_xdl_fp16_fp8
)
example/01_gemm/gemm_wmma_fp16.cpp
View file @
687d2b7e
...
...
@@ -19,15 +19,50 @@ using AElementOp = PassThrough;
using
BElementOp
=
PassThrough
;
using
CElementOp
=
PassThrough
;
static
constexpr
auto
Gemm
MNKPadding
=
ck
::
tensor_operation
::
device
::
GemmSpecialization
::
MNKPadding
;
static
constexpr
auto
Gemm
Default
=
ck
::
tensor_operation
::
device
::
GemmSpecialization
::
MNKPadding
;
// clang-format off
using
DeviceGemmInstance
=
ck
::
tensor_operation
::
device
::
DeviceGemmWmma_CShuffle
// ######| ALayout| BLayout| CLayout| AData| BData| CData| AccData| CShuffle| A| B| C| GEMM| Block| MPer| NPer| K0Per| K1| MPer| NPer|MRepeat|NRepeat| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer|
// ######| | | | Type| Type| Type| Type| DataType| Elementwise| Elementwise| Elementwise| Spacialization| Size| Block| Block| Block| | WMMA| WMMA| | | ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN|MWmmaPerWave|NWmmaPerWave| _MBlock_MWaveMPerWmma| ScalarPerVector|
// ######| | | | | | | | | Operation| Operation| Operation| | | | | | | | | | | Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerWmma| _NWaveNPerWmma|
// ######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
<
ALayout
,
BLayout
,
CLayout
,
ADataType
,
BDataType
,
CDataType
,
AccDataType
,
CShuffleDataType
,
AElementOp
,
BElementOp
,
CElementOp
,
GemmMNKPadding
,
256
,
128
,
256
,
8
,
8
,
16
,
16
,
4
,
4
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
1
,
1
,
S
<
1
,
32
,
1
,
8
>
,
8
,
1
>
;
<
ALayout
,
BLayout
,
CLayout
,
ADataType
,
BDataType
,
CDataType
,
AccDataType
,
CShuffleDataType
,
AElementOp
,
BElementOp
,
CElementOp
,
GemmDefault
,
1
,
// Prefetch stage
128
,
// BlockSize
64
,
// MPerBlock
128
,
// NPerBlock
64
,
// KPerBlock
8
,
// K1
16
,
// MPerWmma
16
,
// NPerWmma
2
,
// M-Repeat // M-PerWmma / M-Repeat = M-Wave
4
,
// N-Repeat // N-PerWmma / N-Repeat = N-Wave
S
<
4
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
S
<
4
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
1
,
// C shuffle (M Repeat) Per store
1
,
// C shuffle (N Repeat) Per store
S
<
1
,
32
,
1
,
4
>
,
8
>
;
// clang-format on
using
ReferenceGemmInstance
=
ck
::
tensor_operation
::
host
::
...
...
example/01_gemm/gemm_xdl_fp16_fp8.cpp
View file @
687d2b7e
...
...
@@ -33,8 +33,14 @@ using DeviceGemmInstance = ck::tensor_operation::device::DeviceGemm_Xdl_CShuffle
<
ALayout
,
BLayout
,
CLayout
,
ADataType
,
BDataType
,
CDataType
,
AccDataType
,
CShuffleDataType
,
AElementOp
,
BElementOp
,
CElementOp
,
GemmDefault
,
1
,
256
,
256
,
128
,
32
,
8
,
8
,
32
,
32
,
4
,
2
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
1
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
1
,
1
,
1
,
S
<
1
,
32
,
1
,
8
>
,
8
,
LoopSched
,
PipelineVer
,
ComputeType
>
;
// clang-format on
using
ReferenceGemmInstance
=
ck
::
tensor_operation
::
host
::
ReferenceGemm
<
ADataType
,
BDataType
,
CDataType
,
AccDataType
,
AElementOp
,
BElementOp
,
CElementOp
>
;
using
ReferenceGemmInstance
=
ck
::
tensor_operation
::
host
::
ReferenceGemm
<
ADataType
,
BDataType
,
CDataType
,
AccDataType
,
AElementOp
,
BElementOp
,
CElementOp
,
ComputeType
>
;
#include "run_gemm_example.inc"
...
...
example/01_gemm/run_gemm_example.inc
View file @
687d2b7e
...
...
@@ -5,6 +5,88 @@
#include "ck/tensor_operation/gpu/device/device_gemm_streamk.hpp"
template
<
typename
DataType
>
inline
__host__
__device__
constexpr
double
get_rtol
()
{
if
constexpr
(
std
::
is_same_v
<
DataType
,
float
>
)
{
return
1
e
-
3
;
}
else
if
constexpr
(
std
::
is_same_v
<
DataType
,
double
>
)
{
return
1
e
-
6
;
}
else
if
constexpr
(
std
::
is_same_v
<
DataType
,
ck
::
half_t
>
)
{
return
1
e
-
3
;
}
else
if
constexpr
(
std
::
is_same_v
<
DataType
,
ck
::
bhalf_t
>
)
{
return
5
e
-
2
;
}
else
if
constexpr
(
std
::
is_same_v
<
DataType
,
int32_t
>
)
{
return
1
e
-
1
;
}
else
if
constexpr
(
std
::
is_same_v
<
DataType
,
int8_t
>
)
{
return
1
e
-
1
;
}
else
if
constexpr
(
std
::
is_same_v
<
DataType
,
ck
::
f8_t
>
)
{
return
1
e
-
1
;
// 240 and 224 are acceptable
}
else
if
constexpr
(
std
::
is_same_v
<
DataType
,
ck
::
bf8_t
>
)
{
return
1.5e-1
;
// 57344 and 49152 are acceptable
}
else
{
return
1
e
-
3
;
}
}
template
<
typename
DataType
>
inline
__host__
__device__
constexpr
double
get_atol
()
{
if
constexpr
(
std
::
is_same_v
<
DataType
,
float
>
)
{
return
1
e
-
3
;
}
else
if
constexpr
(
std
::
is_same_v
<
DataType
,
double
>
)
{
return
1
e
-
6
;
}
else
if
constexpr
(
std
::
is_same_v
<
DataType
,
ck
::
half_t
>
)
{
return
1
e
-
3
;
}
else
if
constexpr
(
std
::
is_same_v
<
DataType
,
ck
::
bhalf_t
>
)
{
return
5
e
-
2
;
}
else
if
constexpr
(
std
::
is_same_v
<
DataType
,
int32_t
>
)
{
return
1
e
-
1
;
}
else
if
constexpr
(
std
::
is_same_v
<
DataType
,
int8_t
>
)
{
return
1
e
-
1
;
}
else
if
constexpr
(
std
::
is_same_v
<
DataType
,
ck
::
f8_t
>
)
{
return
16.1
;
// 240 and 224 are acceptable
}
else
if
constexpr
(
std
::
is_same_v
<
DataType
,
ck
::
bf8_t
>
)
{
return
8192.1
;
// 57344 and 49152 are acceptable
}
else
{
return
1
e
-
3
;
}
}
template
<
typename
ProblemType
>
bool
run_gemm
(
const
ProblemType
&
problem_size
,
const
ExecutionConfig
&
config
)
{
...
...
@@ -68,6 +150,22 @@ bool run_gemm(const ProblemType& problem_size, const ExecutionConfig& config)
ck
::
utils
::
FillUniformDistributionIntegerValue
<
ADataType
>
{
-
5.
f
,
5.
f
}(
a_m_k
);
ck
::
utils
::
FillUniformDistributionIntegerValue
<
BDataType
>
{
-
5.
f
,
5.
f
}(
b_k_n
);
break
;
case
2
:
ck
::
utils
::
FillUniformDistribution
<
ADataType
>
{
-
1.
f
,
1.
f
}(
a_m_k
);
ck
::
utils
::
FillUniformDistribution
<
BDataType
>
{
-
1.
f
,
1.
f
}(
b_k_n
);
break
;
case
3
:
ck
::
utils
::
FillUniformDistributionIntegerValue
<
ADataType
>
{
1.
f
,
1.
f
}(
a_m_k
);
ck
::
utils
::
FillUniformDistributionIntegerValue
<
BDataType
>
{
-
5.
f
,
5.
f
}(
b_k_n
);
break
;
case
4
:
ck
::
utils
::
FillUniformDistributionIntegerValue
<
ADataType
>
{
1.
f
,
1.
f
}(
a_m_k
);
ck
::
utils
::
FillUniformDistributionIntegerValue
<
BDataType
>
{
1.
f
,
1.
f
}(
b_k_n
);
break
;
case
5
:
ck
::
utils
::
FillUniformDistributionIntegerValue
<
ADataType
>
{
-
2.
f
,
2.
f
}(
a_m_k
);
ck
::
utils
::
FillUniformDistributionIntegerValue
<
BDataType
>
{
-
2.
f
,
2.
f
}(
b_k_n
);
break
;
default
:
ck
::
utils
::
FillUniformDistribution
<
ADataType
>
{
-
0.1
f
,
0.1
f
}(
a_m_k
);
ck
::
utils
::
FillUniformDistribution
<
BDataType
>
{
-
0.1
f
,
0.1
f
}(
b_k_n
);
...
...
@@ -240,8 +338,11 @@ bool run_gemm(const ProblemType& problem_size, const ExecutionConfig& config)
#else
c_m_n_device_buf
.
FromDevice
(
c_m_n_device_result
.
mData
.
data
());
return
ck
::
utils
::
check_err
(
c_m_n_device_result
,
c_m_n_host_result
,
"Error: Incorrect results!"
,
1
e
-
1
,
1
e
-
1
);
return
ck
::
utils
::
check_err
(
c_m_n_device_result
,
c_m_n_host_result
,
"Error: Incorrect results!"
,
get_rtol
<
CDataType
>
(),
get_atol
<
CDataType
>
());
#endif
}
...
...
example/02_gemm_bilinear/gemm_bilinear_wmma_fp16.cpp
View file @
687d2b7e
...
...
@@ -65,48 +65,49 @@ using CDEElementOp = AlphaBetaAdd;
static
constexpr
auto
GemmSpec
=
ck
::
tensor_operation
::
device
::
GemmSpecialization
::
MNKPadding
;
using
DeviceOpInstance
=
ck
::
tensor_operation
::
device
::
DeviceGemmMultipleD_Wmma_CShuffle
<
ALayout
,
BLayout
,
ck
::
Tuple
<
DLayout
>
,
ELayout
,
ADataType
,
BDataType
,
ck
::
Tuple
<
DDataType
>
,
EDataType
,
AccDataType
,
CShuffleDataType
,
AElementOp
,
BElementOp
,
CDEElementOp
,
GemmSpec
,
256
,
128
,
256
,
8
,
8
,
16
,
16
,
4
,
4
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
1
,
1
,
S
<
1
,
32
,
1
,
8
>
,
8
>
;
using
DeviceOpInstance
=
ck
::
tensor_operation
::
device
::
DeviceGemmMultipleD_Wmma_CShuffle
<
ALayout
,
BLayout
,
ck
::
Tuple
<
DLayout
>
,
ELayout
,
ADataType
,
BDataType
,
AccDataType
,
CShuffleDataType
,
ck
::
Tuple
<
DDataType
>
,
EDataType
,
AElementOp
,
BElementOp
,
CDEElementOp
,
GemmSpec
,
2
,
// Prefetch stage
128
,
// BlockSize
128
,
// MPerBlock
64
,
// NPerBlock
64
,
// KPerBlock
8
,
// K1
16
,
// MPerWmma
16
,
// NPerWmma
4
,
// M-Repeat // M-PerWmma / M-Repeat = M-Wave
2
,
// N-Repeat // N-PerWmma / N-Repeat = N-Wave
S
<
4
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
S
<
4
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
1
,
// C shuffle (M Repeat) Per store
1
,
// C shuffle (N Repeat) Per store
S
<
1
,
32
,
1
,
4
>
,
8
>
;
int
main
(
int
argc
,
char
*
argv
[])
{
...
...
@@ -264,7 +265,7 @@ int main(int argc, char* argv[])
float
gb_per_sec
=
num_btype
/
1.E6
/
ave_time
;
std
::
cout
<<
"Perf: "
<<
ave_time
<<
" ms, "
<<
tflops
<<
" TFlops, "
<<
gb_per_sec
<<
" GB/s"
<<
std
::
endl
;
<<
device_op
.
GetTypeString
()
<<
std
::
endl
;
e_device_buf
.
FromDevice
(
e_m_n_device_result
.
mData
.
data
());
...
...
example/02_gemm_bilinear/gemm_bilinear_wmma_int8.cpp
View file @
687d2b7e
...
...
@@ -55,7 +55,7 @@ using DDataType = I8;
using
EDataType
=
I8
;
using
ALayout
=
Row
;
using
BLayout
=
Row
;
using
BLayout
=
Col
;
using
DLayout
=
Row
;
using
ELayout
=
Row
;
...
...
@@ -65,48 +65,49 @@ using CDEElementOp = AlphaBetaAdd;
static
constexpr
auto
GemmSpec
=
ck
::
tensor_operation
::
device
::
GemmSpecialization
::
Default
;
using
DeviceOpInstance
=
ck
::
tensor_operation
::
device
::
DeviceGemmMultipleD_Wmma_CShuffle
<
ALayout
,
BLayout
,
ck
::
Tuple
<
DLayout
>
,
ELayout
,
ADataType
,
BDataType
,
ck
::
Tuple
<
DDataType
>
,
EDataType
,
AccDataType
,
CShuffleDataType
,
AElementOp
,
BElementOp
,
CDEElementOp
,
GemmSpec
,
32
,
16
,
16
,
4
,
16
,
16
,
16
,
1
,
1
,
S
<
2
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
16
,
16
,
1
,
S
<
4
,
1
,
8
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
16
,
2
,
1
,
1
,
1
,
S
<
1
,
16
,
1
,
2
>
,
8
>
;
using
DeviceOpInstance
=
ck
::
tensor_operation
::
device
::
DeviceGemmMultipleD_Wmma_CShuffle
<
ALayout
,
BLayout
,
ck
::
Tuple
<
DLayout
>
,
ELayout
,
ADataType
,
BDataType
,
AccDataType
,
CShuffleDataType
,
ck
::
Tuple
<
DDataType
>
,
EDataType
,
AElementOp
,
BElementOp
,
CDEElementOp
,
GemmSpec
,
2
,
// Prefetch stage
128
,
// BlockSize
128
,
// MPerBlock
64
,
// NPerBlock
64
,
// KPerBlock
8
,
// K1
16
,
// MPerWmma
16
,
// NPerWmma
4
,
// M-Repeat // M-PerWmma / M-Repeat = M-Wave
2
,
// N-Repeat // N-PerWmma / N-Repeat = N-Wave
S
<
4
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
S
<
4
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
1
,
// C shuffle (M Repeat) Per store
1
,
// C shuffle (N Repeat) Per store
S
<
1
,
32
,
1
,
4
>
,
8
>
;
int
main
(
int
argc
,
char
*
argv
[])
{
...
...
example/09_convnd_fwd/CMakeLists.txt
View file @
687d2b7e
...
...
@@ -6,6 +6,7 @@ foreach(gpu IN LISTS GPU_TARGETS)
add_example_executable
(
example_convnd_fwd_xdl_fp16 convnd_fwd_xdl_fp16.cpp
)
add_example_executable
(
example_convnd_fwd_xdl_bf16 convnd_fwd_xdl_bf16.cpp
)
add_example_executable
(
example_convnd_fwd_xdl_int8 convnd_fwd_xdl_int8.cpp
)
add_example_executable
(
example_convnd_fwd_xdl_fp8 convnd_fwd_xdl_fp8.cpp
)
# FIXME: re-enable this exampe as test when SWDEV-335738 is fixed
add_example_executable_no_testing
(
example_convnd_fwd_xdl_fp64 convnd_fwd_xdl_fp64.cpp
)
set
(
target 1
)
...
...
example/09_convnd_fwd/convnd_fwd_common.hpp
View file @
687d2b7e
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-202
3
, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-202
4
, Advanced Micro Devices, Inc. All rights reserved.
#include <cstdlib>
#include <iostream>
...
...
@@ -27,6 +27,88 @@ void print_helper_msg()
<<
ck
::
utils
::
conv
::
get_conv_param_parser_helper_msg
()
<<
std
::
endl
;
}
template
<
typename
DataType
>
inline
__host__
__device__
constexpr
double
get_rtol
()
{
if
constexpr
(
std
::
is_same_v
<
DataType
,
float
>
)
{
return
1e-3
;
}
else
if
constexpr
(
std
::
is_same_v
<
DataType
,
double
>
)
{
return
1e-6
;
}
else
if
constexpr
(
std
::
is_same_v
<
DataType
,
ck
::
half_t
>
)
{
return
1e-3
;
}
else
if
constexpr
(
std
::
is_same_v
<
DataType
,
ck
::
bhalf_t
>
)
{
return
5e-2
;
}
else
if
constexpr
(
std
::
is_same_v
<
DataType
,
int32_t
>
)
{
return
1e-1
;
}
else
if
constexpr
(
std
::
is_same_v
<
DataType
,
int8_t
>
)
{
return
1e-1
;
}
else
if
constexpr
(
std
::
is_same_v
<
DataType
,
ck
::
f8_t
>
)
{
return
1e-1
;
// 240 and 224 are acceptable
}
else
if
constexpr
(
std
::
is_same_v
<
DataType
,
ck
::
bf8_t
>
)
{
return
1.5e-1
;
// 57344 and 49152 are acceptable
}
else
{
return
1e-3
;
}
}
template
<
typename
DataType
>
inline
__host__
__device__
constexpr
double
get_atol
()
{
if
constexpr
(
std
::
is_same_v
<
DataType
,
float
>
)
{
return
1e-3
;
}
else
if
constexpr
(
std
::
is_same_v
<
DataType
,
double
>
)
{
return
1e-6
;
}
else
if
constexpr
(
std
::
is_same_v
<
DataType
,
ck
::
half_t
>
)
{
return
1e-3
;
}
else
if
constexpr
(
std
::
is_same_v
<
DataType
,
ck
::
bhalf_t
>
)
{
return
5e-2
;
}
else
if
constexpr
(
std
::
is_same_v
<
DataType
,
int32_t
>
)
{
return
1e-1
;
}
else
if
constexpr
(
std
::
is_same_v
<
DataType
,
int8_t
>
)
{
return
1e-1
;
}
else
if
constexpr
(
std
::
is_same_v
<
DataType
,
ck
::
f8_t
>
)
{
return
16.1
;
// 240 and 224 are acceptable
}
else
if
constexpr
(
std
::
is_same_v
<
DataType
,
ck
::
bf8_t
>
)
{
return
8192.1
;
// 57344 and 49152 are acceptable
}
else
{
return
1e-3
;
}
}
template
<
ck
::
index_t
NDimSpatial
,
typename
InDataType
,
typename
WeiDataType
,
...
...
@@ -164,8 +246,11 @@ bool run_grouped_conv_fwd(bool do_verification,
out_device_buf
.
FromDevice
(
out_device
.
mData
.
data
());
return
ck
::
utils
::
check_err
(
out_device
,
out_host
,
"Error: incorrect results!"
,
1e-5
f
,
1e-4
f
);
return
ck
::
utils
::
check_err
(
out_device
,
out_host
,
"Error: incorrect results!"
,
get_rtol
<
OutDataType
>
(),
get_atol
<
OutDataType
>
());
}
return
true
;
...
...
example/09_convnd_fwd/convnd_fwd_xdl_fp8.cpp
0 → 100644
View file @
687d2b7e
// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
#include "convnd_fwd_common.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_abd_xdl_cshuffle.hpp"
#include "ck/library/utility/convolution_host_tensor_descriptor_helper.hpp"
using
InDataType
=
ck
::
f8_t
;
using
WeiDataType
=
ck
::
f8_t
;
using
AccDataType
=
float
;
using
CShuffleDataType
=
ck
::
f8_t
;
using
OutDataType
=
ck
::
f8_t
;
using
ComputeDataType
=
ck
::
f8_t
;
template
<
ck
::
index_t
...
Is
>
using
S
=
ck
::
Sequence
<
Is
...
>
;
using
InElementOp
=
ck
::
tensor_operation
::
element_wise
::
PassThrough
;
using
WeiElementOp
=
ck
::
tensor_operation
::
element_wise
::
PassThrough
;
using
OutElementOp
=
ck
::
tensor_operation
::
element_wise
::
PassThrough
;
static
constexpr
auto
ConvSpec
=
ck
::
tensor_operation
::
device
::
ConvolutionForwardSpecialization
::
Default
;
static
constexpr
auto
GemmSpec
=
ck
::
tensor_operation
::
device
::
GemmSpecialization
::
MNKPadding
;
template
<
ck
::
index_t
NDimSpatial
,
typename
InLayout
,
typename
WeiLayout
,
typename
OutLayout
>
using
DeviceGroupedConvNDFwdInstance
=
ck
::
tensor_operation
::
device
::
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
<
NDimSpatial
,
InLayout
,
WeiLayout
,
ck
::
Tuple
<>
,
OutLayout
,
InDataType
,
WeiDataType
,
AccDataType
,
CShuffleDataType
,
ck
::
Tuple
<>
,
OutDataType
,
InElementOp
,
WeiElementOp
,
OutElementOp
,
ConvSpec
,
// ConvForwardSpecialization
GemmSpec
,
// GemmSpecialization
1
,
//
256
,
// BlockSize
128
,
// MPerBlock
256
,
// NPerBlock
32
,
// KPerBlock
8
,
// AK1
8
,
// BK1
32
,
// MPerXdl
32
,
// NPerXdl
2
,
// MXdlPerWave
4
,
// NXdlPerWave
S
<
4
,
64
,
1
>
,
// ABlockTransferThreadClusterLengths_AK0_M_AK1
S
<
1
,
0
,
2
>
,
// ABlockTransferThreadClusterArrangeOrder
S
<
1
,
0
,
2
>
,
// ABlockTransferSrcAccessOrder
2
,
// ABlockTransferSrcVectorDim
8
,
// ABlockTransferSrcScalarPerVector
8
,
// ABlockTransferDstScalarPerVector_AK1
1
,
// ABlockLdsExtraM
S
<
4
,
64
,
1
>
,
// BBlockTransferThreadClusterLengths_BK0_N_BK1
S
<
1
,
0
,
2
>
,
// BBlockTransferThreadClusterArrangeOrder
S
<
1
,
0
,
2
>
,
// BBlockTransferSrcAccessOrder
2
,
// BBlockTransferSrcVectorDim
8
,
// BBlockTransferSrcScalarPerVector
8
,
// BBlockTransferDstScalarPerVector_BK1
1
,
// BBlockLdsExtraN
1
,
1
,
S
<
1
,
32
,
1
,
8
>
,
8
,
ComputeDataType
>
;
#include "run_convnd_fwd_example.inc"
int
main
(
int
argc
,
char
*
argv
[])
{
return
run_convnd_fwd_example
(
argc
,
argv
)
?
0
:
1
;
}
Prev
1
2
3
4
5
6
7
…
9
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment