Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel_ROCM
Commits
e878371c
Commit
e878371c
authored
Feb 07, 2024
by
Adam Osewski
Browse files
Merge remote-tracking branch 'origin/develop' into aosewski/ggemm_multi_d2
parents
9cb25b86
753cef78
Changes
97
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
167 additions
and
47 deletions
+167
-47
CHANGELOG.md
CHANGELOG.md
+1
-1
Dockerfile
Dockerfile
+2
-2
Jenkinsfile
Jenkinsfile
+17
-8
docs/sphinx/requirements.in
docs/sphinx/requirements.in
+1
-1
docs/sphinx/requirements.txt
docs/sphinx/requirements.txt
+1
-1
docs/wrapper.rst
docs/wrapper.rst
+1
-0
example/35_splitK_gemm/CMakeLists.txt
example/35_splitK_gemm/CMakeLists.txt
+3
-0
example/35_splitK_gemm/splitK_gemm_xdl_lds_direct_load_fp16.cpp
...e/35_splitK_gemm/splitK_gemm_xdl_lds_direct_load_fp16.cpp
+82
-0
include/ck/ck.hpp
include/ck/ck.hpp
+27
-16
include/ck/host_utility/device_prop.hpp
include/ck/host_utility/device_prop.hpp
+19
-0
include/ck/tensor_operation/gpu/device/impl/device_batched_contraction_multiple_d_wmma_cshuffle.hpp
...l/device_batched_contraction_multiple_d_wmma_cshuffle.hpp
+1
-2
include/ck/tensor_operation/gpu/device/impl/device_batched_contraction_multiple_d_xdl_cshuffle.hpp
...pl/device_batched_contraction_multiple_d_xdl_cshuffle.hpp
+1
-1
include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_e_permute_xdl.hpp
...ion/gpu/device/impl/device_batched_gemm_e_permute_xdl.hpp
+1
-1
include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_gemm_xdl_cshuffle.hpp
...gpu/device/impl/device_batched_gemm_gemm_xdl_cshuffle.hpp
+1
-1
include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_multi_d_xdl.hpp
...ation/gpu/device/impl/device_batched_gemm_multi_d_xdl.hpp
+1
-1
include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_multiple_d_dl.hpp
...ion/gpu/device/impl/device_batched_gemm_multiple_d_dl.hpp
+4
-8
include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_multiple_d_gemm_multiple_d_xdl_cshuffle.hpp
..._batched_gemm_multiple_d_gemm_multiple_d_xdl_cshuffle.hpp
+1
-1
include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_reduce_xdl_cshuffle.hpp
...u/device/impl/device_batched_gemm_reduce_xdl_cshuffle.hpp
+1
-1
include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_softmax_gemm_permute_xdl_cshuffle.hpp
...device_batched_gemm_softmax_gemm_permute_xdl_cshuffle.hpp
+1
-1
include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_softmax_gemm_xdl_cshuffle.hpp
...ce/impl/device_batched_gemm_softmax_gemm_xdl_cshuffle.hpp
+1
-1
No files found.
CHANGELOG.md
View file @
e878371c
...
@@ -11,7 +11,7 @@ None
...
@@ -11,7 +11,7 @@ None
None
None
### Additions
### Additions
*
Introduced wrapper sublibrary (limited functionality). (#1071, #1098, #1108, #1126)
*
Introduced wrapper sublibrary (limited functionality). (#1071, #1098, #1108, #1126
, #1139
)
### Changes
### Changes
None
None
...
...
Dockerfile
View file @
e878371c
...
@@ -122,7 +122,7 @@ ENV compiler_commit=$compiler_commit
...
@@ -122,7 +122,7 @@ ENV compiler_commit=$compiler_commit
RUN
sh
-c
"echo compiler version = '
$compiler_version
'"
RUN
sh
-c
"echo compiler version = '
$compiler_version
'"
RUN
sh
-c
"echo compiler commit = '
$compiler_commit
'"
RUN
sh
-c
"echo compiler commit = '
$compiler_commit
'"
RUN if
(
[
"
$compiler_version
"
=
"amd-st
g-open
"
]
||
[
"
$compiler_version
"
=
"amd-mainline-open"
]
)
&&
[
"
$compiler_commit
"
=
""
]
;
then
\
RUN if
(
[
"
$compiler_version
"
=
"amd-st
aging
"
]
||
[
"
$compiler_version
"
=
"amd-mainline-open"
]
)
&&
[
"
$compiler_commit
"
=
""
]
;
then
\
git clone
-b
"
$compiler_version
"
https://github.com/RadeonOpenCompute/llvm-project.git
&&
\
git clone
-b
"
$compiler_version
"
https://github.com/RadeonOpenCompute/llvm-project.git
&&
\
cd
llvm-project
&&
mkdir
build
&&
cd
build
&&
\
cd
llvm-project
&&
mkdir
build
&&
cd
build
&&
\
cmake
-DCMAKE_INSTALL_PREFIX
=
/opt/rocm/llvm
-DCMAKE_BUILD_TYPE
=
Release
-DLLVM_ENABLE_ASSERTIONS
=
1
-DLLVM_TARGETS_TO_BUILD
=
"AMDGPU;X86"
-DLLVM_ENABLE_PROJECTS
=
"clang;lld"
-DLLVM_ENABLE_RUNTIMES
=
"compiler-rt"
../llvm
&&
\
cmake
-DCMAKE_INSTALL_PREFIX
=
/opt/rocm/llvm
-DCMAKE_BUILD_TYPE
=
Release
-DLLVM_ENABLE_ASSERTIONS
=
1
-DLLVM_TARGETS_TO_BUILD
=
"AMDGPU;X86"
-DLLVM_ENABLE_PROJECTS
=
"clang;lld"
-DLLVM_ENABLE_RUNTIMES
=
"compiler-rt"
../llvm
&&
\
...
@@ -130,7 +130,7 @@ RUN if ( [ "$compiler_version" = "amd-stg-open" ] || [ "$compiler_version" = "am
...
@@ -130,7 +130,7 @@ RUN if ( [ "$compiler_version" = "amd-stg-open" ] || [ "$compiler_version" = "am
else
echo
"using the release compiler"
;
\
else
echo
"using the release compiler"
;
\
fi
fi
RUN if
(
[
"
$compiler_version
"
=
"amd-st
g-open
"
]
||
[
"
$compiler_version
"
=
"amd-mainline-open"
]
)
&&
[
"
$compiler_commit
"
!=
""
]
;
then
\
RUN if
(
[
"
$compiler_version
"
=
"amd-st
aging
"
]
||
[
"
$compiler_version
"
=
"amd-mainline-open"
]
)
&&
[
"
$compiler_commit
"
!=
""
]
;
then
\
git clone
-b
"
$compiler_version
"
https://github.com/RadeonOpenCompute/llvm-project.git
&&
\
git clone
-b
"
$compiler_version
"
https://github.com/RadeonOpenCompute/llvm-project.git
&&
\
cd
llvm-project
&&
git checkout
"
$compiler_commit
"
&&
echo
"checking out commit
$compiler_commit
"
&&
mkdir
build
&&
cd
build
&&
\
cd
llvm-project
&&
git checkout
"
$compiler_commit
"
&&
echo
"checking out commit
$compiler_commit
"
&&
mkdir
build
&&
cd
build
&&
\
cmake
-DCMAKE_INSTALL_PREFIX
=
/opt/rocm/llvm
-DCMAKE_BUILD_TYPE
=
Release
-DLLVM_ENABLE_ASSERTIONS
=
1
-DLLVM_TARGETS_TO_BUILD
=
"AMDGPU;X86"
-DLLVM_ENABLE_PROJECTS
=
"clang;lld"
-DLLVM_ENABLE_RUNTIMES
=
"compiler-rt"
../llvm
&&
\
cmake
-DCMAKE_INSTALL_PREFIX
=
/opt/rocm/llvm
-DCMAKE_BUILD_TYPE
=
Release
-DLLVM_ENABLE_ASSERTIONS
=
1
-DLLVM_TARGETS_TO_BUILD
=
"AMDGPU;X86"
-DLLVM_ENABLE_PROJECTS
=
"clang;lld"
-DLLVM_ENABLE_RUNTIMES
=
"compiler-rt"
../llvm
&&
\
...
...
Jenkinsfile
View file @
e878371c
...
@@ -84,7 +84,7 @@ def build_compiler(){
...
@@ -84,7 +84,7 @@ def build_compiler(){
compiler
=
'/opt/rocm/bin/hipcc'
compiler
=
'/opt/rocm/bin/hipcc'
}
}
else
{
else
{
if
(
params
.
COMPILER_VERSION
==
"amd-st
g-open
"
||
params
.
COMPILER_VERSION
==
"amd-mainline-open"
||
params
.
COMPILER_COMMIT
!=
""
){
if
(
params
.
COMPILER_VERSION
==
"amd-st
aging
"
||
params
.
COMPILER_VERSION
==
"amd-mainline-open"
||
params
.
COMPILER_COMMIT
!=
""
){
compiler
=
"/llvm-project/build/bin/clang++"
compiler
=
"/llvm-project/build/bin/clang++"
}
}
else
{
else
{
...
@@ -135,6 +135,7 @@ def buildDocker(install_prefix){
...
@@ -135,6 +135,7 @@ def buildDocker(install_prefix){
echo
"Building image: ${image_name}"
echo
"Building image: ${image_name}"
retimage
=
docker
.
build
(
"${image_name}"
,
dockerArgs
+
' .'
)
retimage
=
docker
.
build
(
"${image_name}"
,
dockerArgs
+
' .'
)
retimage
.
push
()
retimage
.
push
()
sh
'docker images -q -f dangling=true | xargs --no-run-if-empty docker rmi'
}
}
else
{
else
{
echo
"Checking for image: ${image_name}"
echo
"Checking for image: ${image_name}"
...
@@ -293,7 +294,7 @@ def buildHipClangJob(Map conf=[:]){
...
@@ -293,7 +294,7 @@ def buildHipClangJob(Map conf=[:]){
dockerOpts
=
dockerOpts
+
" --env HSA_XNACK=1 "
dockerOpts
=
dockerOpts
+
" --env HSA_XNACK=1 "
}
}
def
dockerArgs
=
"--build-arg PREFIX=${prefixpath} --build-arg compiler_version='${params.COMPILER_VERSION}' --build-arg compiler_commit='${params.COMPILER_COMMIT}' --build-arg ROCMVERSION='${params.ROCMVERSION}' "
def
dockerArgs
=
"--build-arg PREFIX=${prefixpath} --build-arg compiler_version='${params.COMPILER_VERSION}' --build-arg compiler_commit='${params.COMPILER_COMMIT}' --build-arg ROCMVERSION='${params.ROCMVERSION}' "
if
(
params
.
COMPILER_VERSION
==
"amd-st
g-open
"
||
params
.
COMPILER_VERSION
==
"amd-mainline-open"
||
params
.
COMPILER_COMMIT
!=
""
){
if
(
params
.
COMPILER_VERSION
==
"amd-st
aging
"
||
params
.
COMPILER_VERSION
==
"amd-mainline-open"
||
params
.
COMPILER_COMMIT
!=
""
){
dockerOpts
=
dockerOpts
+
" --env HIP_CLANG_PATH='/llvm-project/build/bin' "
dockerOpts
=
dockerOpts
+
" --env HIP_CLANG_PATH='/llvm-project/build/bin' "
}
}
...
@@ -348,7 +349,7 @@ def runCKProfiler(Map conf=[:]){
...
@@ -348,7 +349,7 @@ def runCKProfiler(Map conf=[:]){
dockerOpts
=
dockerOpts
+
" --env HSA_XNACK=1 "
dockerOpts
=
dockerOpts
+
" --env HSA_XNACK=1 "
}
}
def
dockerArgs
=
"--build-arg PREFIX=${prefixpath} --build-arg compiler_version='${params.COMPILER_VERSION}' --build-arg compiler_commit='${params.COMPILER_COMMIT}' --build-arg ROCMVERSION='${params.ROCMVERSION}' "
def
dockerArgs
=
"--build-arg PREFIX=${prefixpath} --build-arg compiler_version='${params.COMPILER_VERSION}' --build-arg compiler_commit='${params.COMPILER_COMMIT}' --build-arg ROCMVERSION='${params.ROCMVERSION}' "
if
(
params
.
COMPILER_VERSION
==
"amd-st
g-open
"
||
params
.
COMPILER_VERSION
==
"amd-mainline-open"
||
params
.
COMPILER_COMMIT
!=
""
){
if
(
params
.
COMPILER_VERSION
==
"amd-st
aging
"
||
params
.
COMPILER_VERSION
==
"amd-mainline-open"
||
params
.
COMPILER_COMMIT
!=
""
){
dockerOpts
=
dockerOpts
+
" --env HIP_CLANG_PATH='/llvm-project/build/bin' "
dockerOpts
=
dockerOpts
+
" --env HIP_CLANG_PATH='/llvm-project/build/bin' "
}
}
...
@@ -479,7 +480,7 @@ def Build_CK(Map conf=[:]){
...
@@ -479,7 +480,7 @@ def Build_CK(Map conf=[:]){
dockerOpts
=
dockerOpts
+
" --env HSA_XNACK=1 "
dockerOpts
=
dockerOpts
+
" --env HSA_XNACK=1 "
}
}
def
dockerArgs
=
"--build-arg PREFIX=${prefixpath} --build-arg compiler_version='${params.COMPILER_VERSION}' --build-arg compiler_commit='${params.COMPILER_COMMIT}' --build-arg ROCMVERSION='${params.ROCMVERSION}' "
def
dockerArgs
=
"--build-arg PREFIX=${prefixpath} --build-arg compiler_version='${params.COMPILER_VERSION}' --build-arg compiler_commit='${params.COMPILER_COMMIT}' --build-arg ROCMVERSION='${params.ROCMVERSION}' "
if
(
params
.
COMPILER_VERSION
==
"amd-st
g-open
"
||
params
.
COMPILER_VERSION
==
"amd-mainline-open"
||
params
.
COMPILER_COMMIT
!=
""
){
if
(
params
.
COMPILER_VERSION
==
"amd-st
aging
"
||
params
.
COMPILER_VERSION
==
"amd-mainline-open"
||
params
.
COMPILER_COMMIT
!=
""
){
dockerOpts
=
dockerOpts
+
" --env HIP_CLANG_PATH='/llvm-project/build/bin' "
dockerOpts
=
dockerOpts
+
" --env HIP_CLANG_PATH='/llvm-project/build/bin' "
}
}
...
@@ -657,7 +658,7 @@ def process_results(Map conf=[:]){
...
@@ -657,7 +658,7 @@ def process_results(Map conf=[:]){
//launch develop branch daily at 23:00 UT in FULL_QA mode and at 19:00 UT with latest staging compiler version
//launch develop branch daily at 23:00 UT in FULL_QA mode and at 19:00 UT with latest staging compiler version
CRON_SETTINGS
=
BRANCH_NAME
==
"develop"
?
'''0 23 * * * % RUN_FULL_QA=true;ROCMVERSION=6.0;COMPILER_VERSION=
CRON_SETTINGS
=
BRANCH_NAME
==
"develop"
?
'''0 23 * * * % RUN_FULL_QA=true;ROCMVERSION=6.0;COMPILER_VERSION=
0 21 * * * % ROCMVERSION=6.0;COMPILER_VERSION=;COMPILER_COMMIT=
0 21 * * * % ROCMVERSION=6.0;COMPILER_VERSION=;COMPILER_COMMIT=
0 19 * * * % BUILD_DOCKER=true;DL_KERNELS=true;COMPILER_VERSION=amd-st
g-open
;COMPILER_COMMIT=;USE_SCCACHE=false
0 19 * * * % BUILD_DOCKER=true;DL_KERNELS=true;COMPILER_VERSION=amd-st
aging
;COMPILER_COMMIT=;USE_SCCACHE=false
0 17 * * * % BUILD_DOCKER=true;DL_KERNELS=true;COMPILER_VERSION=amd-mainline-open;COMPILER_COMMIT=;USE_SCCACHE=false'''
:
""
0 17 * * * % BUILD_DOCKER=true;DL_KERNELS=true;COMPILER_VERSION=amd-mainline-open;COMPILER_COMMIT=;USE_SCCACHE=false'''
:
""
pipeline
{
pipeline
{
...
@@ -680,7 +681,7 @@ pipeline {
...
@@ -680,7 +681,7 @@ pipeline {
string
(
string
(
name:
'COMPILER_VERSION'
,
name:
'COMPILER_VERSION'
,
defaultValue:
''
,
defaultValue:
''
,
description:
'Specify which version of compiler to use: release, amd-st
g-open
, amd-mainline-open, or leave blank (default).'
)
description:
'Specify which version of compiler to use: release, amd-st
aging
, amd-mainline-open, or leave blank (default).'
)
string
(
string
(
name:
'COMPILER_COMMIT'
,
name:
'COMPILER_COMMIT'
,
defaultValue:
''
,
defaultValue:
''
,
...
@@ -713,6 +714,10 @@ pipeline {
...
@@ -713,6 +714,10 @@ pipeline {
name:
"RUN_CPPCHECK"
,
name:
"RUN_CPPCHECK"
,
defaultValue:
false
,
defaultValue:
false
,
description:
"Run the cppcheck static analysis (default: OFF)"
)
description:
"Run the cppcheck static analysis (default: OFF)"
)
booleanParam
(
name:
"RUN_PERFORMANCE_TESTS"
,
defaultValue:
false
,
description:
"Run the performance tests (default: OFF)"
)
}
}
environment
{
environment
{
dbuser
=
"${dbuser}"
dbuser
=
"${dbuser}"
...
@@ -890,7 +895,7 @@ pipeline {
...
@@ -890,7 +895,7 @@ pipeline {
{
{
when
{
when
{
beforeAgent
true
beforeAgent
true
expression
{
!
params
.
RUN_FULL_QA
.
toBoolean
()
}
expression
{
!
params
.
RUN_FULL_QA
.
toBoolean
()
&&
params
.
RUN_PERFORMANCE_TESTS
.
toBoolean
()
}
}
}
options
{
retry
(
2
)
}
options
{
retry
(
2
)
}
agent
{
label
rocmnode
(
"gfx908 || gfx90a"
)}
agent
{
label
rocmnode
(
"gfx908 || gfx90a"
)}
...
@@ -906,7 +911,7 @@ pipeline {
...
@@ -906,7 +911,7 @@ pipeline {
{
{
when
{
when
{
beforeAgent
true
beforeAgent
true
expression
{
params
.
RUN_FULL_QA
.
toBoolean
()
}
expression
{
params
.
RUN_FULL_QA
.
toBoolean
()
&&
params
.
RUN_PERFORMANCE_TESTS
.
toBoolean
()
}
}
}
options
{
retry
(
2
)
}
options
{
retry
(
2
)
}
agent
{
label
rocmnode
(
"gfx90a"
)}
agent
{
label
rocmnode
(
"gfx90a"
)}
...
@@ -925,6 +930,10 @@ pipeline {
...
@@ -925,6 +930,10 @@ pipeline {
parallel
parallel
{
{
stage
(
"Process results"
){
stage
(
"Process results"
){
when
{
beforeAgent
true
expression
{
params
.
RUN_PERFORMANCE_TESTS
.
toBoolean
()
}
}
agent
{
label
'mici'
}
agent
{
label
'mici'
}
steps
{
steps
{
process_results
()
process_results
()
...
...
docs/sphinx/requirements.in
View file @
e878371c
rocm-docs-core==0.33.
0
rocm-docs-core==0.33.
2
sphinxcontrib-bibtex==2.6.2
sphinxcontrib-bibtex==2.6.2
docs/sphinx/requirements.txt
View file @
e878371c
...
@@ -113,7 +113,7 @@ requests==2.31.0
...
@@ -113,7 +113,7 @@ requests==2.31.0
# via
# via
# pygithub
# pygithub
# sphinx
# sphinx
rocm-docs-core==0.33.
0
rocm-docs-core==0.33.
2
# via -r requirements.in
# via -r requirements.in
six==1.16.0
six==1.16.0
# via
# via
...
...
docs/wrapper.rst
View file @
e878371c
...
@@ -89,3 +89,4 @@ Operations
...
@@ -89,3 +89,4 @@ Operations
-------------------------------------
-------------------------------------
.. doxygenfile:: copy.hpp
.. doxygenfile:: copy.hpp
.. doxygenfile:: gemm.hpp
example/35_splitK_gemm/CMakeLists.txt
View file @
e878371c
...
@@ -10,6 +10,9 @@ foreach(gpu IN LISTS GPU_TARGETS)
...
@@ -10,6 +10,9 @@ foreach(gpu IN LISTS GPU_TARGETS)
add_example_executable
(
example_splitK_gemm_xdl_fp16 splitK_gemm_xdl_fp16.cpp
)
add_example_executable
(
example_splitK_gemm_xdl_fp16 splitK_gemm_xdl_fp16.cpp
)
add_example_dependencies
(
example_splitK_gemm_xdl example_splitK_gemm_xdl_fp16
)
add_example_dependencies
(
example_splitK_gemm_xdl example_splitK_gemm_xdl_fp16
)
add_example_executable
(
example_splitK_gemm_xdl_lds_direct_load_fp16 splitK_gemm_xdl_lds_direct_load_fp16.cpp
)
add_example_dependencies
(
example_splitK_gemm_xdl example_splitK_gemm_xdl_lds_direct_load_fp16
)
add_example_executable
(
example_splitK_gemm_xdl_bf16 splitK_gemm_xdl_bf16.cpp
)
add_example_executable
(
example_splitK_gemm_xdl_bf16 splitK_gemm_xdl_bf16.cpp
)
add_example_dependencies
(
example_splitK_gemm_xdl example_splitK_gemm_xdl_bf16
)
add_example_dependencies
(
example_splitK_gemm_xdl example_splitK_gemm_xdl_bf16
)
...
...
example/35_splitK_gemm/splitK_gemm_xdl_lds_direct_load_fp16.cpp
0 → 100644
View file @
e878371c
// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
#include <iostream>
#include <numeric>
#include <initializer_list>
#include <cstdlib>
#include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#define DIRECT_LOAD 1
#if DIRECT_LOAD
#include "ck/tensor_operation/gpu/device/impl/device_gemm_xdl_splitk_c_shuffle_lds_direct_load.hpp"
#else
#include "ck/tensor_operation/gpu/device/impl/device_gemm_xdl_splitk_c_shuffle.hpp"
#endif
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/library/utility/check_err.hpp"
#include "ck/library/utility/device_memory.hpp"
#include "ck/library/utility/host_tensor.hpp"
#include "ck/library/utility/host_tensor_generator.hpp"
#include "ck/library/reference_tensor_operation/cpu/reference_gemm.hpp"
#include "ck/library/utility/literals.hpp"
template
<
ck
::
index_t
...
Is
>
using
S
=
ck
::
Sequence
<
Is
...
>
;
using
F16
=
ck
::
half_t
;
using
F32
=
float
;
using
Row
=
ck
::
tensor_layout
::
gemm
::
RowMajor
;
using
Col
=
ck
::
tensor_layout
::
gemm
::
ColumnMajor
;
using
PassThrough
=
ck
::
tensor_operation
::
element_wise
::
PassThrough
;
using
ADataType
=
F16
;
using
BDataType
=
F16
;
using
AccDataType
=
F32
;
using
CDataType
=
F16
;
using
ALayout
=
Row
;
using
BLayout
=
Col
;
using
CLayout
=
Row
;
using
AElementOp
=
PassThrough
;
using
BElementOp
=
PassThrough
;
using
CElementOp
=
PassThrough
;
#if DIRECT_LOAD
static
constexpr
auto
GemmDefault
=
ck
::
tensor_operation
::
device
::
GemmSpecialization
::
MNKPadding
;
using
DeviceGemmInstance
=
ck
::
tensor_operation
::
device
::
DeviceGemmXdlSplitKCShuffle_LdsDirectLoad
// clang-format off
//######| AData| BData| CData| AccData| ALayout| BLayout| CLayout| A| B| C| GEMM| NumGemmK| Block| MPer| NPer| KPer| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer|
//######| Type| Type| Type| Type| | | | Elementwise| Elementwise| Elementwise| Spacialization| Prefetch| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| AddExtraM| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MXdlPerWave_MWaveMPerXdl| ScalarPerVector|
//######| | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | Wave| Wave| Lengths_KBatch_K0_M_K1| | | PerVector| | Lengths_KBatch_K0_N_K1| | | PerVector| | PerShuffle| PerShuffle| _NBlock_NXdlPerWave_NWaveNPerXdl| _NWaveNPerXdl|
//######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
<
ADataType
,
BDataType
,
CDataType
,
AccDataType
,
ALayout
,
BLayout
,
CLayout
,
AElementOp
,
BElementOp
,
CElementOp
,
GemmDefault
,
2
,
128
,
32
,
16
,
4
,
16
,
16
,
16
,
1
,
1
,
S
<
1
,
2
,
8
,
8
>
,
S
<
0
,
2
,
1
,
3
>
,
3
,
2
,
true
,
S
<
1
,
2
,
8
,
8
>
,
S
<
0
,
2
,
1
,
3
>
,
3
,
2
,
true
,
1
,
1
,
S
<
1
,
32
,
1
,
4
>
,
4
>
;
// clang-format on
#else
static
constexpr
auto
GemmDefault
=
ck
::
tensor_operation
::
device
::
GemmSpecialization
::
Default
;
using
DeviceGemmInstance
=
ck
::
tensor_operation
::
device
::
DeviceGemmXdlSplitKCShuffle
// clang-format off
//######| AData| BData| CData| AccData| ALayout| BLayout| CLayout| A| B| C| GEMM| Block| MPer| NPer| KPer| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer|
//######| Type| Type| Type| Type| | | | Elementwise| Elementwise| Elementwise| Spacialization| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MXdlPerWave_MWaveMPerXdl| ScalarPerVector|
//######| | | | | | | | Operation| Operation| Operation| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NXdlPerWave_NWaveNPerXdl| _NWaveNPerXdl|
//######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
<
ADataType
,
BDataType
,
CDataType
,
AccDataType
,
ALayout
,
BLayout
,
CLayout
,
AElementOp
,
BElementOp
,
CElementOp
,
GemmDefault
,
256
,
256
,
128
,
4
,
8
,
32
,
32
,
4
,
2
,
S
<
1
,
4
,
64
,
1
>
,
S
<
0
,
2
,
1
,
3
>
,
S
<
0
,
2
,
1
,
3
>
,
3
,
8
,
8
,
true
,
S
<
1
,
4
,
64
,
1
>
,
S
<
0
,
1
,
3
,
2
>
,
S
<
0
,
1
,
3
,
2
>
,
3
,
8
,
8
,
true
,
1
,
1
,
S
<
1
,
32
,
1
,
8
>
,
8
>
;
// clang-format on
#endif
#include "run_splitK_gemm_example.inc"
int
main
(
int
argc
,
char
*
argv
[])
{
return
!
run_splitK_gemm_example
(
argc
,
argv
);
}
include/ck/ck.hpp
View file @
e878371c
...
@@ -44,16 +44,30 @@
...
@@ -44,16 +44,30 @@
#define CK_USE_WAVES_PER_EU 0
#define CK_USE_WAVES_PER_EU 0
#endif
#endif
// define general macros for various architectures
#if defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__)
#define __gfx94__
#endif
#if defined(__gfx1010__) || defined(__gfx1011__) || defined(__gfx1012__)
#define __gfx101__
#endif
#if defined(__gfx1030__) || defined(__gfx1031__) || defined(__gfx1032__) || \
defined(__gfx1034__) || defined(__gfx1035__) || defined(__gfx1036__)
#define __gfx103__
#endif
#if defined(__gfx1100__) || defined(__gfx1101__) || defined(__gfx1102__) || defined(__gfx1103__)
#define __gfx11__
#endif
// buffer resource
// buffer resource
#ifndef __HIP_DEVICE_COMPILE__ // for host code
#ifndef __HIP_DEVICE_COMPILE__ // for host code
#define CK_BUFFER_RESOURCE_3RD_DWORD -1
#define CK_BUFFER_RESOURCE_3RD_DWORD -1
#elif defined(__gfx803__) || defined(__gfx900__) || defined(__gfx906__) || defined(__gfx908__) || \
#elif defined(__gfx803__) || defined(__gfx900__) || defined(__gfx906__) || defined(__gfx908__) || \
defined(__gfx90a__) || defined(__gfx940__) || defined(__gfx941__) || \
defined(__gfx90a__) || defined(__gfx94__)
defined(__gfx942__) // for GPU code
#define CK_BUFFER_RESOURCE_3RD_DWORD 0x00020000
#define CK_BUFFER_RESOURCE_3RD_DWORD 0x00020000
#elif defined(__gfx103
0
__)
// for GPU code
#elif defined(__gfx103__)
#define CK_BUFFER_RESOURCE_3RD_DWORD 0x31014000
#define CK_BUFFER_RESOURCE_3RD_DWORD 0x31014000
#elif defined(__gfx11
00
__)
|| defined(__gfx1101__) || defined(__gfx1102__) // for GPU code
#elif defined(__gfx11__)
#define CK_BUFFER_RESOURCE_3RD_DWORD 0x31004000
#define CK_BUFFER_RESOURCE_3RD_DWORD 0x31004000
#endif
#endif
...
@@ -61,12 +75,12 @@
...
@@ -61,12 +75,12 @@
#ifndef __HIP_DEVICE_COMPILE__ // for host code, define nothing
#ifndef __HIP_DEVICE_COMPILE__ // for host code, define nothing
#elif defined(__gfx803__) || defined(__gfx900__) // for GPU code
#elif defined(__gfx803__) || defined(__gfx900__) // for GPU code
#define CK_USE_AMD_V_MAC_F32
#define CK_USE_AMD_V_MAC_F32
#elif defined(__gfx906__) || defined(__gfx908__) || defined(__gfx90a__) || defined(__gfx103
0
__) || \
#elif defined(__gfx906__) || defined(__gfx908__) || defined(__gfx90a__) || defined(__gfx103__) || \
defined(__gfx94
0__) || defined(__gfx941__) || defined(__gfx942
__) // for GPU code
defined(__gfx94__) // for GPU code
#define CK_USE_AMD_V_FMAC_F32
#define CK_USE_AMD_V_FMAC_F32
#define CK_USE_AMD_V_DOT2_F32_F16
#define CK_USE_AMD_V_DOT2_F32_F16
#define CK_USE_AMD_V_DOT4_I32_I8
#define CK_USE_AMD_V_DOT4_I32_I8
#elif defined(__gfx11
00__) || defined(__gfx1101__) || defined(__gfx1102
__)
#elif defined(__gfx11__)
#define CK_USE_AMD_V_FMAC_F32
#define CK_USE_AMD_V_FMAC_F32
#define CK_USE_AMD_V_DOT2_F32_F16
#define CK_USE_AMD_V_DOT2_F32_F16
#define CK_USE_AMD_V_DOT4_I32_I8_GFX11
#define CK_USE_AMD_V_DOT4_I32_I8_GFX11
...
@@ -75,23 +89,22 @@
...
@@ -75,23 +89,22 @@
// MFMA instruction
// MFMA instruction
#ifndef __HIP_DEVICE_COMPILE__ // for host code
#ifndef __HIP_DEVICE_COMPILE__ // for host code
#define CK_USE_AMD_MFMA
#define CK_USE_AMD_MFMA
#elif defined(__gfx908__) || defined(__gfx90a__) || defined(__gfx940__) || defined(__gfx941__) || \
#elif defined(__gfx908__) || defined(__gfx90a__) || defined(__gfx94__) // for GPU code
defined(__gfx942__) // for GPU code
#define CK_USE_AMD_MFMA
#define CK_USE_AMD_MFMA
#endif
#endif
#if(defined(__gfx90a__) || defined(__gfx94
0__) || defined(__gfx941__) || defined(__gfx942
__))
#if(defined(__gfx90a__) || defined(__gfx94__))
#define CK_USE_AMD_MFMA_BF16_1K_OP
#define CK_USE_AMD_MFMA_BF16_1K_OP
#endif
#endif
#if defined(__gfx94
0__) || defined(__gfx941__) || defined(__gfx942
__)
#if defined(__gfx94__)
#define CK_USE_AMD_MFMA_GFX940
#define CK_USE_AMD_MFMA_GFX940
#endif
#endif
// WMMA instruction
// WMMA instruction
#ifndef __HIP_DEVICE_COMPILE__ // for host code
#ifndef __HIP_DEVICE_COMPILE__ // for host code
#define CK_USE_AMD_WMMA
#define CK_USE_AMD_WMMA
#elif defined(__gfx11
00__) || defined(__gfx1101__) || defined(__gfx1102
__) // for GPU code
#elif defined(__gfx11__) // for GPU code
#define CK_USE_AMD_WMMA
#define CK_USE_AMD_WMMA
#endif
#endif
...
@@ -107,15 +120,13 @@
...
@@ -107,15 +120,13 @@
// buffer atomic add: floating point
// buffer atomic add: floating point
#ifndef __HIP_DEVICE_COMPILE__ // for host code
#ifndef __HIP_DEVICE_COMPILE__ // for host code
#define CK_USE_AMD_BUFFER_ATOMIC_ADD_FLOAT 1
#define CK_USE_AMD_BUFFER_ATOMIC_ADD_FLOAT 1
#elif defined(__gfx908__) || defined(__gfx90a__) || defined(__gfx940__) || defined(__gfx941__) || \
#elif defined(__gfx908__) || defined(__gfx90a__) || defined(__gfx94__) // for GPU code
defined(__gfx942__) // for GPU code
#define CK_USE_AMD_BUFFER_ATOMIC_ADD_FLOAT 1
#define CK_USE_AMD_BUFFER_ATOMIC_ADD_FLOAT 1
#else // for GPU code
#else // for GPU code
#define CK_USE_AMD_BUFFER_ATOMIC_ADD_FLOAT 0
#define CK_USE_AMD_BUFFER_ATOMIC_ADD_FLOAT 0
#endif
#endif
#if(defined(__gfx90a__) || defined(__gfx940__) || defined(__gfx941__) || \
#if(defined(__gfx90a__) || defined(__gfx94__)) // for GPU code
defined(__gfx942__)) // for GPU code
#define CK_USE_AMD_BUFFER_ATOMIC_MAX_FLOAT64 1
#define CK_USE_AMD_BUFFER_ATOMIC_MAX_FLOAT64 1
#else
#else
#define CK_USE_AMD_BUFFER_ATOMIC_MAX_FLOAT64 0
#define CK_USE_AMD_BUFFER_ATOMIC_MAX_FLOAT64 0
...
...
include/ck/host_utility/device_prop.hpp
View file @
e878371c
...
@@ -65,4 +65,23 @@ inline bool is_lds_direct_load_supported()
...
@@ -65,4 +65,23 @@ inline bool is_lds_direct_load_supported()
ck
::
get_device_name
()
==
"gfx941"
||
ck
::
get_device_name
()
==
"gfx942"
;
ck
::
get_device_name
()
==
"gfx941"
||
ck
::
get_device_name
()
==
"gfx942"
;
}
}
inline
bool
is_navi1_supported
()
{
return
ck
::
get_device_name
()
==
"gfx1010"
||
ck
::
get_device_name
()
==
"gfx1011"
||
ck
::
get_device_name
()
==
"gfx1012"
;
}
inline
bool
is_navi2_supported
()
{
return
ck
::
get_device_name
()
==
"gfx1030"
||
ck
::
get_device_name
()
==
"gfx1031"
||
ck
::
get_device_name
()
==
"gfx1032"
||
ck
::
get_device_name
()
==
"gfx1034"
||
ck
::
get_device_name
()
==
"gfx1035"
||
ck
::
get_device_name
()
==
"gfx1036"
;
}
inline
bool
is_navi3_supported
()
{
return
ck
::
get_device_name
()
==
"gfx1100"
||
ck
::
get_device_name
()
==
"gfx1101"
||
ck
::
get_device_name
()
==
"gfx1102"
||
ck
::
get_device_name
()
==
"gfx1103"
;
}
}
// namespace ck
}
// namespace ck
include/ck/tensor_operation/gpu/device/impl/device_batched_contraction_multiple_d_wmma_cshuffle.hpp
View file @
e878371c
...
@@ -770,8 +770,7 @@ struct DeviceBatchedContractionMultipleD_Wmma_CShuffle
...
@@ -770,8 +770,7 @@ struct DeviceBatchedContractionMultipleD_Wmma_CShuffle
static
bool
IsSupportedArgument
(
const
Argument
&
arg
)
static
bool
IsSupportedArgument
(
const
Argument
&
arg
)
{
{
if
(
ck
::
get_device_name
()
==
"gfx1100"
||
ck
::
get_device_name
()
==
"gfx1101"
||
if
(
ck
::
is_navi3_supported
())
ck
::
get_device_name
()
==
"gfx1102"
)
{
{
if
constexpr
(
!
(
is_same_v
<
AccDataType
,
float
>
||
is_same_v
<
AccDataType
,
int32_t
>
))
if
constexpr
(
!
(
is_same_v
<
AccDataType
,
float
>
||
is_same_v
<
AccDataType
,
int32_t
>
))
{
{
...
...
include/ck/tensor_operation/gpu/device/impl/device_batched_contraction_multiple_d_xdl_cshuffle.hpp
View file @
e878371c
...
@@ -57,7 +57,7 @@ __global__ void
...
@@ -57,7 +57,7 @@ __global__ void
const
Block2ETileMap
block_2_etile_map
)
const
Block2ETileMap
block_2_etile_map
)
{
{
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \
defined(__gfx94
0__) || defined(__gfx941__) || defined(__gfx942
__))
defined(__gfx94__))
__shared__
char
p_shared
[
GridwiseGemm
::
GetSharedMemoryNumberOfByte
()];
__shared__
char
p_shared
[
GridwiseGemm
::
GetSharedMemoryNumberOfByte
()];
const
index_t
num_blocks_per_batch
=
const
index_t
num_blocks_per_batch
=
...
...
include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_e_permute_xdl.hpp
View file @
e878371c
...
@@ -75,7 +75,7 @@ __global__ void
...
@@ -75,7 +75,7 @@ __global__ void
const
Block2ETileMap
block_2_etile_map
)
const
Block2ETileMap
block_2_etile_map
)
{
{
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \
defined(__gfx94
0__) || defined(__gfx941__) || defined(__gfx942
__))
defined(__gfx94__))
const
index_t
num_blocks_per_batch
=
const
index_t
num_blocks_per_batch
=
__builtin_amdgcn_readfirstlane
(
get_grid_size
()
/
batch_count
);
__builtin_amdgcn_readfirstlane
(
get_grid_size
()
/
batch_count
);
const
index_t
g_idx
=
__builtin_amdgcn_readfirstlane
(
get_block_1d_id
()
/
num_blocks_per_batch
);
const
index_t
g_idx
=
__builtin_amdgcn_readfirstlane
(
get_block_1d_id
()
/
num_blocks_per_batch
);
...
...
include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_gemm_xdl_cshuffle.hpp
View file @
e878371c
...
@@ -61,7 +61,7 @@ __global__ void
...
@@ -61,7 +61,7 @@ __global__ void
const
ComputeBasePtrOfStridedBatch
compute_base_ptr_of_batch
)
const
ComputeBasePtrOfStridedBatch
compute_base_ptr_of_batch
)
{
{
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \
defined(__gfx94
0__) || defined(__gfx941__) || defined(__gfx942
__))
defined(__gfx94__))
__shared__
char
p_shared
[
GridwiseGemm
::
GetSharedMemoryNumberOfByte
()];
__shared__
char
p_shared
[
GridwiseGemm
::
GetSharedMemoryNumberOfByte
()];
const
index_t
num_blocks_per_batch
=
const
index_t
num_blocks_per_batch
=
__builtin_amdgcn_readfirstlane
(
get_grid_size
()
/
batch_count
);
__builtin_amdgcn_readfirstlane
(
get_grid_size
()
/
batch_count
);
...
...
include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_multi_d_xdl.hpp
View file @
e878371c
...
@@ -84,7 +84,7 @@ __global__ void
...
@@ -84,7 +84,7 @@ __global__ void
{
{
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \
defined(__gfx94
0__) || defined(__gfx941__) || defined(__gfx942
__))
defined(__gfx94__))
const
index_t
num_blocks_per_batch
=
const
index_t
num_blocks_per_batch
=
__builtin_amdgcn_readfirstlane
(
get_grid_size
()
/
batch_count
);
__builtin_amdgcn_readfirstlane
(
get_grid_size
()
/
batch_count
);
const
index_t
g_idx
=
__builtin_amdgcn_readfirstlane
(
get_block_1d_id
()
/
num_blocks_per_batch
);
const
index_t
g_idx
=
__builtin_amdgcn_readfirstlane
(
get_block_1d_id
()
/
num_blocks_per_batch
);
...
...
include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_multiple_d_dl.hpp
View file @
e878371c
...
@@ -70,9 +70,8 @@ __global__ void
...
@@ -70,9 +70,8 @@ __global__ void
const
ComputePtrOffsetOfBatch
compute_ptr_offset_of_batch
,
const
ComputePtrOffsetOfBatch
compute_ptr_offset_of_batch
,
const
Block2CTileMap
block_2_ctile_map
)
const
Block2CTileMap
block_2_ctile_map
)
{
{
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx906__) || defined(__gfx908__) || \
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx906__) || defined(__gfx908__) || \
defined(__gfx90a__) || defined(__gfx940__) || defined(__gfx1030__) || defined(__gfx1100__) || \
defined(__gfx90a__) || defined(__gfx94__) || defined(__gfx103__) || defined(__gfx11__))
defined(__gfx1101__) || defined(__gfx1102__))
const
index_t
num_blocks_per_batch
=
const
index_t
num_blocks_per_batch
=
__builtin_amdgcn_readfirstlane
(
get_grid_size
()
/
batch_count
);
__builtin_amdgcn_readfirstlane
(
get_grid_size
()
/
batch_count
);
...
@@ -648,11 +647,8 @@ struct DeviceBatchedGemmMultipleD_Dl : public DeviceBatchedGemmMultiD<ALayout,
...
@@ -648,11 +647,8 @@ struct DeviceBatchedGemmMultipleD_Dl : public DeviceBatchedGemmMultiD<ALayout,
static
bool
IsSupportedArgument
(
const
Argument
&
arg
)
static
bool
IsSupportedArgument
(
const
Argument
&
arg
)
{
{
// TODO: Enable for gfx90a after complier fix
if
(
ck
::
get_device_name
()
==
"gfx906"
||
ck
::
is_xdl_supported
()
||
if
(
ck
::
get_device_name
()
==
"gfx906"
||
ck
::
get_device_name
()
==
"gfx90a"
||
ck
::
is_navi2_supported
()
||
ck
::
is_navi3_supported
())
ck
::
get_device_name
()
==
"gfx908"
||
ck
::
get_device_name
()
==
"gfx1030"
||
ck
::
get_device_name
()
==
"gfx940"
||
ck
::
get_device_name
()
==
"gfx1100"
||
ck
::
get_device_name
()
==
"gfx1101"
||
ck
::
get_device_name
()
==
"gfx1102"
)
{
{
bool
pass
=
true
;
bool
pass
=
true
;
pass
=
pass
&&
arg
.
K_
%
K1
==
0
;
pass
=
pass
&&
arg
.
K_
%
K1
==
0
;
...
...
include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_multiple_d_gemm_multiple_d_xdl_cshuffle.hpp
View file @
e878371c
...
@@ -69,7 +69,7 @@ __global__ void
...
@@ -69,7 +69,7 @@ __global__ void
const
ComputeBasePtrOfStridedBatch
compute_base_ptr_of_batch
)
const
ComputeBasePtrOfStridedBatch
compute_base_ptr_of_batch
)
{
{
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \
defined(__gfx94
0__) || defined(__gfx941__) || defined(__gfx942
__))
defined(__gfx94__))
__shared__
char
p_shared
[
GridwiseGemm
::
GetSharedMemoryNumberOfByte
()];
__shared__
char
p_shared
[
GridwiseGemm
::
GetSharedMemoryNumberOfByte
()];
const
index_t
num_blocks_per_batch
=
const
index_t
num_blocks_per_batch
=
__builtin_amdgcn_readfirstlane
(
get_grid_size
()
/
batch_count
);
__builtin_amdgcn_readfirstlane
(
get_grid_size
()
/
batch_count
);
...
...
include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_reduce_xdl_cshuffle.hpp
View file @
e878371c
...
@@ -60,7 +60,7 @@ __global__ void
...
@@ -60,7 +60,7 @@ __global__ void
const
Block2CTileMap
block_2_ctile_map
)
const
Block2CTileMap
block_2_ctile_map
)
{
{
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \
defined(__gfx94
0__) || defined(__gfx941__) || defined(__gfx942
__))
defined(__gfx94__))
const
index_t
num_blocks_per_batch
=
const
index_t
num_blocks_per_batch
=
__builtin_amdgcn_readfirstlane
(
get_grid_size
()
/
batch_count
);
__builtin_amdgcn_readfirstlane
(
get_grid_size
()
/
batch_count
);
const
index_t
g_idx
=
__builtin_amdgcn_readfirstlane
(
get_block_1d_id
()
/
num_blocks_per_batch
);
const
index_t
g_idx
=
__builtin_amdgcn_readfirstlane
(
get_block_1d_id
()
/
num_blocks_per_batch
);
...
...
include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_softmax_gemm_permute_xdl_cshuffle.hpp
View file @
e878371c
...
@@ -68,7 +68,7 @@ __global__ void
...
@@ -68,7 +68,7 @@ __global__ void
const
C0MatrixMask
c0_matrix_mask
)
const
C0MatrixMask
c0_matrix_mask
)
{
{
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \
defined(__gfx94
0__) || defined(__gfx941__) || defined(__gfx942
__))
defined(__gfx94__))
__shared__
char
p_shared
[
GridwiseGemm
::
GetSharedMemoryNumberOfByte
()];
__shared__
char
p_shared
[
GridwiseGemm
::
GetSharedMemoryNumberOfByte
()];
const
index_t
num_blocks_per_batch
=
const
index_t
num_blocks_per_batch
=
__builtin_amdgcn_readfirstlane
(
get_grid_size
()
/
batch_count
);
__builtin_amdgcn_readfirstlane
(
get_grid_size
()
/
batch_count
);
...
...
include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_softmax_gemm_xdl_cshuffle.hpp
View file @
e878371c
...
@@ -63,7 +63,7 @@ __global__ void
...
@@ -63,7 +63,7 @@ __global__ void
const
C0MatrixMask
c0_matrix_mask
)
const
C0MatrixMask
c0_matrix_mask
)
{
{
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \
defined(__gfx94
0__) || defined(__gfx941__) || defined(__gfx942
__))
defined(__gfx94__))
__shared__
char
p_shared
[
GridwiseGemm
::
GetSharedMemoryNumberOfByte
()];
__shared__
char
p_shared
[
GridwiseGemm
::
GetSharedMemoryNumberOfByte
()];
const
index_t
num_blocks_per_batch
=
const
index_t
num_blocks_per_batch
=
__builtin_amdgcn_readfirstlane
(
get_grid_size
()
/
batch_count
);
__builtin_amdgcn_readfirstlane
(
get_grid_size
()
/
batch_count
);
...
...
Prev
1
2
3
4
5
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment