Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel_ROCM
Commits
e3d444c8
Commit
e3d444c8
authored
Oct 02, 2024
by
Mirza Halilcevic
Browse files
Merge remote-tracking branch 'upstream/develop' into ck_migraphx_integration
parents
24608d43
11b7a4db
Changes
46
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
1232 additions
and
50 deletions
+1232
-50
Jenkinsfile
Jenkinsfile
+7
-7
docs/sphinx/requirements.in
docs/sphinx/requirements.in
+1
-1
docs/sphinx/requirements.txt
docs/sphinx/requirements.txt
+1
-1
example/66_complex_contraction_bilinear/CMakeLists.txt
example/66_complex_contraction_bilinear/CMakeLists.txt
+3
-0
example/66_complex_contraction_bilinear/README.md
example/66_complex_contraction_bilinear/README.md
+11
-0
example/66_complex_contraction_bilinear/common_instances.hpp
example/66_complex_contraction_bilinear/common_instances.hpp
+196
-0
example/66_complex_contraction_bilinear/complex_contraction_bilinear_xdl_fp32.cpp
...action_bilinear/complex_contraction_bilinear_xdl_fp32.cpp
+86
-0
example/66_complex_contraction_bilinear/complex_contraction_bilinear_xdl_fp64.cpp
...action_bilinear/complex_contraction_bilinear_xdl_fp64.cpp
+86
-0
example/66_complex_contraction_bilinear/run_complex_contraction_bilinear_example.inc
...ion_bilinear/run_complex_contraction_bilinear_example.inc
+484
-0
example/ck_tile/01_fmha/fmha_bwd.cpp
example/ck_tile/01_fmha/fmha_bwd.cpp
+15
-2
example/ck_tile/01_fmha/fmha_fwd.cpp
example/ck_tile/01_fmha/fmha_fwd.cpp
+32
-15
example/ck_tile/01_fmha/fmha_fwd.hpp
example/ck_tile/01_fmha/fmha_fwd.hpp
+2
-7
example/ck_tile/04_img2col/CMakeLists.txt
example/ck_tile/04_img2col/CMakeLists.txt
+3
-0
example/ck_tile/04_img2col/README.md
example/ck_tile/04_img2col/README.md
+12
-0
example/ck_tile/04_img2col/image_to_column.cpp
example/ck_tile/04_img2col/image_to_column.cpp
+170
-0
example/ck_tile/04_img2col/image_to_column.hpp
example/ck_tile/04_img2col/image_to_column.hpp
+105
-0
example/ck_tile/CMakeLists.txt
example/ck_tile/CMakeLists.txt
+1
-0
include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops.hpp
...or_operation/gpu/block/blockwise_gemm_pipeline_xdlops.hpp
+2
-2
include/ck/tensor_operation/gpu/warp/dpp_gemm.hpp
include/ck/tensor_operation/gpu/warp/dpp_gemm.hpp
+9
-9
include/ck/tensor_operation/gpu/warp/wmma_gemm.hpp
include/ck/tensor_operation/gpu/warp/wmma_gemm.hpp
+6
-6
No files found.
Jenkinsfile
View file @
e3d444c8
...
@@ -320,7 +320,7 @@ def cmake_build(Map conf=[:]){
...
@@ -320,7 +320,7 @@ def cmake_build(Map conf=[:]){
if
(
package_build
==
true
&&
(
env
.
BRANCH_NAME
==
"develop"
||
env
.
BRANCH_NAME
==
"amd-master"
))
{
if
(
package_build
==
true
&&
(
env
.
BRANCH_NAME
==
"develop"
||
env
.
BRANCH_NAME
==
"amd-master"
))
{
archiveArtifacts
artifacts:
"build/*.deb"
,
allowEmptyArchive:
true
,
fingerprint:
true
archiveArtifacts
artifacts:
"build/*.deb"
,
allowEmptyArchive:
true
,
fingerprint:
true
}
}
if
(
params
.
RUN_CK_TILE_TESTS
){
if
(
params
.
RUN_CK_TILE_
FMHA_
TESTS
){
try
{
try
{
archiveArtifacts
"perf_fmha_fwd_*.log"
archiveArtifacts
"perf_fmha_fwd_*.log"
archiveArtifacts
"perf_fmha_bwd_*.log"
archiveArtifacts
"perf_fmha_bwd_*.log"
...
@@ -371,7 +371,7 @@ def buildHipClangJob(Map conf=[:]){
...
@@ -371,7 +371,7 @@ def buildHipClangJob(Map conf=[:]){
def
retimage
def
retimage
(
retimage
,
image
)
=
getDockerImage
(
conf
)
(
retimage
,
image
)
=
getDockerImage
(
conf
)
gitStatusWrapper
(
credentialsId:
"${
status_wrapper
_creds}"
,
gitHubContext:
"Jenkins - ${variant}"
,
account:
'ROCm'
,
repo:
'composable_kernel'
)
{
gitStatusWrapper
(
credentialsId:
"${
env.ck_git
_creds}"
,
gitHubContext:
"Jenkins - ${variant}"
,
account:
'ROCm'
,
repo:
'composable_kernel'
)
{
withDockerContainer
(
image:
image
,
args:
dockerOpts
+
' -v=/var/jenkins/:/var/jenkins'
)
{
withDockerContainer
(
image:
image
,
args:
dockerOpts
+
' -v=/var/jenkins/:/var/jenkins'
)
{
timeout
(
time:
48
,
unit:
'HOURS'
)
timeout
(
time:
48
,
unit:
'HOURS'
)
{
{
...
@@ -426,7 +426,7 @@ def runCKProfiler(Map conf=[:]){
...
@@ -426,7 +426,7 @@ def runCKProfiler(Map conf=[:]){
def
variant
=
env
.
STAGE_NAME
def
variant
=
env
.
STAGE_NAME
def
retimage
def
retimage
gitStatusWrapper
(
credentialsId:
"${
status_wrapper
_creds}"
,
gitHubContext:
"Jenkins - ${variant}"
,
account:
'ROCm'
,
repo:
'composable_kernel'
)
{
gitStatusWrapper
(
credentialsId:
"${
env.ck_git
_creds}"
,
gitHubContext:
"Jenkins - ${variant}"
,
account:
'ROCm'
,
repo:
'composable_kernel'
)
{
try
{
try
{
(
retimage
,
image
)
=
getDockerImage
(
conf
)
(
retimage
,
image
)
=
getDockerImage
(
conf
)
withDockerContainer
(
image:
image
,
args:
dockerOpts
)
{
withDockerContainer
(
image:
image
,
args:
dockerOpts
)
{
...
@@ -563,7 +563,7 @@ def Build_CK(Map conf=[:]){
...
@@ -563,7 +563,7 @@ def Build_CK(Map conf=[:]){
def
variant
=
env
.
STAGE_NAME
def
variant
=
env
.
STAGE_NAME
def
retimage
def
retimage
gitStatusWrapper
(
credentialsId:
"${env.
status_wrapper
_creds}"
,
gitHubContext:
"Jenkins - ${variant}"
,
account:
'ROCm'
,
repo:
'composable_kernel'
)
{
gitStatusWrapper
(
credentialsId:
"${env.
ck_git
_creds}"
,
gitHubContext:
"Jenkins - ${variant}"
,
account:
'ROCm'
,
repo:
'composable_kernel'
)
{
try
{
try
{
(
retimage
,
image
)
=
getDockerImage
(
conf
)
(
retimage
,
image
)
=
getDockerImage
(
conf
)
withDockerContainer
(
image:
image
,
args:
dockerOpts
)
{
withDockerContainer
(
image:
image
,
args:
dockerOpts
)
{
...
@@ -668,7 +668,7 @@ def process_results(Map conf=[:]){
...
@@ -668,7 +668,7 @@ def process_results(Map conf=[:]){
def
variant
=
env
.
STAGE_NAME
def
variant
=
env
.
STAGE_NAME
def
retimage
def
retimage
gitStatusWrapper
(
credentialsId:
"${env.
status_wrapper
_creds}"
,
gitHubContext:
"Jenkins - ${variant}"
,
account:
'ROCm'
,
repo:
'composable_kernel'
)
{
gitStatusWrapper
(
credentialsId:
"${env.
ck_git
_creds}"
,
gitHubContext:
"Jenkins - ${variant}"
,
account:
'ROCm'
,
repo:
'composable_kernel'
)
{
try
{
try
{
(
retimage
,
image
)
=
getDockerImage
(
conf
)
(
retimage
,
image
)
=
getDockerImage
(
conf
)
}
}
...
@@ -682,7 +682,7 @@ def process_results(Map conf=[:]){
...
@@ -682,7 +682,7 @@ def process_results(Map conf=[:]){
timeout
(
time:
1
,
unit:
'HOURS'
){
timeout
(
time:
1
,
unit:
'HOURS'
){
try
{
try
{
dir
(
"script"
){
dir
(
"script"
){
if
(
params
.
RUN_CK_TILE_TESTS
){
if
(
params
.
RUN_CK_TILE_
FMHA_
TESTS
){
try
{
try
{
unstash
"perf_fmha_fwd_gfx942.log"
unstash
"perf_fmha_fwd_gfx942.log"
unstash
"perf_fmha_bwd_gfx942.log"
unstash
"perf_fmha_bwd_gfx942.log"
...
@@ -838,7 +838,7 @@ pipeline {
...
@@ -838,7 +838,7 @@ pipeline {
dbsshport
=
"${dbsshport}"
dbsshport
=
"${dbsshport}"
dbsshuser
=
"${dbsshuser}"
dbsshuser
=
"${dbsshuser}"
dbsshpassword
=
"${dbsshpassword}"
dbsshpassword
=
"${dbsshpassword}"
status_wrapper_creds
=
"${status_wrapper
_creds}"
ck_git_creds
=
"${ck_git
_creds}"
gerrit_cred
=
"${gerrit_cred}"
gerrit_cred
=
"${gerrit_cred}"
DOCKER_BUILDKIT
=
"1"
DOCKER_BUILDKIT
=
"1"
}
}
...
...
docs/sphinx/requirements.in
View file @
e3d444c8
rocm-docs-core==1.8.
1
rocm-docs-core==1.8.
2
sphinxcontrib-bibtex==2.6.3
sphinxcontrib-bibtex==2.6.3
docs/sphinx/requirements.txt
View file @
e3d444c8
...
@@ -103,7 +103,7 @@ requests==2.32.3
...
@@ -103,7 +103,7 @@ requests==2.32.3
# via
# via
# pygithub
# pygithub
# sphinx
# sphinx
rocm-docs-core==1.8.
1
rocm-docs-core==1.8.
2
# via -r requirements.in
# via -r requirements.in
six==1.16.0
six==1.16.0
# via pybtex
# via pybtex
...
...
example/66_complex_contraction_bilinear/CMakeLists.txt
0 → 100755
View file @
e3d444c8
add_example_executable
(
example_complex_contraction_bilinear_xdl_fp32 complex_contraction_bilinear_xdl_fp32.cpp
)
add_example_executable
(
example_complex_contraction_bilinear_xdl_fp64 complex_contraction_bilinear_xdl_fp64.cpp
)
example/66_complex_contraction_bilinear/README.md
0 → 100755
View file @
e3d444c8
# Instructions for ```example_complex_contraction_bilinear_xdl_fp32```
## Run
```
bash
#arg1: verification (0=no, 1=yes)
#arg2: initialization (0=no init, 1=integer value, 2=decimal value)
#arg3: time kernel (0=no, 1=yes)
./bin/example_contraction_bilinear_xdl_fp32 1 1 1
```
example/66_complex_contraction_bilinear/common_instances.hpp
0 → 100644
View file @
e3d444c8
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_contraction_multiple_d_xdl_cshuffle.hpp"
using
F16
=
ck
::
half_t
;
using
BF16
=
ck
::
bhalf_t
;
using
F32
=
float
;
using
F64
=
double
;
template
<
ck
::
index_t
...
Is
>
using
S
=
ck
::
Sequence
<
Is
...
>
;
static
constexpr
auto
GemmSpec
=
ck
::
tensor_operation
::
device
::
GemmSpecialization
::
MNKPadding
;
// Generic instances for fp32, fp16 and bf16 data types.
template
<
ck
::
index_t
NumDimM
,
ck
::
index_t
NumDimN
,
ck
::
index_t
NumDimK
,
typename
ADataType
,
typename
BDataType
,
typename
AccDataType
,
typename
CShuffleDataType
,
typename
DsDataType
,
typename
EDataType
,
typename
ComputeDataType
,
typename
AElementOp
,
typename
BElementOp
,
typename
CDEElementOp
>
// clang-format off
using
DeviceOpInstanceKK_Generic
=
ck
::
tensor_operation
::
device
::
//#####################################| NumDimM| NumDimN| NumDimK| AData| BData| AccData| CShuffle| DsData| EData| A| B| CDE| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| Compute|
//#####################################| | | | Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Specialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| Data|
//#####################################| | | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| Type|
//#####################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
DeviceContractionMultipleD_Xdl_CShuffle
<
NumDimM
,
NumDimN
,
NumDimK
,
ADataType
,
BDataType
,
AccDataType
,
CShuffleDataType
,
DsDataType
,
EDataType
,
AElementOp
,
BElementOp
,
CDEElementOp
,
GemmSpec
,
1
,
256
,
256
,
128
,
16
,
4
,
4
,
32
,
32
,
4
,
2
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
4
,
4
,
1
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
4
,
4
,
1
,
1
,
1
,
S
<
1
,
16
,
1
,
16
>
,
4
,
ComputeDataType
>
;
// clang-format on
template
<
ck
::
index_t
NumDimM
,
ck
::
index_t
NumDimN
,
ck
::
index_t
NumDimK
,
typename
ADataType
,
typename
BDataType
,
typename
AccDataType
,
typename
CShuffleDataType
,
typename
DsDataType
,
typename
EDataType
,
typename
ComputeDataType
,
typename
AElementOp
,
typename
BElementOp
,
typename
CDEElementOp
>
// clang-format off
using
DeviceOpInstanceKN_Generic
=
ck
::
tensor_operation
::
device
::
//#####################################| NumDimM| NumDimN| NumDimK| AData| BData| AccData| CShuffle| DsData| EData| A| B| CDE| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| Compute|
//#####################################| | | | Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Specialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| Data|
//#####################################| | | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| Type|
//#####################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
DeviceContractionMultipleD_Xdl_CShuffle
<
NumDimM
,
NumDimN
,
NumDimK
,
ADataType
,
BDataType
,
AccDataType
,
CShuffleDataType
,
DsDataType
,
EDataType
,
AElementOp
,
BElementOp
,
CDEElementOp
,
GemmSpec
,
1
,
256
,
256
,
128
,
16
,
4
,
1
,
32
,
32
,
4
,
2
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
4
,
4
,
1
,
S
<
8
,
32
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
1
,
0
,
1
,
1
,
S
<
1
,
16
,
1
,
16
>
,
4
,
ComputeDataType
>
;
// clang-format on
template
<
ck
::
index_t
NumDimM
,
ck
::
index_t
NumDimN
,
ck
::
index_t
NumDimK
,
typename
ADataType
,
typename
BDataType
,
typename
AccDataType
,
typename
CShuffleDataType
,
typename
DsDataType
,
typename
EDataType
,
typename
ComputeDataType
,
typename
AElementOp
,
typename
BElementOp
,
typename
CDEElementOp
>
// clang-format off
using
DeviceOpInstanceMK_Generic
=
ck
::
tensor_operation
::
device
::
//#####################################| NumDimM| NumDimN| NumDimK| AData| BData| AccData| CShuffle| DsData| EData| A| B| CDE| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| Compute|
//#####################################| | | | Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Specialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| Data|
//#####################################| | | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| Type|
//#####################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
DeviceContractionMultipleD_Xdl_CShuffle
<
NumDimM
,
NumDimN
,
NumDimK
,
ADataType
,
BDataType
,
AccDataType
,
CShuffleDataType
,
DsDataType
,
EDataType
,
AElementOp
,
BElementOp
,
CDEElementOp
,
GemmSpec
,
1
,
256
,
256
,
128
,
16
,
1
,
4
,
32
,
32
,
4
,
2
,
S
<
4
,
64
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
1
,
0
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
4
,
4
,
1
,
1
,
1
,
S
<
1
,
16
,
1
,
16
>
,
4
,
ComputeDataType
>
;
// clang-format on
template
<
ck
::
index_t
NumDimM
,
ck
::
index_t
NumDimN
,
ck
::
index_t
NumDimK
,
typename
ADataType
,
typename
BDataType
,
typename
AccDataType
,
typename
CShuffleDataType
,
typename
DsDataType
,
typename
EDataType
,
typename
ComputeDataType
,
typename
AElementOp
,
typename
BElementOp
,
typename
CDEElementOp
>
// clang-format off
using
DeviceOpInstanceMN_Generic
=
ck
::
tensor_operation
::
device
::
//#####################################| NumDimM| NumDimN| NumDimK| AData| BData| AccData| CShuffle| DsData| EData| A| B| CDE| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| Compute|
//#####################################| | | | Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Specialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| Data|
//#####################################| | | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| Type|
//#####################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
DeviceContractionMultipleD_Xdl_CShuffle
<
NumDimM
,
NumDimN
,
NumDimK
,
ADataType
,
BDataType
,
AccDataType
,
CShuffleDataType
,
DsDataType
,
EDataType
,
AElementOp
,
BElementOp
,
CDEElementOp
,
GemmSpec
,
1
,
256
,
256
,
128
,
16
,
1
,
1
,
32
,
32
,
4
,
2
,
S
<
4
,
64
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
1
,
0
,
S
<
8
,
32
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
1
,
0
,
1
,
1
,
S
<
1
,
16
,
1
,
16
>
,
4
,
ComputeDataType
>
;
// clang-format on
// Fp64 instances.
template
<
ck
::
index_t
NumDimM
,
ck
::
index_t
NumDimN
,
ck
::
index_t
NumDimK
,
typename
ADataType
,
typename
BDataType
,
typename
AccDataType
,
typename
CShuffleDataType
,
typename
DsDataType
,
typename
EDataType
,
typename
ComputeDataType
,
typename
AElementOp
,
typename
BElementOp
,
typename
CDEElementOp
>
// clang-format off
using
DeviceOpInstanceKK_FP64
=
ck
::
tensor_operation
::
device
::
//#####################################| NumDimM| NumDimN| NumDimK| AData| BData| AccData| CShuffle| DsData| EData| A| B| CDE| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| Compute|
//#####################################| | | | Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Specialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| Data|
//#####################################| | | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| Type|
//#####################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
DeviceContractionMultipleD_Xdl_CShuffle
<
NumDimM
,
NumDimN
,
NumDimK
,
ADataType
,
BDataType
,
AccDataType
,
CShuffleDataType
,
DsDataType
,
EDataType
,
AElementOp
,
BElementOp
,
CDEElementOp
,
GemmSpec
,
1
,
256
,
128
,
128
,
16
,
2
,
2
,
16
,
16
,
4
,
4
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
2
,
2
,
1
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
2
,
2
,
1
,
1
,
1
,
S
<
1
,
16
,
1
,
16
>
,
1
,
ComputeDataType
>
;
// clang-format on
template
<
ck
::
index_t
NumDimM
,
ck
::
index_t
NumDimN
,
ck
::
index_t
NumDimK
,
typename
ADataType
,
typename
BDataType
,
typename
AccDataType
,
typename
CShuffleDataType
,
typename
DsDataType
,
typename
EDataType
,
typename
ComputeDataType
,
typename
AElementOp
,
typename
BElementOp
,
typename
CDEElementOp
>
// clang-format off
using
DeviceOpInstanceKN_FP64
=
ck
::
tensor_operation
::
device
::
//#####################################| NumDimM| NumDimN| NumDimK| AData| BData| AccData| CShuffle| DsData| EData| A| B| CDE| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| Compute|
//#####################################| | | | Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Specialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| Data|
//#####################################| | | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| Type|
//#####################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
DeviceContractionMultipleD_Xdl_CShuffle
<
NumDimM
,
NumDimN
,
NumDimK
,
ADataType
,
BDataType
,
AccDataType
,
CShuffleDataType
,
DsDataType
,
EDataType
,
AElementOp
,
BElementOp
,
CDEElementOp
,
GemmSpec
,
1
,
256
,
128
,
128
,
16
,
2
,
1
,
16
,
16
,
4
,
4
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
2
,
2
,
1
,
S
<
8
,
32
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
2
,
1
,
0
,
1
,
1
,
S
<
1
,
16
,
1
,
16
>
,
1
,
ComputeDataType
>
;
// clang-format on
template
<
ck
::
index_t
NumDimM
,
ck
::
index_t
NumDimN
,
ck
::
index_t
NumDimK
,
typename
ADataType
,
typename
BDataType
,
typename
AccDataType
,
typename
CShuffleDataType
,
typename
DsDataType
,
typename
EDataType
,
typename
ComputeDataType
,
typename
AElementOp
,
typename
BElementOp
,
typename
CDEElementOp
>
// clang-format off
using
DeviceOpInstanceMK_FP64
=
ck
::
tensor_operation
::
device
::
//#####################################| NumDimM| NumDimN| NumDimK| AData| BData| AccData| CShuffle| DsData| EData| A| B| CDE| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| Compute|
//#####################################| | | | Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Specialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| Data|
//#####################################| | | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| Type|
//#####################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
DeviceContractionMultipleD_Xdl_CShuffle
<
NumDimM
,
NumDimN
,
NumDimK
,
ADataType
,
BDataType
,
AccDataType
,
CShuffleDataType
,
DsDataType
,
EDataType
,
AElementOp
,
BElementOp
,
CDEElementOp
,
GemmSpec
,
1
,
256
,
128
,
128
,
16
,
1
,
2
,
16
,
16
,
4
,
4
,
S
<
4
,
64
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
2
,
1
,
0
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
2
,
2
,
1
,
1
,
1
,
S
<
1
,
16
,
1
,
16
>
,
1
,
ComputeDataType
>
;
// clang-format on
template
<
ck
::
index_t
NumDimM
,
ck
::
index_t
NumDimN
,
ck
::
index_t
NumDimK
,
typename
ADataType
,
typename
BDataType
,
typename
AccDataType
,
typename
CShuffleDataType
,
typename
DsDataType
,
typename
EDataType
,
typename
ComputeDataType
,
typename
AElementOp
,
typename
BElementOp
,
typename
CDEElementOp
>
// clang-format off
using
DeviceOpInstanceMN_FP64
=
ck
::
tensor_operation
::
device
::
//#####################################| NumDimM| NumDimN| NumDimK| AData| BData| AccData| CShuffle| DsData| EData| A| B| CDE| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| Compute|
//#####################################| | | | Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Specialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| Data|
//#####################################| | | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| Type|
//#####################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
DeviceContractionMultipleD_Xdl_CShuffle
<
NumDimM
,
NumDimN
,
NumDimK
,
ADataType
,
BDataType
,
AccDataType
,
CShuffleDataType
,
DsDataType
,
EDataType
,
AElementOp
,
BElementOp
,
CDEElementOp
,
GemmSpec
,
1
,
256
,
128
,
128
,
16
,
1
,
1
,
16
,
16
,
4
,
4
,
S
<
4
,
64
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
2
,
1
,
0
,
S
<
8
,
32
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
2
,
1
,
0
,
1
,
1
,
S
<
1
,
16
,
1
,
16
>
,
1
,
ComputeDataType
>
;
// clang-format on
example/66_complex_contraction_bilinear/complex_contraction_bilinear_xdl_fp32.cpp
0 → 100755
View file @
e3d444c8
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "common_instances.hpp"
using
ADataType
=
F32
;
using
BDataType
=
F32
;
using
AccDataType
=
F32
;
using
CShuffleDataType
=
F32
;
using
DDataType
=
F32
;
using
DsDataType
=
ck
::
Tuple
<
DDataType
>
;
using
EDataType
=
F32
;
using
ComputeDataType
=
F32
;
static
constexpr
ck
::
index_t
NumDimM
=
2
;
static
constexpr
ck
::
index_t
NumDimN
=
2
;
static
constexpr
ck
::
index_t
NumDimK
=
2
;
using
AElementOp
=
ck
::
tensor_operation
::
element_wise
::
PassThrough
;
using
BElementOp
=
ck
::
tensor_operation
::
element_wise
::
PassThrough
;
using
CDEElementOp
=
ck
::
tensor_operation
::
element_wise
::
Bilinear
;
using
DeviceOpInstanceKKNN
=
DeviceOpInstanceKK_Generic
<
NumDimM
,
NumDimN
,
NumDimK
,
ADataType
,
BDataType
,
AccDataType
,
CShuffleDataType
,
DsDataType
,
EDataType
,
ComputeDataType
,
AElementOp
,
BElementOp
,
CDEElementOp
>
;
using
DeviceOpInstanceKNNN
=
DeviceOpInstanceKN_Generic
<
NumDimM
,
NumDimN
,
NumDimK
,
ADataType
,
BDataType
,
AccDataType
,
CShuffleDataType
,
DsDataType
,
EDataType
,
ComputeDataType
,
AElementOp
,
BElementOp
,
CDEElementOp
>
;
using
DeviceOpInstanceMKNN
=
DeviceOpInstanceMK_Generic
<
NumDimM
,
NumDimN
,
NumDimK
,
ADataType
,
BDataType
,
AccDataType
,
CShuffleDataType
,
DsDataType
,
EDataType
,
ComputeDataType
,
AElementOp
,
BElementOp
,
CDEElementOp
>
;
using
DeviceOpInstanceMNNN
=
DeviceOpInstanceMN_Generic
<
NumDimM
,
NumDimN
,
NumDimK
,
ADataType
,
BDataType
,
AccDataType
,
CShuffleDataType
,
DsDataType
,
EDataType
,
ComputeDataType
,
AElementOp
,
BElementOp
,
CDEElementOp
>
;
using
DeviceOpInstance
=
DeviceOpInstanceKKNN
;
#include "run_complex_contraction_bilinear_example.inc"
int
main
(
int
argc
,
char
*
argv
[])
{
return
run_complex_contraction_bilinear_example
(
argc
,
argv
);
}
example/66_complex_contraction_bilinear/complex_contraction_bilinear_xdl_fp64.cpp
0 → 100755
View file @
e3d444c8
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "common_instances.hpp"
using
ADataType
=
F64
;
using
BDataType
=
F64
;
using
AccDataType
=
F64
;
using
CShuffleDataType
=
F64
;
using
DDataType
=
F64
;
using
DsDataType
=
ck
::
Tuple
<
DDataType
>
;
using
EDataType
=
F64
;
using
ComputeDataType
=
F64
;
static
constexpr
ck
::
index_t
NumDimM
=
2
;
static
constexpr
ck
::
index_t
NumDimN
=
2
;
static
constexpr
ck
::
index_t
NumDimK
=
2
;
using
AElementOp
=
ck
::
tensor_operation
::
element_wise
::
PassThrough
;
using
BElementOp
=
ck
::
tensor_operation
::
element_wise
::
PassThrough
;
using
CDEElementOp
=
ck
::
tensor_operation
::
element_wise
::
Bilinear
;
using
DeviceOpInstanceKKNN
=
DeviceOpInstanceKK_FP64
<
NumDimM
,
NumDimN
,
NumDimK
,
ADataType
,
BDataType
,
AccDataType
,
CShuffleDataType
,
DsDataType
,
EDataType
,
ComputeDataType
,
AElementOp
,
BElementOp
,
CDEElementOp
>
;
using
DeviceOpInstanceKNNN
=
DeviceOpInstanceKN_FP64
<
NumDimM
,
NumDimN
,
NumDimK
,
ADataType
,
BDataType
,
AccDataType
,
CShuffleDataType
,
DsDataType
,
EDataType
,
ComputeDataType
,
AElementOp
,
BElementOp
,
CDEElementOp
>
;
using
DeviceOpInstanceMKNN
=
DeviceOpInstanceMK_FP64
<
NumDimM
,
NumDimN
,
NumDimK
,
ADataType
,
BDataType
,
AccDataType
,
CShuffleDataType
,
DsDataType
,
EDataType
,
ComputeDataType
,
AElementOp
,
BElementOp
,
CDEElementOp
>
;
using
DeviceOpInstanceMNNN
=
DeviceOpInstanceMN_FP64
<
NumDimM
,
NumDimN
,
NumDimK
,
ADataType
,
BDataType
,
AccDataType
,
CShuffleDataType
,
DsDataType
,
EDataType
,
ComputeDataType
,
AElementOp
,
BElementOp
,
CDEElementOp
>
;
using
DeviceOpInstance
=
DeviceOpInstanceKKNN
;
#include "run_complex_contraction_bilinear_example.inc"
int
main
(
int
argc
,
char
*
argv
[])
{
return
run_complex_contraction_bilinear_example
(
argc
,
argv
);
}
example/66_complex_contraction_bilinear/run_complex_contraction_bilinear_example.inc
0 → 100755
View file @
e3d444c8
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <cstdlib>
#include <iostream>
#include <string>
#include <vector>
#include "ck/ck.hpp"
#include "ck/library/utility/check_err.hpp"
#include "ck/library/utility/device_memory.hpp"
#include "ck/library/utility/host_tensor_generator.hpp"
#include "ck/library/utility/numeric.hpp"
#include "ck/library/reference_tensor_operation/cpu/reference_contraction.hpp"
int
run_complex_contraction_bilinear_example
(
int
argc
,
char
*
argv
[])
{
bool
do_verification
=
true
;
int
init_method
=
1
;
bool
time_kernel
=
false
;
// A[M0, M1, K0, K1]
std
::
vector
<
ck
::
index_t
>
a_ms_ks_lengths
{
30
,
128
,
32
,
64
};
std
::
vector
<
ck
::
index_t
>
a_ms_ks_strides
{
524288
,
4096
,
128
,
1
};
// B[N0, N1, K0, K1]
std
::
vector
<
ck
::
index_t
>
b_ns_ks_lengths
{
32
,
64
,
32
,
64
};
std
::
vector
<
ck
::
index_t
>
b_ns_ks_strides
{
524288
,
4096
,
128
,
1
};
// D[M0, M1, N0, N1]
std
::
vector
<
ck
::
index_t
>
d_ms_ns_lengths
{
30
,
128
,
32
,
64
};
std
::
vector
<
ck
::
index_t
>
d_ms_ns_strides
{
524288
,
4096
,
128
,
1
};
// E[M0, M1, N0, N1]
std
::
vector
<
ck
::
index_t
>
e_ms_ns_lengths
{
30
,
128
,
32
,
64
};
std
::
vector
<
ck
::
index_t
>
e_ms_ns_strides
{
524288
,
4096
,
128
,
1
};
float
alpha
=
1.
f
;
float
beta
=
1.
f
;
if
(
argc
==
1
)
{
// use default case
}
else
if
(
argc
==
4
)
{
do_verification
=
std
::
stoi
(
argv
[
1
]);
init_method
=
std
::
stoi
(
argv
[
2
]);
time_kernel
=
std
::
stoi
(
argv
[
3
]);
}
else
if
(
argc
==
28
)
{
do_verification
=
std
::
stoi
(
argv
[
1
]);
init_method
=
std
::
stoi
(
argv
[
2
]);
time_kernel
=
std
::
stoi
(
argv
[
3
]);
const
ck
::
index_t
M0
=
std
::
stoi
(
argv
[
4
]);
const
ck
::
index_t
M1
=
std
::
stoi
(
argv
[
5
]);
const
ck
::
index_t
N0
=
std
::
stoi
(
argv
[
6
]);
const
ck
::
index_t
N1
=
std
::
stoi
(
argv
[
7
]);
const
ck
::
index_t
K0
=
std
::
stoi
(
argv
[
8
]);
const
ck
::
index_t
K1
=
std
::
stoi
(
argv
[
9
]);
a_ms_ks_lengths
=
{
M0
,
M1
,
K0
,
K1
};
a_ms_ks_strides
=
{
std
::
stoi
(
argv
[
10
]),
std
::
stoi
(
argv
[
11
]),
std
::
stoi
(
argv
[
12
]),
std
::
stoi
(
argv
[
13
])};
b_ns_ks_lengths
=
{
N0
,
N1
,
K0
,
K1
};
b_ns_ks_strides
=
{
std
::
stoi
(
argv
[
14
]),
std
::
stoi
(
argv
[
15
]),
std
::
stoi
(
argv
[
16
]),
std
::
stoi
(
argv
[
17
])};
d_ms_ns_lengths
=
{
M0
,
M1
,
N0
,
N1
};
d_ms_ns_strides
=
{
std
::
stoi
(
argv
[
18
]),
std
::
stoi
(
argv
[
19
]),
std
::
stoi
(
argv
[
20
]),
std
::
stoi
(
argv
[
21
])};
e_ms_ns_lengths
=
{
M0
,
M1
,
N0
,
N1
};
e_ms_ns_strides
=
{
std
::
stoi
(
argv
[
22
]),
std
::
stoi
(
argv
[
23
]),
std
::
stoi
(
argv
[
24
]),
std
::
stoi
(
argv
[
25
])};
alpha
=
std
::
stof
(
argv
[
26
]);
beta
=
std
::
stof
(
argv
[
27
]);
}
else
{
printf
(
"arg1: verification (0=no, 1=yes)
\n
"
);
printf
(
"arg2: initialization (0=no init, 1=integer value, 2=decimal value)
\n
"
);
printf
(
"arg3: time kernel (0=no, 1=yes)
\n
"
);
printf
(
"arg4 to 9: M0, M1, N0, N1, K0, K1
\n
"
);
printf
(
"arg10 to 13: Stride_A_M0, Stride_A_M1, Stride_A_K0, Stride_A_K1
\n
"
);
printf
(
"arg14 to 17: Stride_B_N0, Stride_B_N1, Stride_B_K0, Stride_B_K1
\n
"
);
printf
(
"arg18 to 21: Stride_D_M0, Stride_D_M1, Stride_D_N0, Stride_D_N1
\n
"
);
printf
(
"arg22 to 25: Stride_E_M0, Stride_E_M1, Stride_E_N0, Stride_E_N1
\n
"
);
printf
(
"arg26 to 27: alpha, beta
\n
"
);
exit
(
0
);
}
// For Real Part of Complex Tensor
Tensor
<
ADataType
>
a_ms_ks_re
(
a_ms_ks_lengths
,
a_ms_ks_strides
);
Tensor
<
BDataType
>
b_ns_ks_re
(
b_ns_ks_lengths
,
b_ns_ks_strides
);
Tensor
<
EDataType
>
d_ms_ns_re
(
d_ms_ns_lengths
,
d_ms_ns_strides
);
Tensor
<
EDataType
>
e_ms_ns_host_result_re
(
e_ms_ns_lengths
,
e_ms_ns_strides
);
Tensor
<
EDataType
>
e_ms_ns_device_result_re
(
e_ms_ns_lengths
,
e_ms_ns_strides
);
// For Imaginary Part of Complex Tensor
Tensor
<
ADataType
>
a_ms_ks_img
(
a_ms_ks_lengths
,
a_ms_ks_strides
);
Tensor
<
BDataType
>
b_ns_ks_img
(
b_ns_ks_lengths
,
b_ns_ks_strides
);
Tensor
<
EDataType
>
d_ms_ns_img
(
d_ms_ns_lengths
,
d_ms_ns_strides
);
Tensor
<
EDataType
>
e_ms_ns_host_result_img
(
e_ms_ns_lengths
,
e_ms_ns_strides
);
Tensor
<
EDataType
>
e_ms_ns_device_result_img
(
e_ms_ns_lengths
,
e_ms_ns_strides
);
// Intermediate E tensor Definition
Tensor
<
EDataType
>
e_ms_ns_device_result_re1
(
e_ms_ns_lengths
,
e_ms_ns_strides
);
Tensor
<
EDataType
>
e_ms_ns_device_result_img1
(
e_ms_ns_lengths
,
e_ms_ns_strides
);
std
::
cout
<<
"a_ms_ks_re: "
<<
a_ms_ks_re
.
mDesc
<<
std
::
endl
;
std
::
cout
<<
"b_ns_ks_re: "
<<
b_ns_ks_re
.
mDesc
<<
std
::
endl
;
std
::
cout
<<
"d_ms_ns_re: "
<<
d_ms_ns_re
.
mDesc
<<
std
::
endl
;
std
::
cout
<<
"e_ms_ns_re: "
<<
e_ms_ns_host_result_re
.
mDesc
<<
std
::
endl
;
std
::
cout
<<
"a_ms_ks_img: "
<<
a_ms_ks_img
.
mDesc
<<
std
::
endl
;
std
::
cout
<<
"b_ns_ks_img: "
<<
b_ns_ks_img
.
mDesc
<<
std
::
endl
;
std
::
cout
<<
"d_ms_ns_img: "
<<
d_ms_ns_img
.
mDesc
<<
std
::
endl
;
std
::
cout
<<
"e_ms_ns_img: "
<<
e_ms_ns_host_result_img
.
mDesc
<<
std
::
endl
;
switch
(
init_method
)
{
case
0
:
break
;
case
1
:
a_ms_ks_re
.
GenerateTensorValue
(
GeneratorTensor_2
<
ADataType
>
{
-
5
,
5
});
b_ns_ks_re
.
GenerateTensorValue
(
GeneratorTensor_2
<
BDataType
>
{
-
5
,
5
});
d_ms_ns_re
.
GenerateTensorValue
(
GeneratorTensor_2
<
BDataType
>
{
-
5
,
5
});
a_ms_ks_img
.
GenerateTensorValue
(
GeneratorTensor_2
<
ADataType
>
{
-
5
,
5
});
b_ns_ks_img
.
GenerateTensorValue
(
GeneratorTensor_2
<
BDataType
>
{
-
5
,
5
});
d_ms_ns_img
.
GenerateTensorValue
(
GeneratorTensor_2
<
BDataType
>
{
-
5
,
5
});
break
;
default
:
a_ms_ks_re
.
GenerateTensorValue
(
GeneratorTensor_3
<
ADataType
>
{
0.0
,
1.0
});
b_ns_ks_re
.
GenerateTensorValue
(
GeneratorTensor_3
<
BDataType
>
{
-
0.5
,
0.5
});
d_ms_ns_re
.
GenerateTensorValue
(
GeneratorTensor_3
<
BDataType
>
{
-
0.5
,
0.5
});
a_ms_ks_img
.
GenerateTensorValue
(
GeneratorTensor_3
<
ADataType
>
{
0.0
,
1.0
});
b_ns_ks_img
.
GenerateTensorValue
(
GeneratorTensor_3
<
BDataType
>
{
-
0.5
,
0.5
});
d_ms_ns_img
.
GenerateTensorValue
(
GeneratorTensor_3
<
BDataType
>
{
-
0.5
,
0.5
});
break
;
}
DeviceMem
a_device_buf_re
(
sizeof
(
ADataType
)
*
a_ms_ks_re
.
mDesc
.
GetElementSpaceSize
());
DeviceMem
b_device_buf_re
(
sizeof
(
BDataType
)
*
b_ns_ks_re
.
mDesc
.
GetElementSpaceSize
());
DeviceMem
d_device_buf_re
(
sizeof
(
DDataType
)
*
d_ms_ns_re
.
mDesc
.
GetElementSpaceSize
());
DeviceMem
e_device_buf_re
(
sizeof
(
EDataType
)
*
e_ms_ns_device_result_re
.
mDesc
.
GetElementSpaceSize
());
DeviceMem
a_device_buf_img
(
sizeof
(
ADataType
)
*
a_ms_ks_img
.
mDesc
.
GetElementSpaceSize
());
DeviceMem
b_device_buf_img
(
sizeof
(
BDataType
)
*
b_ns_ks_img
.
mDesc
.
GetElementSpaceSize
());
DeviceMem
d_device_buf_img
(
sizeof
(
DDataType
)
*
d_ms_ns_img
.
mDesc
.
GetElementSpaceSize
());
DeviceMem
e_device_buf_img
(
sizeof
(
EDataType
)
*
e_ms_ns_device_result_img
.
mDesc
.
GetElementSpaceSize
());
// Intermediate Value For E Real and Img
DeviceMem
e_device_buf_re1
(
sizeof
(
EDataType
)
*
e_ms_ns_device_result_re
.
mDesc
.
GetElementSpaceSize
());
DeviceMem
e_device_buf_img1
(
sizeof
(
EDataType
)
*
e_ms_ns_device_result_img
.
mDesc
.
GetElementSpaceSize
());
a_device_buf_re
.
ToDevice
(
a_ms_ks_re
.
mData
.
data
());
b_device_buf_re
.
ToDevice
(
b_ns_ks_re
.
mData
.
data
());
d_device_buf_re
.
ToDevice
(
d_ms_ns_re
.
mData
.
data
());
a_device_buf_img
.
ToDevice
(
a_ms_ks_img
.
mData
.
data
());
b_device_buf_img
.
ToDevice
(
b_ns_ks_img
.
mData
.
data
());
d_device_buf_img
.
ToDevice
(
d_ms_ns_img
.
mData
.
data
());
// set zero
e_device_buf_re
.
SetZero
();
e_device_buf_img
.
SetZero
();
// set zero for intermediate values
e_device_buf_re1
.
SetZero
();
e_device_buf_img1
.
SetZero
();
auto
a_element_op
=
AElementOp
{};
auto
b_element_op
=
BElementOp
{};
auto
cde_element_op
=
CDEElementOp
{
alpha
,
beta
};
// device operation
// For real Intermediate Value re_1
auto
op
=
DeviceOpInstance
{};
auto
invoker
=
op
.
MakeInvoker
();
auto
argument_re1
=
op
.
MakeArgument
(
a_device_buf_re
.
GetDeviceBuffer
(),
b_device_buf_re
.
GetDeviceBuffer
(),
std
::
array
<
const
void
*
,
1
>
{
d_device_buf_re
.
GetDeviceBuffer
()},
e_device_buf_re1
.
GetDeviceBuffer
(),
a_ms_ks_lengths
,
a_ms_ks_strides
,
b_ns_ks_lengths
,
b_ns_ks_strides
,
std
::
array
<
std
::
vector
<
ck
::
index_t
>
,
1
>
{
d_ms_ns_lengths
},
std
::
array
<
std
::
vector
<
ck
::
index_t
>
,
1
>
{
d_ms_ns_strides
},
e_ms_ns_lengths
,
e_ms_ns_strides
,
a_element_op
,
b_element_op
,
cde_element_op
);
if
(
!
op
.
IsSupportedArgument
(
argument_re1
))
{
std
::
cout
<<
op
.
GetTypeString
()
<<
" does not support this problem"
<<
std
::
endl
;
return
0
;
}
float
ave_time_re1
=
invoker
.
Run
(
argument_re1
,
StreamConfig
{
nullptr
,
time_kernel
});
alpha
=
-
1.
f
;
beta
=
1.
f
;
a_element_op
=
AElementOp
{};
b_element_op
=
BElementOp
{};
cde_element_op
=
CDEElementOp
{
alpha
,
beta
};
// device operation
// For real Intermediate Value re_2
// auto op = DeviceOpInstance{};
// auto invoker = op.MakeInvoker();
auto
argument_re2
=
op
.
MakeArgument
(
a_device_buf_img
.
GetDeviceBuffer
(),
b_device_buf_img
.
GetDeviceBuffer
(),
std
::
array
<
const
void
*
,
1
>
{
e_device_buf_re1
.
GetDeviceBuffer
()},
e_device_buf_re
.
GetDeviceBuffer
(),
a_ms_ks_lengths
,
a_ms_ks_strides
,
b_ns_ks_lengths
,
b_ns_ks_strides
,
std
::
array
<
std
::
vector
<
ck
::
index_t
>
,
1
>
{
d_ms_ns_lengths
},
std
::
array
<
std
::
vector
<
ck
::
index_t
>
,
1
>
{
d_ms_ns_strides
},
e_ms_ns_lengths
,
e_ms_ns_strides
,
a_element_op
,
b_element_op
,
cde_element_op
);
if
(
!
op
.
IsSupportedArgument
(
argument_re2
))
{
std
::
cout
<<
op
.
GetTypeString
()
<<
" does not support this problem"
<<
std
::
endl
;
return
0
;
}
float
ave_time_re2
=
invoker
.
Run
(
argument_re2
,
StreamConfig
{
nullptr
,
time_kernel
});
alpha
=
1.
f
;
beta
=
1.
f
;
a_element_op
=
AElementOp
{};
b_element_op
=
BElementOp
{};
cde_element_op
=
CDEElementOp
{
alpha
,
beta
};
auto
argument_img1
=
op
.
MakeArgument
(
a_device_buf_re
.
GetDeviceBuffer
(),
b_device_buf_img
.
GetDeviceBuffer
(),
std
::
array
<
const
void
*
,
1
>
{
d_device_buf_img
.
GetDeviceBuffer
()},
e_device_buf_img1
.
GetDeviceBuffer
(),
a_ms_ks_lengths
,
a_ms_ks_strides
,
b_ns_ks_lengths
,
b_ns_ks_strides
,
std
::
array
<
std
::
vector
<
ck
::
index_t
>
,
1
>
{
d_ms_ns_lengths
},
std
::
array
<
std
::
vector
<
ck
::
index_t
>
,
1
>
{
d_ms_ns_strides
},
e_ms_ns_lengths
,
e_ms_ns_strides
,
a_element_op
,
b_element_op
,
cde_element_op
);
if
(
!
op
.
IsSupportedArgument
(
argument_img1
))
{
std
::
cout
<<
op
.
GetTypeString
()
<<
" does not support this problem"
<<
std
::
endl
;
return
0
;
}
float
ave_time_img1
=
invoker
.
Run
(
argument_img1
,
StreamConfig
{
nullptr
,
time_kernel
});
alpha
=
1.
f
;
beta
=
1.
f
;
auto
argument_img2
=
op
.
MakeArgument
(
a_device_buf_img
.
GetDeviceBuffer
(),
b_device_buf_re
.
GetDeviceBuffer
(),
std
::
array
<
const
void
*
,
1
>
{
e_device_buf_img1
.
GetDeviceBuffer
()},
e_device_buf_img
.
GetDeviceBuffer
(),
a_ms_ks_lengths
,
a_ms_ks_strides
,
b_ns_ks_lengths
,
b_ns_ks_strides
,
std
::
array
<
std
::
vector
<
ck
::
index_t
>
,
1
>
{
d_ms_ns_lengths
},
std
::
array
<
std
::
vector
<
ck
::
index_t
>
,
1
>
{
d_ms_ns_strides
},
e_ms_ns_lengths
,
e_ms_ns_strides
,
a_element_op
,
b_element_op
,
cde_element_op
);
if
(
!
op
.
IsSupportedArgument
(
argument_img2
))
{
std
::
cout
<<
op
.
GetTypeString
()
<<
" does not support this problem"
<<
std
::
endl
;
return
0
;
}
float
ave_time_img2
=
invoker
.
Run
(
argument_img2
,
StreamConfig
{
nullptr
,
time_kernel
});
ck
::
index_t
M
=
ck
::
accumulate_n
<
ck
::
index_t
>
(
e_ms_ns_lengths
.
begin
(),
NumDimM
,
1
,
std
::
multiplies
<>
{});
ck
::
index_t
N
=
ck
::
accumulate_n
<
ck
::
index_t
>
(
e_ms_ns_lengths
.
begin
()
+
NumDimM
,
NumDimN
,
1
,
std
::
multiplies
<>
{});
ck
::
index_t
K
=
ck
::
accumulate_n
<
ck
::
index_t
>
(
a_ms_ks_lengths
.
begin
()
+
NumDimM
,
NumDimK
,
1
,
std
::
multiplies
<>
{});
std
::
size_t
flop
=
std
::
size_t
(
2
)
*
M
*
N
*
K
*
2
;
std
::
size_t
num_btype
=
sizeof
(
ADataType
)
*
M
*
K
+
sizeof
(
BDataType
)
*
K
*
N
+
sizeof
(
DDataType
)
*
M
*
N
+
sizeof
(
EDataType
)
*
M
*
N
*
2
;
float
ave_time
=
ave_time_img2
+
ave_time_img1
+
ave_time_re2
+
ave_time_re1
;
float
tflops
=
static_cast
<
float
>
(
flop
)
/
1.E9
/
ave_time
;
float
gb_per_sec
=
num_btype
/
1.E6
/
ave_time
;
std
::
cout
<<
"Perf: "
<<
ave_time
<<
" ms, "
<<
tflops
<<
" TFlops, "
<<
gb_per_sec
<<
" GB/s, "
<<
op
.
GetTypeString
()
<<
std
::
endl
;
e_device_buf_re
.
FromDevice
(
e_ms_ns_device_result_re
.
mData
.
data
());
e_device_buf_img
.
FromDevice
(
e_ms_ns_device_result_img
.
mData
.
data
());
auto
isRealOk
=
0
;
auto
isImgOk
=
0
;
if
(
do_verification
)
{
// Real Part Verification
Tensor
<
CShuffleDataType
>
c_ms_ns_host_result_re
(
e_ms_ns_lengths
,
e_ms_ns_strides
);
Tensor
<
CShuffleDataType
>
c_ms_ns_host_result_re1
(
e_ms_ns_lengths
,
e_ms_ns_strides
);
using
ReferenceOpInstance
=
ck
::
tensor_operation
::
host
::
ReferenceContraction_M2_N2_K2
<
NumDimM
,
NumDimN
,
NumDimK
,
ADataType
,
BDataType
,
CShuffleDataType
,
AccDataType
,
F32
,
AElementOp
,
BElementOp
>
;
auto
ref_op
=
ReferenceOpInstance
{};
auto
ref_invoker
=
ref_op
.
MakeInvoker
();
auto
ref_argument_re
=
ref_op
.
MakeArgument
(
a_ms_ks_re
,
b_ns_ks_re
,
c_ms_ns_host_result_re
,
a_element_op
,
b_element_op
);
ref_invoker
.
Run
(
ref_argument_re
);
alpha
=
1.
f
;
beta
=
1.
f
;
cde_element_op
=
CDEElementOp
{
alpha
,
beta
};
for
(
size_t
m0
=
0
;
m0
<
e_ms_ns_host_result_re
.
mDesc
.
GetLengths
()[
0
];
++
m0
)
{
for
(
size_t
m1
=
0
;
m1
<
e_ms_ns_host_result_re
.
mDesc
.
GetLengths
()[
1
];
++
m1
)
{
for
(
size_t
n0
=
0
;
n0
<
e_ms_ns_host_result_re
.
mDesc
.
GetLengths
()[
2
];
++
n0
)
{
for
(
size_t
n1
=
0
;
n1
<
e_ms_ns_host_result_re
.
mDesc
.
GetLengths
()[
3
];
++
n1
)
{
cde_element_op
(
e_ms_ns_host_result_re
(
m0
,
m1
,
n0
,
n1
),
c_ms_ns_host_result_re
(
m0
,
m1
,
n0
,
n1
),
d_ms_ns_re
(
m0
,
m1
,
n0
,
n1
));
}
}
}
}
alpha
=
1.
f
;
beta
=
-
1.
f
;
cde_element_op
=
CDEElementOp
{
alpha
,
beta
};
auto
ref_argument_re1
=
ref_op
.
MakeArgument
(
a_ms_ks_img
,
b_ns_ks_img
,
c_ms_ns_host_result_re1
,
a_element_op
,
b_element_op
);
ref_invoker
.
Run
(
ref_argument_re1
);
for
(
size_t
m0
=
0
;
m0
<
e_ms_ns_host_result_re
.
mDesc
.
GetLengths
()[
0
];
++
m0
)
{
for
(
size_t
m1
=
0
;
m1
<
e_ms_ns_host_result_re
.
mDesc
.
GetLengths
()[
1
];
++
m1
)
{
for
(
size_t
n0
=
0
;
n0
<
e_ms_ns_host_result_re
.
mDesc
.
GetLengths
()[
2
];
++
n0
)
{
for
(
size_t
n1
=
0
;
n1
<
e_ms_ns_host_result_re
.
mDesc
.
GetLengths
()[
3
];
++
n1
)
{
cde_element_op
(
e_ms_ns_host_result_re
(
m0
,
m1
,
n0
,
n1
),
e_ms_ns_host_result_re
(
m0
,
m1
,
n0
,
n1
),
c_ms_ns_host_result_re1
(
m0
,
m1
,
n0
,
n1
));
}
}
}
}
isRealOk
=
ck
::
utils
::
check_err
(
e_ms_ns_device_result_re
,
e_ms_ns_host_result_re
)
?
0
:
1
;
// Img Part Verification
Tensor
<
CShuffleDataType
>
c_ms_ns_host_result_img
(
e_ms_ns_lengths
,
e_ms_ns_strides
);
Tensor
<
CShuffleDataType
>
c_ms_ns_host_result_img1
(
e_ms_ns_lengths
,
e_ms_ns_strides
);
auto
ref_argument_img
=
ref_op
.
MakeArgument
(
a_ms_ks_re
,
b_ns_ks_img
,
c_ms_ns_host_result_img
,
a_element_op
,
b_element_op
);
ref_invoker
.
Run
(
ref_argument_img
);
alpha
=
1.
f
;
beta
=
1.
f
;
cde_element_op
=
CDEElementOp
{
alpha
,
beta
};
for
(
size_t
m0
=
0
;
m0
<
e_ms_ns_host_result_img
.
mDesc
.
GetLengths
()[
0
];
++
m0
)
{
for
(
size_t
m1
=
0
;
m1
<
e_ms_ns_host_result_img
.
mDesc
.
GetLengths
()[
1
];
++
m1
)
{
for
(
size_t
n0
=
0
;
n0
<
e_ms_ns_host_result_img
.
mDesc
.
GetLengths
()[
2
];
++
n0
)
{
for
(
size_t
n1
=
0
;
n1
<
e_ms_ns_host_result_img
.
mDesc
.
GetLengths
()[
3
];
++
n1
)
{
cde_element_op
(
e_ms_ns_host_result_img
(
m0
,
m1
,
n0
,
n1
),
c_ms_ns_host_result_img
(
m0
,
m1
,
n0
,
n1
),
d_ms_ns_img
(
m0
,
m1
,
n0
,
n1
));
}
}
}
}
auto
ref_argument_img1
=
ref_op
.
MakeArgument
(
a_ms_ks_img
,
b_ns_ks_re
,
c_ms_ns_host_result_img1
,
a_element_op
,
b_element_op
);
ref_invoker
.
Run
(
ref_argument_img1
);
for
(
size_t
m0
=
0
;
m0
<
e_ms_ns_host_result_img
.
mDesc
.
GetLengths
()[
0
];
++
m0
)
{
for
(
size_t
m1
=
0
;
m1
<
e_ms_ns_host_result_img
.
mDesc
.
GetLengths
()[
1
];
++
m1
)
{
for
(
size_t
n0
=
0
;
n0
<
e_ms_ns_host_result_img
.
mDesc
.
GetLengths
()[
2
];
++
n0
)
{
for
(
size_t
n1
=
0
;
n1
<
e_ms_ns_host_result_img
.
mDesc
.
GetLengths
()[
3
];
++
n1
)
{
cde_element_op
(
e_ms_ns_host_result_img
(
m0
,
m1
,
n0
,
n1
),
e_ms_ns_host_result_img
(
m0
,
m1
,
n0
,
n1
),
c_ms_ns_host_result_img1
(
m0
,
m1
,
n0
,
n1
));
}
}
}
}
isImgOk
=
ck
::
utils
::
check_err
(
e_ms_ns_device_result_re
,
e_ms_ns_host_result_re
)
?
0
:
1
;
return
(
isRealOk
&&
isImgOk
);
}
return
0
;
}
example/ck_tile/01_fmha/fmha_bwd.cpp
View file @
e3d444c8
...
@@ -99,13 +99,26 @@ auto create_args(int argc, char* argv[])
...
@@ -99,13 +99,26 @@ auto create_args(int argc, char* argv[])
// different threshold for different dtype
// different threshold for different dtype
template
<
typename
DataType
>
template
<
typename
DataType
>
auto
get_elimit
(
int
/*init_method
*/
)
auto
get_elimit
(
ck_tile
::
index_t
/*hdim_q*/
,
ck_tile
::
index_t
/*hdim_v
*/
)
{
{
double
rtol
=
1e-2
;
double
rtol
=
1e-2
;
double
atol
=
1e-2
;
double
atol
=
1e-2
;
return
ck_tile
::
make_tuple
(
rtol
,
atol
);
return
ck_tile
::
make_tuple
(
rtol
,
atol
);
}
}
template
<
>
auto
get_elimit
<
ck_tile
::
bf16_t
>
(
ck_tile
::
index_t
hdim_q
,
ck_tile
::
index_t
hdim_v
)
{
double
rtol
=
1e-2
;
double
atol
=
1e-2
;
if
(
hdim_q
>
128
&&
hdim_v
>
128
)
// 3.2 for RTZ/1.5 for RTN
{
rtol
=
3.2e-2
;
atol
=
3.2e-2
;
}
return
ck_tile
::
make_tuple
(
rtol
,
atol
);
}
template
<
typename
DataType
>
template
<
typename
DataType
>
bool
run
(
const
ck_tile
::
ArgParser
&
arg_parser
)
bool
run
(
const
ck_tile
::
ArgParser
&
arg_parser
)
{
{
...
@@ -899,7 +912,7 @@ bool run(const ck_tile::ArgParser& arg_parser)
...
@@ -899,7 +912,7 @@ bool run(const ck_tile::ArgParser& arg_parser)
}
}
// clang-format on
// clang-format on
auto
[
rtol
,
atol
]
=
get_elimit
<
DataType
>
(
init_method
);
auto
[
rtol
,
atol
]
=
get_elimit
<
DataType
>
(
hdim_q
,
hdim_v
);
bool
dq_cur_pass
=
ck_tile
::
check_err
(
dq_host_result
,
bool
dq_cur_pass
=
ck_tile
::
check_err
(
dq_host_result
,
dq_host_ref
,
dq_host_ref
,
std
::
string
(
"Error: QGrad Incorrect results!"
),
std
::
string
(
"Error: QGrad Incorrect results!"
),
...
...
example/ck_tile/01_fmha/fmha_fwd.cpp
View file @
e3d444c8
...
@@ -552,16 +552,33 @@ bool run(const ck_tile::ArgParser& arg_parser)
...
@@ -552,16 +552,33 @@ bool run(const ck_tile::ArgParser& arg_parser)
}
}
#endif
#endif
auto
get_lengths
=
[
&
](
bool
permute
,
struct
ck_tile
::
index_t
b
/*batch*/
,
{
ck_tile
::
index_t
h
/*nhead*/
,
auto
operator
()(
bool
permute
,
ck_tile
::
index_t
s
/*seqlen*/
,
ck_tile
::
index_t
b
/*batch*/
,
ck_tile
::
index_t
d
/*hdim*/
)
{
ck_tile
::
index_t
h
/*nhead*/
,
if
(
permute
)
ck_tile
::
index_t
s
/*seqlen*/
,
return
std
::
array
<
ck_tile
::
index_t
,
4
>
{
b
,
h
,
s
,
d
};
ck_tile
::
index_t
d
/*hdim*/
)
else
{
return
std
::
array
<
ck_tile
::
index_t
,
4
>
{
b
,
s
,
h
,
d
};
if
(
permute
)
};
return
std
::
array
<
ck_tile
::
index_t
,
4
>
{
b
,
h
,
s
,
d
};
else
return
std
::
array
<
ck_tile
::
index_t
,
4
>
{
b
,
s
,
h
,
d
};
}
auto
operator
()(
bool
permute
,
ck_tile
::
index_t
ns
/*num_splits*/
,
ck_tile
::
index_t
b
/*batch*/
,
ck_tile
::
index_t
h
/*nhead*/
,
ck_tile
::
index_t
s
/*seqlen*/
,
ck_tile
::
index_t
d
/*hdim*/
)
{
if
(
permute
)
return
std
::
array
<
ck_tile
::
index_t
,
5
>
{
ns
,
b
,
h
,
s
,
d
};
else
return
std
::
array
<
ck_tile
::
index_t
,
5
>
{
ns
,
b
,
s
,
h
,
d
};
}
}
get_lengths
;
bool
is_v_rowmajor
=
vlayout
==
std
::
string
(
"r"
);
bool
is_v_rowmajor
=
vlayout
==
std
::
string
(
"r"
);
...
@@ -617,7 +634,7 @@ bool run(const ck_tile::ArgParser& arg_parser)
...
@@ -617,7 +634,7 @@ bool run(const ck_tile::ArgParser& arg_parser)
:
std
::
array
<
ck_tile
::
index_t
,
4
>
{
1
,
1
,
1
,
1
});
:
std
::
array
<
ck_tile
::
index_t
,
4
>
{
1
,
1
,
1
,
1
});
ck_tile
::
HostTensor
<
OaccDataType
>
o_acc_host
(
ck_tile
::
HostTensor
<
OaccDataType
>
o_acc_host
(
1
<
num_splits
||
use_kvcache
1
<
num_splits
||
use_kvcache
?
std
::
array
<
ck_tile
::
index_t
,
5
>
{
num_splits
,
batch
,
nhead
,
max
_seqlen_q
,
hdim_v
}
?
get_lengths
(
o_perm
,
num_splits
,
shape_
batch
,
nhead
,
shape
_seqlen_q
,
hdim_v
)
:
std
::
array
<
ck_tile
::
index_t
,
5
>
{
1
,
1
,
1
,
1
,
1
});
:
std
::
array
<
ck_tile
::
index_t
,
5
>
{
1
,
1
,
1
,
1
,
1
});
// batch mode of lse data layout is [batch, nhead, seqlen_q]
// batch mode of lse data layout is [batch, nhead, seqlen_q]
...
@@ -854,7 +871,7 @@ bool run(const ck_tile::ArgParser& arg_parser)
...
@@ -854,7 +871,7 @@ bool run(const ck_tile::ArgParser& arg_parser)
}();
}();
const
ck_tile
::
index_t
stride_bias
=
(
i_perm
?
shape_seqlen_k
:
1
*
shape_seqlen_k
);
const
ck_tile
::
index_t
stride_bias
=
(
i_perm
?
shape_seqlen_k
:
1
*
shape_seqlen_k
);
const
ck_tile
::
index_t
stride_randval
=
(
max_seqlen_k
);
const
ck_tile
::
index_t
stride_randval
=
(
max_seqlen_k
);
const
ck_tile
::
index_t
stride_o_acc
=
hdim_v
;
const
ck_tile
::
index_t
stride_o_acc
=
(
o_perm
?
hdim_v
:
nhead
*
hdim_v
)
;
const
ck_tile
::
index_t
stride_o
=
(
o_perm
?
hdim_v
:
nhead
*
hdim_v
);
const
ck_tile
::
index_t
stride_o
=
(
o_perm
?
hdim_v
:
nhead
*
hdim_v
);
// setup nhead_stride_* arguments
// setup nhead_stride_* arguments
const
ck_tile
::
index_t
nhead_stride_q
=
(
i_perm
?
shape_seqlen_q
*
hdim_q
:
hdim_q
);
const
ck_tile
::
index_t
nhead_stride_q
=
(
i_perm
?
shape_seqlen_q
*
hdim_q
:
hdim_q
);
...
@@ -881,7 +898,7 @@ bool run(const ck_tile::ArgParser& arg_parser)
...
@@ -881,7 +898,7 @@ bool run(const ck_tile::ArgParser& arg_parser)
const
ck_tile
::
index_t
nhead_stride_randval
=
(
shape_seqlen_q
*
max_seqlen_k
);
const
ck_tile
::
index_t
nhead_stride_randval
=
(
shape_seqlen_q
*
max_seqlen_k
);
const
ck_tile
::
index_t
nhead_stride_lse
=
shape_seqlen_q
;
const
ck_tile
::
index_t
nhead_stride_lse
=
shape_seqlen_q
;
const
ck_tile
::
index_t
nhead_stride_lse_acc
=
shape_seqlen_q
;
const
ck_tile
::
index_t
nhead_stride_lse_acc
=
shape_seqlen_q
;
const
ck_tile
::
index_t
nhead_stride_o_acc
=
(
max
_seqlen_q
*
hdim_v
);
const
ck_tile
::
index_t
nhead_stride_o_acc
=
(
o_perm
?
shape
_seqlen_q
*
hdim_v
:
hdim_v
);
const
ck_tile
::
index_t
nhead_stride_o
=
(
o_perm
?
shape_seqlen_q
*
hdim_v
:
hdim_v
);
const
ck_tile
::
index_t
nhead_stride_o
=
(
o_perm
?
shape_seqlen_q
*
hdim_v
:
hdim_v
);
// setup batch_stride_* arguments
// setup batch_stride_* arguments
const
ck_tile
::
index_t
batch_stride_q
=
(
nhead
*
shape_seqlen_q
*
hdim_q
);
const
ck_tile
::
index_t
batch_stride_q
=
(
nhead
*
shape_seqlen_q
*
hdim_q
);
...
@@ -897,12 +914,12 @@ bool run(const ck_tile::ArgParser& arg_parser)
...
@@ -897,12 +914,12 @@ bool run(const ck_tile::ArgParser& arg_parser)
const
ck_tile
::
index_t
batch_stride_randval
=
(
nhead
*
shape_seqlen_q
*
max_seqlen_k
);
const
ck_tile
::
index_t
batch_stride_randval
=
(
nhead
*
shape_seqlen_q
*
max_seqlen_k
);
const
ck_tile
::
index_t
batch_stride_lse
=
(
nhead
*
shape_seqlen_q
);
const
ck_tile
::
index_t
batch_stride_lse
=
(
nhead
*
shape_seqlen_q
);
const
ck_tile
::
index_t
batch_stride_lse_acc
=
(
nhead
*
shape_seqlen_q
);
const
ck_tile
::
index_t
batch_stride_lse_acc
=
(
nhead
*
shape_seqlen_q
);
const
ck_tile
::
index_t
batch_stride_o_acc
=
(
nhead
*
max
_seqlen_q
*
hdim_v
);
const
ck_tile
::
index_t
batch_stride_o_acc
=
(
nhead
*
shape
_seqlen_q
*
hdim_v
);
const
ck_tile
::
index_t
batch_stride_o
=
(
nhead
*
shape_seqlen_q
*
hdim_v
);
const
ck_tile
::
index_t
batch_stride_o
=
(
nhead
*
shape_seqlen_q
*
hdim_v
);
const
ck_tile
::
index_t
batch_stride_block_table
=
(
max_num_page_blocks
/
batch
);
const
ck_tile
::
index_t
batch_stride_block_table
=
(
max_num_page_blocks
/
batch
);
// setup split_stride_* arguments (only used in split-kv kernel)
// setup split_stride_* arguments (only used in split-kv kernel)
const
ck_tile
::
index_t
split_stride_lse_acc
=
(
shape_batch
*
nhead
*
shape_seqlen_q
);
const
ck_tile
::
index_t
split_stride_lse_acc
=
(
shape_batch
*
nhead
*
shape_seqlen_q
);
const
ck_tile
::
index_t
split_stride_o_acc
=
(
batch
*
nhead
*
max
_seqlen_q
*
hdim_v
);
const
ck_tile
::
index_t
split_stride_o_acc
=
(
shape_
batch
*
nhead
*
shape
_seqlen_q
*
hdim_v
);
args
.
q_ptr
=
q_buf
.
GetDeviceBuffer
();
args
.
q_ptr
=
q_buf
.
GetDeviceBuffer
();
args
.
k_ptr
=
k_buf
.
GetDeviceBuffer
();
args
.
k_ptr
=
k_buf
.
GetDeviceBuffer
();
...
...
example/ck_tile/01_fmha/fmha_fwd.hpp
View file @
e3d444c8
...
@@ -398,10 +398,8 @@ auto fmha_fwd_splitkv_create_kargs_and_grids(fmha_fwd_splitkv_args args)
...
@@ -398,10 +398,8 @@ auto fmha_fwd_splitkv_create_kargs_and_grids(fmha_fwd_splitkv_args args)
args
.
nhead_stride_bias
,
args
.
nhead_stride_bias
,
args
.
nhead_stride_lse_acc
,
args
.
nhead_stride_lse_acc
,
args
.
nhead_stride_o_acc
,
args
.
nhead_stride_o_acc
,
args
.
batch_stride_k
,
args
.
batch_stride_k
,
// only used for paged-kvcache
args
.
batch_stride_v
,
args
.
batch_stride_v
,
// only used for paged-kvcache
args
.
batch_stride_lse_acc
,
args
.
batch_stride_o_acc
,
args
.
split_stride_lse_acc
,
args
.
split_stride_lse_acc
,
args
.
split_stride_o_acc
,
args
.
split_stride_o_acc
,
args
.
window_size_left
,
args
.
window_size_left
,
...
@@ -475,7 +473,6 @@ auto fmha_fwd_splitkv_combine_create_kargs_and_grids(fmha_fwd_splitkv_args args)
...
@@ -475,7 +473,6 @@ auto fmha_fwd_splitkv_combine_create_kargs_and_grids(fmha_fwd_splitkv_args args)
args
.
lse_ptr
,
args
.
lse_ptr
,
args
.
o_ptr
,
args
.
o_ptr
,
args
.
batch
,
args
.
batch
,
args
.
max_seqlen_q
,
args
.
seqstart_q_ptr
,
args
.
seqstart_q_ptr
,
args
.
hdim_v
,
args
.
hdim_v
,
args
.
num_splits
,
args
.
num_splits
,
...
@@ -486,7 +483,6 @@ auto fmha_fwd_splitkv_combine_create_kargs_and_grids(fmha_fwd_splitkv_args args)
...
@@ -486,7 +483,6 @@ auto fmha_fwd_splitkv_combine_create_kargs_and_grids(fmha_fwd_splitkv_args args)
args
.
nhead_stride_o_acc
,
args
.
nhead_stride_o_acc
,
args
.
nhead_stride_lse
,
args
.
nhead_stride_lse
,
args
.
nhead_stride_o
,
args
.
nhead_stride_o
,
args
.
batch_stride_o_acc
,
args
.
split_stride_lse_acc
,
args
.
split_stride_lse_acc
,
args
.
split_stride_o_acc
);
args
.
split_stride_o_acc
);
}
}
...
@@ -497,7 +493,6 @@ auto fmha_fwd_splitkv_combine_create_kargs_and_grids(fmha_fwd_splitkv_args args)
...
@@ -497,7 +493,6 @@ auto fmha_fwd_splitkv_combine_create_kargs_and_grids(fmha_fwd_splitkv_args args)
args
.
lse_ptr
,
args
.
lse_ptr
,
args
.
o_ptr
,
args
.
o_ptr
,
args
.
batch
,
args
.
batch
,
args
.
max_seqlen_q
,
args
.
seqlen_q
,
args
.
seqlen_q
,
args
.
hdim_v
,
args
.
hdim_v
,
args
.
num_splits
,
args
.
num_splits
,
...
...
example/ck_tile/04_img2col/CMakeLists.txt
0 → 100644
View file @
e3d444c8
# not using add_example_executable() to add this target, since we don't want this to have
# to be included in "make all/install/check"
add_executable
(
tile_example_img2col EXCLUDE_FROM_ALL image_to_column.cpp
)
example/ck_tile/04_img2col/README.md
0 → 100644
View file @
e3d444c8
# Image to Column
This folder contains example for Image to Column using ck_tile tile-programming implementation.
## build
```
# in the root of ck_tile
mkdir build && cd build
sh ../script/cmake-ck-dev.sh ../ <arch> # you can replace this <arch> to gfx90a, gfx942...
make tile_example_img2col -j
```
This will result in an executable
`build/bin/tile_example_img2col`
example/ck_tile/04_img2col/image_to_column.cpp
0 → 100644
View file @
e3d444c8
// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
#include <algorithm>
#include <cstring>
#include "ck_tile/host.hpp"
#include "image_to_column.hpp"
// Host API implementation
template
<
>
float
image_to_column
(
const
image_to_column_traits
&
traits
,
const
image_to_column_args
<
2
>&
args
,
const
ck_tile
::
stream_config
&
stream_conf
)
{
if
(
traits
.
data_type
.
compare
(
"fp16"
)
==
0
)
{
constexpr
ck_tile
::
index_t
NDimSpatial
=
2
;
constexpr
ck_tile
::
index_t
VectorSize
=
8
;
using
thread_tile
=
ck_tile
::
sequence
<
8
,
8
>
;
using
warp_tile
=
ck_tile
::
sequence
<
64
,
64
>
;
using
block_tile
=
ck_tile
::
sequence
<
128
,
128
>
;
using
Shape
=
ck_tile
::
TileImageToColumnShape
<
thread_tile
,
warp_tile
,
block_tile
>
;
using
InDataType
=
ck_tile
::
half_t
;
using
OutDataType
=
ck_tile
::
half_t
;
using
PipelineProblem
=
ck_tile
::
BlockImageToColumnProblem
<
InDataType
,
OutDataType
,
Shape
,
NDimSpatial
,
VectorSize
,
VectorSize
>
;
using
Kernel
=
ck_tile
::
ImageToColumn
<
PipelineProblem
>
;
auto
kargs
=
Kernel
::
MakeKargs
(
args
.
p_in
,
args
.
p_out
,
args
.
G
,
args
.
N
,
args
.
C
,
args
.
input_spatial_lengths
,
args
.
filter_spatial_lengths
,
args
.
output_spatial_lengths
,
args
.
image_g_n_c_wis_strides
,
args
.
gemm_g_m_k_strides
,
args
.
conv_filter_strides
,
args
.
conv_filter_dilations
,
args
.
input_left_pads
,
args
.
input_right_pads
);
const
dim3
grids
=
Kernel
::
GridSize
(
args
.
N
*
args
.
output_spatial_lengths
[
0
]
*
args
.
output_spatial_lengths
[
1
],
args
.
filter_spatial_lengths
[
0
]
*
args
.
filter_spatial_lengths
[
1
]
*
args
.
C
,
args
.
G
);
constexpr
dim3
blocks
=
Kernel
::
BlockSize
();
constexpr
ck_tile
::
index_t
kBlockPerCu
=
2
;
float
ave_time
=
ck_tile
::
launch_kernel
(
stream_conf
,
ck_tile
::
make_kernel
<
blocks
.
x
,
kBlockPerCu
>
(
Kernel
{},
grids
,
blocks
,
0
,
kargs
));
return
ave_time
;
}
return
0
;
}
int
main
(
int
argc
,
char
*
argv
[])
{
constexpr
ck_tile
::
index_t
NDimSpatial
=
2
;
ExecutionConfig
config
;
ck_tile
::
conv
::
ConvParam
conv_params
=
DefaultConvParams
;
if
(
!
parse_cmd_args
(
argc
,
argv
,
config
,
conv_params
))
{
return
EXIT_FAILURE
;
}
if
(
conv_params
.
num_dim_spatial_
!=
NDimSpatial
)
{
std
::
cerr
<<
"unsupported # of spatial dimensions"
<<
std
::
endl
;
return
EXIT_FAILURE
;
}
using
InDataType
=
ck_tile
::
half_t
;
using
OutDataType
=
ck_tile
::
half_t
;
using
ImLayout
=
ck_tile
::
tensor_layout
::
convolution
::
NHWGC
;
const
auto
G
=
conv_params
.
G_
;
const
auto
N
=
conv_params
.
N_
;
const
auto
C
=
conv_params
.
C_
;
const
ck_tile
::
long_index_t
NHoWo
=
N
*
std
::
accumulate
(
conv_params
.
output_spatial_lengths_
.
begin
(),
std
::
next
(
conv_params
.
output_spatial_lengths_
.
begin
(),
NDimSpatial
),
1
,
std
::
multiplies
<>
());
const
ck_tile
::
long_index_t
CYX
=
C
*
std
::
accumulate
(
conv_params
.
filter_spatial_lengths_
.
begin
(),
std
::
next
(
conv_params
.
filter_spatial_lengths_
.
begin
(),
NDimSpatial
),
1
,
std
::
multiplies
<>
());
const
auto
in_desc
=
ck_tile
::
conv
::
make_input_host_tensor_descriptor_g_n_c_wis_packed
<
ImLayout
>
(
conv_params
);
const
auto
out_desc
=
ck_tile
::
HostTensorDescriptor
({
G
,
NHoWo
,
CYX
});
// host verify
ck_tile
::
HostTensor
<
InDataType
>
in
(
in_desc
);
ck_tile
::
HostTensor
<
OutDataType
>
out_device
(
out_desc
);
ck_tile
::
HostTensor
<
OutDataType
>
out_host
(
out_desc
);
switch
(
config
.
init_method
)
{
case
0
:
break
;
case
1
:
ck_tile
::
FillUniformDistributionIntegerValue
<
InDataType
>
{
-
5.
f
,
5.
f
}(
in
);
break
;
default:
ck_tile
::
FillUniformDistribution
<
InDataType
>
{
-
0.5
,
0.5
}(
in
);
break
;
}
ck_tile
::
DeviceMem
in_device_buf
(
in
.
get_element_space_size_in_bytes
());
ck_tile
::
DeviceMem
out_device_buf
(
out_device
.
get_element_space_size_in_bytes
());
in_device_buf
.
ToDevice
(
in
.
data
());
image_to_column_traits
traits
{
"fp16"
};
image_to_column_args
<
NDimSpatial
>
args
{
in_device_buf
.
GetDeviceBuffer
(),
out_device_buf
.
GetDeviceBuffer
(),
G
,
N
,
C
,
ck_tile
::
to_array
<
ck_tile
::
long_index_t
,
NDimSpatial
>
(
conv_params
.
input_spatial_lengths_
),
ck_tile
::
to_array
<
ck_tile
::
long_index_t
,
NDimSpatial
>
(
conv_params
.
filter_spatial_lengths_
),
ck_tile
::
to_array
<
ck_tile
::
long_index_t
,
NDimSpatial
>
(
conv_params
.
output_spatial_lengths_
),
ck_tile
::
to_array
<
ck_tile
::
long_index_t
,
NDimSpatial
+
3
>
(
in_desc
.
get_strides
()),
ck_tile
::
to_array
<
ck_tile
::
long_index_t
,
3
>
(
out_desc
.
get_strides
()),
ck_tile
::
to_array
<
ck_tile
::
long_index_t
,
NDimSpatial
>
(
conv_params
.
conv_filter_strides_
),
ck_tile
::
to_array
<
ck_tile
::
long_index_t
,
NDimSpatial
>
(
conv_params
.
conv_filter_dilations_
),
ck_tile
::
to_array
<
ck_tile
::
long_index_t
,
NDimSpatial
>
(
conv_params
.
input_left_pads_
),
ck_tile
::
to_array
<
ck_tile
::
long_index_t
,
NDimSpatial
>
(
conv_params
.
input_right_pads_
)};
float
ave_time
=
image_to_column
(
traits
,
args
,
ck_tile
::
stream_config
{
nullptr
,
config
.
time_kernel
});
std
::
size_t
num_btype
=
G
*
NHoWo
*
CYX
*
(
sizeof
(
OutDataType
)
+
sizeof
(
InDataType
));
float
gb_per_sec
=
num_btype
/
1.E6
/
ave_time
;
std
::
cout
<<
"Perf: "
<<
ave_time
<<
" ms, "
<<
gb_per_sec
<<
" GB/s"
<<
std
::
endl
;
bool
pass
=
true
;
if
(
config
.
do_verification
)
{
// reference
ck_tile
::
reference_im2col
<
InDataType
,
OutDataType
,
NDimSpatial
>
(
in
,
out_host
,
conv_params
);
out_device_buf
.
FromDevice
(
out_device
.
data
());
pass
=
ck_tile
::
check_err
(
out_device
,
out_host
);
std
::
cout
<<
"valid:"
<<
(
pass
?
"y"
:
"n"
)
<<
std
::
endl
;
}
return
!
pass
;
}
example/ck_tile/04_img2col/image_to_column.hpp
0 → 100644
View file @
e3d444c8
// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core.hpp"
#include "ck_tile/host/kernel_launch.hpp"
#include "ck_tile/ops/image_to_column.hpp"
#include <string>
#define DefaultConvParams \
ck_tile::conv::ConvParam \
{ \
2, 2, 32, 32, 32, {4, 4}, {64, 64}, {1, 1}, {1, 1}, {0, 0}, { 0, 0 } \
}
struct
ExecutionConfig
final
{
bool
do_verification
=
true
;
int
init_method
=
1
;
bool
time_kernel
=
false
;
};
inline
void
print_help_msg
()
{
std
::
cerr
<<
"arg1: verification (0=no, 1=yes)
\n
"
<<
"arg2: initialization (0=no init, 1=integer value, 2=decimal value)
\n
"
<<
"arg3: time kernel (0=no, 1=yes)
\n
"
<<
ck_tile
::
conv
::
get_conv_param_parser_helper_msg
()
<<
std
::
endl
;
}
inline
bool
parse_cmd_args
(
int
argc
,
char
*
argv
[],
ExecutionConfig
&
config
,
ck_tile
::
conv
::
ConvParam
&
conv_params
)
{
constexpr
int
num_execution_config_args
=
3
;
// arguments for do_verification, init_method, time_kernel
constexpr
int
num_conv_param_leading_args
=
5
;
// arguments for num_dim_spatial_, G_, N_, K_, C_
constexpr
int
threshold_to_catch_partial_args
=
1
+
num_execution_config_args
;
constexpr
int
threshold_to_catch_all_args
=
threshold_to_catch_partial_args
+
num_conv_param_leading_args
;
if
(
argc
==
1
)
{
// use default
config
=
ExecutionConfig
{};
}
// catch only ExecutionConfig arguments
else
if
(
argc
==
threshold_to_catch_partial_args
)
{
config
.
do_verification
=
std
::
stoi
(
argv
[
1
]);
config
.
init_method
=
std
::
stoi
(
argv
[
2
]);
config
.
time_kernel
=
std
::
stoi
(
argv
[
3
]);
}
// catch both ExecutionConfig & ConvParam arguments
else
if
(
threshold_to_catch_all_args
<
argc
&&
((
argc
-
threshold_to_catch_all_args
)
%
3
==
0
))
{
config
.
do_verification
=
std
::
stoi
(
argv
[
1
]);
config
.
init_method
=
std
::
stoi
(
argv
[
2
]);
config
.
time_kernel
=
std
::
stoi
(
argv
[
3
]);
const
ck_tile
::
index_t
num_dim_spatial
=
std
::
stoi
(
argv
[
4
]);
conv_params
=
ck_tile
::
conv
::
parse_conv_param
(
num_dim_spatial
,
threshold_to_catch_partial_args
,
argv
);
}
else
{
print_help_msg
();
return
false
;
}
return
true
;
}
struct
image_to_column_traits
{
std
::
string
data_type
;
};
template
<
ck_tile
::
index_t
NDimSpatial
>
struct
image_to_column_args
{
const
void
*
p_in
;
void
*
p_out
;
const
ck_tile
::
long_index_t
G
;
const
ck_tile
::
long_index_t
N
;
const
ck_tile
::
long_index_t
C
;
const
ck_tile
::
array
<
ck_tile
::
long_index_t
,
NDimSpatial
>
input_spatial_lengths
;
const
ck_tile
::
array
<
ck_tile
::
long_index_t
,
NDimSpatial
>
filter_spatial_lengths
;
const
ck_tile
::
array
<
ck_tile
::
long_index_t
,
NDimSpatial
>
output_spatial_lengths
;
const
ck_tile
::
array
<
ck_tile
::
long_index_t
,
NDimSpatial
+
3
>
image_g_n_c_wis_strides
;
const
ck_tile
::
array
<
ck_tile
::
long_index_t
,
3
>
gemm_g_m_k_strides
;
const
ck_tile
::
array
<
ck_tile
::
long_index_t
,
NDimSpatial
>
conv_filter_strides
;
const
ck_tile
::
array
<
ck_tile
::
long_index_t
,
NDimSpatial
>
conv_filter_dilations
;
const
ck_tile
::
array
<
ck_tile
::
long_index_t
,
NDimSpatial
>
input_left_pads
;
const
ck_tile
::
array
<
ck_tile
::
long_index_t
,
NDimSpatial
>
input_right_pads
;
};
// host API
template
<
ck_tile
::
index_t
NDimSpatial
>
float
image_to_column
(
const
image_to_column_traits
&
,
const
image_to_column_args
<
NDimSpatial
>&
,
const
ck_tile
::
stream_config
&
);
example/ck_tile/CMakeLists.txt
View file @
e3d444c8
...
@@ -5,3 +5,4 @@ include_directories(AFTER
...
@@ -5,3 +5,4 @@ include_directories(AFTER
add_subdirectory
(
01_fmha
)
add_subdirectory
(
01_fmha
)
add_subdirectory
(
02_layernorm2d
)
add_subdirectory
(
02_layernorm2d
)
add_subdirectory
(
03_gemm
)
add_subdirectory
(
03_gemm
)
add_subdirectory
(
04_img2col
)
include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops.hpp
View file @
e3d444c8
...
@@ -406,7 +406,7 @@ struct BlockwiseGemmXdlops_pipeline_v4
...
@@ -406,7 +406,7 @@ struct BlockwiseGemmXdlops_pipeline_v4
}
}
template
<
>
template
<
>
__device__
static
constexpr
auto
TailScheduler
<
1
>
()
__device__
constexpr
auto
TailScheduler
<
1
>
()
{
{
// schedule
// schedule
constexpr
auto
num_ds_read_inst
=
constexpr
auto
num_ds_read_inst
=
...
@@ -433,7 +433,7 @@ struct BlockwiseGemmXdlops_pipeline_v4
...
@@ -433,7 +433,7 @@ struct BlockwiseGemmXdlops_pipeline_v4
}
}
template
<
>
template
<
>
__device__
static
constexpr
auto
TailScheduler
<
2
>
()
__device__
constexpr
auto
TailScheduler
<
2
>
()
{
{
// schedule
// schedule
constexpr
auto
num_ds_read_inst
=
constexpr
auto
num_ds_read_inst
=
...
...
include/ck/tensor_operation/gpu/warp/dpp_gemm.hpp
View file @
e3d444c8
...
@@ -324,55 +324,55 @@ struct DppSelector
...
@@ -324,55 +324,55 @@ struct DppSelector
static
constexpr
auto
GetDpp
();
static
constexpr
auto
GetDpp
();
template
<
>
template
<
>
static
constexpr
auto
GetDpp
<
half_t
,
8
,
32
>
()
constexpr
auto
GetDpp
<
half_t
,
8
,
32
>
()
{
{
return
DppInstr
::
dpp8_f16_8x32x2
;
return
DppInstr
::
dpp8_f16_8x32x2
;
}
}
template
<
>
template
<
>
static
constexpr
auto
GetDpp
<
half_t
,
8
,
16
>
()
constexpr
auto
GetDpp
<
half_t
,
8
,
16
>
()
{
{
return
DppInstr
::
dpp8_f16_8x16x2
;
return
DppInstr
::
dpp8_f16_8x16x2
;
}
}
template
<
>
template
<
>
static
constexpr
auto
GetDpp
<
half_t
,
16
,
16
>
()
constexpr
auto
GetDpp
<
half_t
,
16
,
16
>
()
{
{
return
DppInstr
::
dpp8_f16_16x16x2
;
return
DppInstr
::
dpp8_f16_16x16x2
;
}
}
template
<
>
template
<
>
static
constexpr
auto
GetDpp
<
half_t
,
32
,
8
>
()
constexpr
auto
GetDpp
<
half_t
,
32
,
8
>
()
{
{
return
DppInstr
::
dpp8_f16_32x8x2
;
return
DppInstr
::
dpp8_f16_32x8x2
;
}
}
template
<
>
template
<
>
static
constexpr
auto
GetDpp
<
half_t
,
1
,
32
>
()
constexpr
auto
GetDpp
<
half_t
,
1
,
32
>
()
{
{
return
DppInstr
::
dpp8_f16_1x32x2
;
return
DppInstr
::
dpp8_f16_1x32x2
;
}
}
template
<
>
template
<
>
static
constexpr
auto
GetDpp
<
half_t
,
2
,
32
>
()
constexpr
auto
GetDpp
<
half_t
,
2
,
32
>
()
{
{
return
DppInstr
::
dpp8_f16_2x32x2
;
return
DppInstr
::
dpp8_f16_2x32x2
;
}
}
template
<
>
template
<
>
static
constexpr
auto
GetDpp
<
half_t
,
2
,
16
>
()
constexpr
auto
GetDpp
<
half_t
,
2
,
16
>
()
{
{
return
DppInstr
::
dpp8_f16_2x16x2
;
return
DppInstr
::
dpp8_f16_2x16x2
;
}
}
template
<
>
template
<
>
static
constexpr
auto
GetDpp
<
half_t
,
4
,
16
>
()
constexpr
auto
GetDpp
<
half_t
,
4
,
16
>
()
{
{
return
DppInstr
::
dpp8_f16_4x16x2
;
return
DppInstr
::
dpp8_f16_4x16x2
;
}
}
template
<
>
template
<
>
static
constexpr
auto
GetDpp
<
half_t
,
4
,
32
>
()
constexpr
auto
GetDpp
<
half_t
,
4
,
32
>
()
{
{
return
DppInstr
::
dpp8_f16_4x32x2
;
return
DppInstr
::
dpp8_f16_4x32x2
;
}
}
...
...
include/ck/tensor_operation/gpu/warp/wmma_gemm.hpp
View file @
e3d444c8
...
@@ -415,7 +415,7 @@ struct WmmaSelector
...
@@ -415,7 +415,7 @@ struct WmmaSelector
static
constexpr
auto
GetWmma
();
static
constexpr
auto
GetWmma
();
template
<
>
template
<
>
static
constexpr
auto
GetWmma
<
half_t
,
half_t
,
float
,
16
,
16
>
()
constexpr
auto
GetWmma
<
half_t
,
half_t
,
float
,
16
,
16
>
()
{
{
#ifdef __gfx12__
#ifdef __gfx12__
return
WmmaInstr
::
wmma_f32_16x16x16_f16_gfx12
;
return
WmmaInstr
::
wmma_f32_16x16x16_f16_gfx12
;
...
@@ -425,7 +425,7 @@ struct WmmaSelector
...
@@ -425,7 +425,7 @@ struct WmmaSelector
}
}
template
<
>
template
<
>
static
constexpr
auto
GetWmma
<
bhalf_t
,
bhalf_t
,
float
,
16
,
16
>
()
constexpr
auto
GetWmma
<
bhalf_t
,
bhalf_t
,
float
,
16
,
16
>
()
{
{
#ifdef __gfx12__
#ifdef __gfx12__
return
WmmaInstr
::
wmma_f32_16x16x16_bf16_gfx12
;
return
WmmaInstr
::
wmma_f32_16x16x16_bf16_gfx12
;
...
@@ -435,19 +435,19 @@ struct WmmaSelector
...
@@ -435,19 +435,19 @@ struct WmmaSelector
}
}
template
<
>
template
<
>
static
constexpr
auto
GetWmma
<
half_t
,
half_t
,
half_t
,
16
,
16
>
()
constexpr
auto
GetWmma
<
half_t
,
half_t
,
half_t
,
16
,
16
>
()
{
{
return
WmmaInstr
::
wmma_f16_16x16x16_f16
;
return
WmmaInstr
::
wmma_f16_16x16x16_f16
;
}
}
template
<
>
template
<
>
static
constexpr
auto
GetWmma
<
bhalf_t
,
bhalf_t
,
bhalf_t
,
16
,
16
>
()
constexpr
auto
GetWmma
<
bhalf_t
,
bhalf_t
,
bhalf_t
,
16
,
16
>
()
{
{
return
WmmaInstr
::
wmma_bf16_16x16x16_bf16
;
return
WmmaInstr
::
wmma_bf16_16x16x16_bf16
;
}
}
template
<
>
template
<
>
static
constexpr
auto
GetWmma
<
int8_t
,
int8_t
,
int
,
16
,
16
>
()
constexpr
auto
GetWmma
<
int8_t
,
int8_t
,
int
,
16
,
16
>
()
{
{
#ifdef __gfx12__
#ifdef __gfx12__
return
WmmaInstr
::
wmma_i32_16x16x16_iu8_gfx12
;
return
WmmaInstr
::
wmma_i32_16x16x16_iu8_gfx12
;
...
@@ -458,7 +458,7 @@ struct WmmaSelector
...
@@ -458,7 +458,7 @@ struct WmmaSelector
#ifdef CK_EXPERIMENTAL_BIT_INT_EXTENSION_INT4
#ifdef CK_EXPERIMENTAL_BIT_INT_EXTENSION_INT4
template
<
>
template
<
>
static
constexpr
auto
GetWmma
<
int4_t
,
int4_t
,
int
,
16
,
16
>
()
constexpr
auto
GetWmma
<
int4_t
,
int4_t
,
int
,
16
,
16
>
()
{
{
return
WmmaInstr
::
wmma_i32_16x16x16_iu4
;
return
WmmaInstr
::
wmma_i32_16x16x16_iu4
;
}
}
...
...
Prev
1
2
3
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment