Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel_ROCM
Commits
941d1f7c
Unverified
Commit
941d1f7c
authored
Jun 27, 2024
by
Illia Silin
Committed by
GitHub
Jun 27, 2024
Browse files
Merging the gfx12 code into public repo. (#1362)
parent
a32b1bc6
Changes
49
Expand all
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
589 additions
and
73 deletions
+589
-73
CMakeLists.txt
CMakeLists.txt
+4
-2
Jenkinsfile
Jenkinsfile
+1
-3
cmake/EnableCompilerWarnings.cmake
cmake/EnableCompilerWarnings.cmake
+1
-1
example/01_gemm/gemm_wmma_fp16.cpp
example/01_gemm/gemm_wmma_fp16.cpp
+27
-27
example/01_gemm/run_gemm_example.inc
example/01_gemm/run_gemm_example.inc
+1
-1
example/04_gemm_add_add_fastgelu/CMakeLists.txt
example/04_gemm_add_add_fastgelu/CMakeLists.txt
+1
-1
example/29_batched_gemm_bias_e_permute/batched_gemm_bias_e_permute_wmma_fp16.cpp
..._bias_e_permute/batched_gemm_bias_e_permute_wmma_fp16.cpp
+2
-2
example/32_batched_gemm_scale_softmax_gemm/cross_attention_forward_wmma_fp16.cpp
..._scale_softmax_gemm/cross_attention_forward_wmma_fp16.cpp
+3
-3
example/32_batched_gemm_scale_softmax_gemm/self_attention_forward_wmma_fp16.cpp
...m_scale_softmax_gemm/self_attention_forward_wmma_fp16.cpp
+3
-3
example/CMakeLists.txt
example/CMakeLists.txt
+2
-2
include/ck/ck.hpp
include/ck/ck.hpp
+5
-9
include/ck/host_utility/device_prop.hpp
include/ck/host_utility/device_prop.hpp
+5
-0
include/ck/tensor_operation/gpu/block/blockwise_gemm_wmma.hpp
...ude/ck/tensor_operation/gpu/block/blockwise_gemm_wmma.hpp
+499
-0
include/ck/tensor_operation/gpu/block/blockwise_gemm_xdlops.hpp
...e/ck/tensor_operation/gpu/block/blockwise_gemm_xdlops.hpp
+7
-0
include/ck/tensor_operation/gpu/device/impl/device_batched_contraction_multiple_d_wmma_cshuffle.hpp
...l/device_batched_contraction_multiple_d_wmma_cshuffle.hpp
+16
-7
include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_multiple_d_dl.hpp
...ion/gpu/device/impl/device_batched_gemm_multiple_d_dl.hpp
+4
-3
include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_softmax_gemm_permute_wmma_cshuffle.hpp
...evice_batched_gemm_softmax_gemm_permute_wmma_cshuffle.hpp
+5
-4
include/ck/tensor_operation/gpu/device/impl/device_contraction_multiple_d_xdl_cshuffle.hpp
...evice/impl/device_contraction_multiple_d_xdl_cshuffle.hpp
+1
-3
include/ck/tensor_operation/gpu/device/impl/device_convnd_bwd_data_nwc_kxc_nwk_dl.hpp
...gpu/device/impl/device_convnd_bwd_data_nwc_kxc_nwk_dl.hpp
+1
-1
include/ck/tensor_operation/gpu/device/impl/device_fpAintB_gemm_wmma.hpp
...or_operation/gpu/device/impl/device_fpAintB_gemm_wmma.hpp
+1
-1
No files found.
CMakeLists.txt
View file @
941d1f7c
...
@@ -117,7 +117,7 @@ else()
...
@@ -117,7 +117,7 @@ else()
add_definitions
(
-DPROFILER_ONLY
)
add_definitions
(
-DPROFILER_ONLY
)
set
(
GPU_TARGETS
""
CACHE STRING
""
FORCE
)
set
(
GPU_TARGETS
""
CACHE STRING
""
FORCE
)
if
(
GPU_TARGETS
)
if
(
GPU_TARGETS
)
message
(
FATAL_ERROR
"For PROFILE_ONLY build, please do not set GPU_TARGETS, use GPU_ARCH = gfx90, gfx94, gfx10, or gfx1
1
"
)
message
(
FATAL_ERROR
"For PROFILE_ONLY build, please do not set GPU_TARGETS, use GPU_ARCH = gfx90, gfx94, gfx10,
gfx11
or gfx1
2
"
)
endif
()
endif
()
if
(
GPU_ARCH MATCHES
"gfx90"
)
if
(
GPU_ARCH MATCHES
"gfx90"
)
rocm_check_target_ids
(
DEFAULT_GPU_TARGETS TARGETS
"gfx908;gfx90a"
)
rocm_check_target_ids
(
DEFAULT_GPU_TARGETS TARGETS
"gfx908;gfx90a"
)
...
@@ -127,8 +127,10 @@ else()
...
@@ -127,8 +127,10 @@ else()
rocm_check_target_ids
(
DEFAULT_GPU_TARGETS TARGETS
"gfx1030"
)
rocm_check_target_ids
(
DEFAULT_GPU_TARGETS TARGETS
"gfx1030"
)
elseif
(
GPU_ARCH MATCHES
"gfx11"
)
elseif
(
GPU_ARCH MATCHES
"gfx11"
)
rocm_check_target_ids
(
DEFAULT_GPU_TARGETS TARGETS
"gfx1100;gfx1101;gfx1102"
)
rocm_check_target_ids
(
DEFAULT_GPU_TARGETS TARGETS
"gfx1100;gfx1101;gfx1102"
)
elseif
(
GPU_ARCH MATCHES
"gfx12"
)
rocm_check_target_ids
(
DEFAULT_GPU_TARGETS TARGETS
"gfx1200;gfx1201"
)
else
()
else
()
message
(
FATAL_ERROR
"For PROFILE_ONLY build, please specify GPU_ARCH as gfx90, gfx94, gfx10, or gfx1
1
"
)
message
(
FATAL_ERROR
"For PROFILE_ONLY build, please specify GPU_ARCH as gfx90, gfx94, gfx10,
gfx11
or gfx1
2
"
)
endif
()
endif
()
set
(
GPU_TARGETS
"
${
DEFAULT_GPU_TARGETS
}
"
CACHE STRING
" "
FORCE
)
set
(
GPU_TARGETS
"
${
DEFAULT_GPU_TARGETS
}
"
CACHE STRING
" "
FORCE
)
endif
()
endif
()
...
...
Jenkinsfile
View file @
941d1f7c
...
@@ -493,6 +493,7 @@ def Build_CK(Map conf=[:]){
...
@@ -493,6 +493,7 @@ def Build_CK(Map conf=[:]){
def
variant
=
env
.
STAGE_NAME
def
variant
=
env
.
STAGE_NAME
def
retimage
def
retimage
gitStatusWrapper
(
credentialsId:
"${env.status_wrapper_creds}"
,
gitHubContext:
"Jenkins - ${variant}"
,
account:
'ROCm'
,
repo:
'composable_kernel'
)
{
gitStatusWrapper
(
credentialsId:
"${env.status_wrapper_creds}"
,
gitHubContext:
"Jenkins - ${variant}"
,
account:
'ROCm'
,
repo:
'composable_kernel'
)
{
try
{
try
{
(
retimage
,
image
)
=
getDockerImage
(
conf
)
(
retimage
,
image
)
=
getDockerImage
(
conf
)
...
@@ -660,9 +661,6 @@ CRON_SETTINGS = BRANCH_NAME == "develop" ? '''0 23 * * * % RUN_FULL_QA=true;ROCM
...
@@ -660,9 +661,6 @@ CRON_SETTINGS = BRANCH_NAME == "develop" ? '''0 23 * * * % RUN_FULL_QA=true;ROCM
pipeline
{
pipeline
{
agent
none
agent
none
triggers
{
parameterizedCron
(
CRON_SETTINGS
)
}
options
{
options
{
parallelsAlwaysFailFast
()
parallelsAlwaysFailFast
()
}
}
...
...
cmake/EnableCompilerWarnings.cmake
View file @
941d1f7c
...
@@ -66,7 +66,7 @@ else()
...
@@ -66,7 +66,7 @@ else()
-Wunreachable-code
-Wunreachable-code
-Wunused
-Wunused
-Wno-reserved-identifier
-Wno-reserved-identifier
-Werror
-Werror
-Wno-option-ignored
-Wno-option-ignored
-Wsign-compare
-Wsign-compare
-Wno-extra-semi-stmt
-Wno-extra-semi-stmt
...
...
example/01_gemm/gemm_wmma_fp16.cpp
View file @
941d1f7c
...
@@ -23,45 +23,45 @@ static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecializa
...
@@ -23,45 +23,45 @@ static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecializa
// clang-format off
// clang-format off
using
DeviceGemmInstance
=
ck
::
tensor_operation
::
device
::
DeviceGemmWmma_CShuffle
using
DeviceGemmInstance
=
ck
::
tensor_operation
::
device
::
DeviceGemmWmma_CShuffle
<
ALayout
,
<
ALayout
,
BLayout
,
BLayout
,
CLayout
,
CLayout
,
ADataType
,
ADataType
,
BDataType
,
BDataType
,
CDataType
,
CDataType
,
AccDataType
,
AccDataType
,
CShuffleDataType
,
CShuffleDataType
,
AElementOp
,
AElementOp
,
BElementOp
,
BElementOp
,
CElementOp
,
CElementOp
,
GemmDefault
,
GemmDefault
,
1
,
// Prefetch stage
1
,
// Prefetch stage
128
,
// BlockSize
128
,
// BlockSize
64
,
// MPerBlock
64
,
// MPerBlock
128
,
// NPerBlock
128
,
// NPerBlock
64
,
// KPerBlock
64
,
// KPerBlock
8
,
// K1
2
,
// K1
16
,
// MPerWmma
16
,
// MPerWmma
16
,
// NPerWmma
16
,
// NPerWmma
2
,
// M-Repeat // M-PerWmma / M-Repeat = M-Wave
2
,
// M-Repeat // M-PerWmma / M-Repeat = M-Wave
4
,
// N-Repeat // N-PerWmma / N-Repeat = N-Wave
4
,
// N-Repeat // N-PerWmma / N-Repeat = N-Wave
S
<
4
,
32
,
1
>
,
S
<
4
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
2
,
8
,
2
,
8
,
2
,
true
,
true
,
S
<
4
,
32
,
1
>
,
S
<
4
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
2
,
8
,
2
,
8
,
2
,
true
,
true
,
1
,
// C shuffle (M Repeat) Per store
1
,
// C shuffle (M Repeat) Per store
1
,
// C shuffle (N Repeat) Per store
1
,
// C shuffle (N Repeat) Per store
S
<
1
,
32
,
1
,
4
>
,
S
<
1
,
32
,
1
,
4
>
,
8
>
;
8
>
;
// clang-format on
// clang-format on
...
...
example/01_gemm/run_gemm_example.inc
View file @
941d1f7c
...
@@ -159,7 +159,7 @@ bool run_gemm(const ProblemType& problem_size, const ExecutionConfig& config)
...
@@ -159,7 +159,7 @@ bool run_gemm(const ProblemType& problem_size, const ExecutionConfig& config)
ck
::
utils
::
FillUniformDistributionIntegerValue
<
BDataType
>
{
-
5.
f
,
5.
f
}(
b_k_n
);
ck
::
utils
::
FillUniformDistributionIntegerValue
<
BDataType
>
{
-
5.
f
,
5.
f
}(
b_k_n
);
break
;
break
;
case
4
:
case
4
:
ck
::
utils
::
FillUniformDistributionIntegerValue
<
ADataType
>
{
1
.
f
,
1
.
f
}(
a_m_k
);
ck
::
utils
::
FillUniformDistributionIntegerValue
<
ADataType
>
{
-
5
.
f
,
5
.
f
}(
a_m_k
);
ck
::
utils
::
FillUniformDistributionIntegerValue
<
BDataType
>
{
1.
f
,
1.
f
}(
b_k_n
);
ck
::
utils
::
FillUniformDistributionIntegerValue
<
BDataType
>
{
1.
f
,
1.
f
}(
b_k_n
);
break
;
break
;
case
5
:
case
5
:
...
...
example/04_gemm_add_add_fastgelu/CMakeLists.txt
View file @
941d1f7c
...
@@ -24,4 +24,4 @@ foreach(gpu IN LISTS GPU_TARGETS)
...
@@ -24,4 +24,4 @@ foreach(gpu IN LISTS GPU_TARGETS)
add_example_dependencies
(
example_gemm_add_add_fastgelu_xdl example_gemm_add_add_fastgelu_xdl_lds_direct_load_fp32
)
add_example_dependencies
(
example_gemm_add_add_fastgelu_xdl example_gemm_add_add_fastgelu_xdl_lds_direct_load_fp32
)
set
(
target 1
)
set
(
target 1
)
endif
()
endif
()
endforeach
()
endforeach
()
\ No newline at end of file
example/29_batched_gemm_bias_e_permute/batched_gemm_bias_e_permute_wmma_fp16.cpp
View file @
941d1f7c
...
@@ -83,14 +83,14 @@ using DeviceOpInstanceKKNN =
...
@@ -83,14 +83,14 @@ using DeviceOpInstanceKKNN =
2
,
2
,
4
,
4
,
4
,
4
,
tru
e
,
fals
e
,
S
<
4
,
32
,
1
>
,
S
<
4
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
2
,
4
,
4
,
4
,
4
,
tru
e
,
fals
e
,
1
,
1
,
1
,
1
,
S
<
1
,
64
,
1
,
2
>
,
S
<
1
,
64
,
1
,
2
>
,
...
...
example/32_batched_gemm_scale_softmax_gemm/cross_attention_forward_wmma_fp16.cpp
View file @
941d1f7c
...
@@ -71,7 +71,7 @@ static constexpr auto TensorSpecC = ck::tensor_operation::device::TensorSpecial
...
@@ -71,7 +71,7 @@ static constexpr auto TensorSpecC = ck::tensor_operation::device::TensorSpecial
#define CK_MHA_USE_WAVE_1
#define CK_MHA_USE_WAVE_1
#define CK_MHA_USE_WAVE_2
#define CK_MHA_USE_WAVE_2
#define CK_MHA_USE_WAVE_4
#define CK_MHA_USE_WAVE_4
#define CK_MHA_USE_WAVE_8
//
#define CK_MHA_USE_WAVE_8
using
DeviceMHAFactory
=
using
DeviceMHAFactory
=
std
::
tuple
<
std
::
tuple
<
#ifdef CK_MHA_USE_WAVE_1
#ifdef CK_MHA_USE_WAVE_1
...
@@ -277,10 +277,10 @@ using DeviceMHAFactory =
...
@@ -277,10 +277,10 @@ using DeviceMHAFactory =
S
<
2
,
8
,
8
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
2
,
1
,
false
,
S
<
2
,
8
,
8
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
2
,
1
,
false
,
// CShuffleBlockTransfer MN
// CShuffleBlockTransfer MN
1
,
1
,
S
<
1
,
64
,
1
,
2
>
,
8
,
1
,
1
,
S
<
1
,
64
,
1
,
2
>
,
8
,
MaskingSpec
>
,
MaskingSpec
>
#endif
#endif
#ifdef CK_MHA_USE_WAVE_8
#ifdef CK_MHA_USE_WAVE_8
ck
::
tensor_operation
::
device
::
DeviceBatchedGemmSoftmaxGemmPermute_Wmma_CShuffle
<
,
ck
::
tensor_operation
::
device
::
DeviceBatchedGemmSoftmaxGemmPermute_Wmma_CShuffle
<
NumDimG
,
NumDimM
,
NumDimN
,
NumDimK
,
NumDimO
,
NumDimG
,
NumDimM
,
NumDimN
,
NumDimK
,
NumDimO
,
ADataType
,
B0DataType
,
B1DataType
,
CDataType
,
Acc0BiasDataType
,
Acc0DataType
,
Acc1BiasDataType
,
Acc1DataType
,
CShuffleDataType
,
ADataType
,
B0DataType
,
B1DataType
,
CDataType
,
Acc0BiasDataType
,
Acc0DataType
,
Acc1BiasDataType
,
Acc1DataType
,
CShuffleDataType
,
AElementOp
,
B0ElementOp
,
Acc0ElementOp
,
B1ElementOp
,
CElementOp
,
AElementOp
,
B0ElementOp
,
Acc0ElementOp
,
B1ElementOp
,
CElementOp
,
...
...
example/32_batched_gemm_scale_softmax_gemm/self_attention_forward_wmma_fp16.cpp
View file @
941d1f7c
...
@@ -71,7 +71,7 @@ static constexpr auto TensorSpecC = ck::tensor_operation::device::TensorSpecial
...
@@ -71,7 +71,7 @@ static constexpr auto TensorSpecC = ck::tensor_operation::device::TensorSpecial
#define CK_MHA_USE_WAVE_1
#define CK_MHA_USE_WAVE_1
#define CK_MHA_USE_WAVE_2
#define CK_MHA_USE_WAVE_2
#define CK_MHA_USE_WAVE_4
#define CK_MHA_USE_WAVE_4
#define CK_MHA_USE_WAVE_8
//
#define CK_MHA_USE_WAVE_8
using
DeviceMHAFactory
=
using
DeviceMHAFactory
=
std
::
tuple
<
std
::
tuple
<
#ifdef CK_MHA_USE_WAVE_1
#ifdef CK_MHA_USE_WAVE_1
...
@@ -277,10 +277,10 @@ using DeviceMHAFactory =
...
@@ -277,10 +277,10 @@ using DeviceMHAFactory =
S
<
2
,
8
,
8
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
2
,
1
,
false
,
S
<
2
,
8
,
8
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
2
,
1
,
false
,
// CShuffleBlockTransfer MN
// CShuffleBlockTransfer MN
1
,
1
,
S
<
1
,
64
,
1
,
2
>
,
8
,
1
,
1
,
S
<
1
,
64
,
1
,
2
>
,
8
,
MaskingSpec
>
,
MaskingSpec
>
#endif
#endif
#ifdef CK_MHA_USE_WAVE_8
#ifdef CK_MHA_USE_WAVE_8
ck
::
tensor_operation
::
device
::
DeviceBatchedGemmSoftmaxGemmPermute_Wmma_CShuffle
<
,
ck
::
tensor_operation
::
device
::
DeviceBatchedGemmSoftmaxGemmPermute_Wmma_CShuffle
<
NumDimG
,
NumDimM
,
NumDimN
,
NumDimK
,
NumDimO
,
NumDimG
,
NumDimM
,
NumDimN
,
NumDimK
,
NumDimO
,
ADataType
,
B0DataType
,
B1DataType
,
CDataType
,
Acc0BiasDataType
,
Acc0DataType
,
Acc1BiasDataType
,
Acc1DataType
,
CShuffleDataType
,
ADataType
,
B0DataType
,
B1DataType
,
CDataType
,
Acc0BiasDataType
,
Acc0DataType
,
Acc1BiasDataType
,
Acc1DataType
,
CShuffleDataType
,
AElementOp
,
B0ElementOp
,
Acc0ElementOp
,
B1ElementOp
,
CElementOp
,
AElementOp
,
B0ElementOp
,
Acc0ElementOp
,
B1ElementOp
,
CElementOp
,
...
...
example/CMakeLists.txt
View file @
941d1f7c
...
@@ -67,7 +67,7 @@ function(add_example_executable EXAMPLE_NAME FILE_NAME)
...
@@ -67,7 +67,7 @@ function(add_example_executable EXAMPLE_NAME FILE_NAME)
endforeach
()
endforeach
()
#Do not build any WMMA examples if gfx11 targets are not on the list
#Do not build any WMMA examples if gfx11 targets are not on the list
foreach
(
source IN LISTS FILE_NAME
)
foreach
(
source IN LISTS FILE_NAME
)
if
(
NOT
EX
_TARGETS MATCHES
"gfx1
1
"
AND source MATCHES
"_wmma"
)
if
(
NOT GPU_TARGETS MATCHES
"gfx11"
AND
NOT
GPU
_TARGETS MATCHES
"gfx1
2
"
AND source MATCHES
"_wmma"
)
message
(
"removing wmma example
${
source
}
"
)
message
(
"removing wmma example
${
source
}
"
)
list
(
REMOVE_ITEM FILE_NAME
"
${
source
}
"
)
list
(
REMOVE_ITEM FILE_NAME
"
${
source
}
"
)
endif
()
endif
()
...
@@ -154,7 +154,7 @@ function(add_example_executable_no_testing EXAMPLE_NAME FILE_NAME)
...
@@ -154,7 +154,7 @@ function(add_example_executable_no_testing EXAMPLE_NAME FILE_NAME)
endforeach
()
endforeach
()
#Do not build any WMMA examples if gfx11 targets are not on the list
#Do not build any WMMA examples if gfx11 targets are not on the list
foreach
(
source IN LISTS FILE_NAME
)
foreach
(
source IN LISTS FILE_NAME
)
if
(
NOT
EX
_TARGETS MATCHES
"gfx1
1
"
AND source MATCHES
"_wmma"
)
if
(
NOT GPU_TARGETS MATCHES
"gfx11"
AND
NOT
GPU
_TARGETS MATCHES
"gfx1
2
"
AND source MATCHES
"_wmma"
)
message
(
"removing wmma example
${
source
}
"
)
message
(
"removing wmma example
${
source
}
"
)
list
(
REMOVE_ITEM FILE_NAME
"
${
source
}
"
)
list
(
REMOVE_ITEM FILE_NAME
"
${
source
}
"
)
endif
()
endif
()
...
...
include/ck/ck.hpp
View file @
941d1f7c
...
@@ -69,6 +69,9 @@ CK_DECLARE_ENV_VAR_BOOL(CK_LOGGING)
...
@@ -69,6 +69,9 @@ CK_DECLARE_ENV_VAR_BOOL(CK_LOGGING)
#if defined(__gfx1100__) || defined(__gfx1101__) || defined(__gfx1102__) || defined(__gfx1103__)
#if defined(__gfx1100__) || defined(__gfx1101__) || defined(__gfx1102__) || defined(__gfx1103__)
#define __gfx11__
#define __gfx11__
#endif
#endif
#if defined(__gfx1200__) || defined(__gfx1201__)
#define __gfx12__
#endif
// buffer resource
// buffer resource
#ifndef __HIP_DEVICE_COMPILE__ // for host code
#ifndef __HIP_DEVICE_COMPILE__ // for host code
...
@@ -77,7 +80,7 @@ CK_DECLARE_ENV_VAR_BOOL(CK_LOGGING)
...
@@ -77,7 +80,7 @@ CK_DECLARE_ENV_VAR_BOOL(CK_LOGGING)
#define CK_BUFFER_RESOURCE_3RD_DWORD 0x00020000
#define CK_BUFFER_RESOURCE_3RD_DWORD 0x00020000
#elif defined(__gfx103__)
#elif defined(__gfx103__)
#define CK_BUFFER_RESOURCE_3RD_DWORD 0x31014000
#define CK_BUFFER_RESOURCE_3RD_DWORD 0x31014000
#elif defined(__gfx11__)
#elif defined(__gfx11__)
|| defined(__gfx12__)
#define CK_BUFFER_RESOURCE_3RD_DWORD 0x31004000
#define CK_BUFFER_RESOURCE_3RD_DWORD 0x31004000
#endif
#endif
...
@@ -89,7 +92,7 @@ CK_DECLARE_ENV_VAR_BOOL(CK_LOGGING)
...
@@ -89,7 +92,7 @@ CK_DECLARE_ENV_VAR_BOOL(CK_LOGGING)
#define CK_USE_AMD_V_FMAC_F32
#define CK_USE_AMD_V_FMAC_F32
#define CK_USE_AMD_V_DOT2_F32_F16
#define CK_USE_AMD_V_DOT2_F32_F16
#define CK_USE_AMD_V_DOT4_I32_I8
#define CK_USE_AMD_V_DOT4_I32_I8
#elif defined(__gfx11__)
#elif defined(__gfx11__)
|| defined(__gfx12__)
#define CK_USE_AMD_V_FMAC_F32
#define CK_USE_AMD_V_FMAC_F32
#define CK_USE_AMD_V_DOT2_F32_F16
#define CK_USE_AMD_V_DOT2_F32_F16
#define CK_USE_AMD_V_DOT4_I32_I8_GFX11
#define CK_USE_AMD_V_DOT4_I32_I8_GFX11
...
@@ -110,13 +113,6 @@ CK_DECLARE_ENV_VAR_BOOL(CK_LOGGING)
...
@@ -110,13 +113,6 @@ CK_DECLARE_ENV_VAR_BOOL(CK_LOGGING)
#define CK_USE_AMD_MFMA_GFX940
#define CK_USE_AMD_MFMA_GFX940
#endif
#endif
// WMMA instruction
#ifndef __HIP_DEVICE_COMPILE__ // for host code
#define CK_USE_AMD_WMMA
#elif defined(__gfx11__) // for GPU code
#define CK_USE_AMD_WMMA
#endif
// buffer load
// buffer load
#define CK_USE_AMD_BUFFER_LOAD 1
#define CK_USE_AMD_BUFFER_LOAD 1
...
...
include/ck/host_utility/device_prop.hpp
View file @
941d1f7c
...
@@ -84,4 +84,9 @@ inline bool is_gfx11_supported()
...
@@ -84,4 +84,9 @@ inline bool is_gfx11_supported()
ck
::
get_device_name
()
==
"gfx1102"
||
ck
::
get_device_name
()
==
"gfx1103"
;
ck
::
get_device_name
()
==
"gfx1102"
||
ck
::
get_device_name
()
==
"gfx1103"
;
}
}
inline
bool
is_gfx12_supported
()
{
return
ck
::
get_device_name
()
==
"gfx1200"
||
ck
::
get_device_name
()
==
"gfx1201"
;
}
}
// namespace ck
}
// namespace ck
include/ck/tensor_operation/gpu/block/blockwise_gemm_wmma.hpp
View file @
941d1f7c
This diff is collapsed.
Click to expand it.
include/ck/tensor_operation/gpu/block/blockwise_gemm_xdlops.hpp
View file @
941d1f7c
...
@@ -487,7 +487,14 @@ struct BlockwiseGemmXdlopsInterwave_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1
...
@@ -487,7 +487,14 @@ struct BlockwiseGemmXdlopsInterwave_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1
// sync point.
// sync point.
if
constexpr
(
k
.
value
!=
0
||
KPerInnerLoop
==
KPerThread
)
if
constexpr
(
k
.
value
!=
0
||
KPerInnerLoop
==
KPerThread
)
{
{
#ifdef __gfx12__
asm
volatile
(
"\
s_barrier_signal -1
\n
\
s_barrier_wait -1 \
"
::
);
#else
asm
volatile
(
"s_barrier"
::
);
asm
volatile
(
"s_barrier"
::
);
#endif
__builtin_amdgcn_sched_barrier
(
0
);
__builtin_amdgcn_sched_barrier
(
0
);
}
}
static_for
<
0
,
KPerInnerLoop
,
KPack
>
{}([
&
](
auto
k_
)
{
static_for
<
0
,
KPerInnerLoop
,
KPack
>
{}([
&
](
auto
k_
)
{
...
...
include/ck/tensor_operation/gpu/device/impl/device_batched_contraction_multiple_d_wmma_cshuffle.hpp
View file @
941d1f7c
...
@@ -133,8 +133,13 @@ struct DeviceBatchedContractionMultipleD_Wmma_CShuffle
...
@@ -133,8 +133,13 @@ struct DeviceBatchedContractionMultipleD_Wmma_CShuffle
static
constexpr
auto
NWaves
=
NPerBlock
/
(
NRepeat
*
NPerWmma
);
static
constexpr
auto
NWaves
=
NPerBlock
/
(
NRepeat
*
NPerWmma
);
static
constexpr
auto
WmmaK
=
K1
==
16
?
32
:
16
;
static
constexpr
auto
WmmaK
=
K1
==
16
?
32
:
16
;
static
constexpr
auto
AEnableLds_auto
=
NWaves
==
1
?
false
:
true
;
static
constexpr
auto
MaxVectorLoadA
=
K1
*
sizeof
(
ADataType
)
==
16
?
true
:
false
;
static
constexpr
auto
BEnableLds_auto
=
MWaves
==
1
?
false
:
true
;
static
constexpr
auto
MaxVectorLoadB
=
K1
*
sizeof
(
BDataType
)
==
16
?
true
:
false
;
static
constexpr
auto
AEnableLds_auto
=
(
NWaves
==
1
&&
(
MaxVectorLoadA
||
MRepeat
==
1
))
?
false
:
true
;
static
constexpr
auto
BEnableLds_auto
=
(
MWaves
==
1
&&
(
MaxVectorLoadB
||
NRepeat
==
1
))
?
false
:
true
;
// If true, LDS is used unconditionally
// If true, LDS is used unconditionally
static
constexpr
auto
AEnableLds_manu
=
false
;
static
constexpr
auto
AEnableLds_manu
=
false
;
...
@@ -829,7 +834,7 @@ struct DeviceBatchedContractionMultipleD_Wmma_CShuffle
...
@@ -829,7 +834,7 @@ struct DeviceBatchedContractionMultipleD_Wmma_CShuffle
static
bool
IsSupportedArgument
(
const
Argument
&
arg
)
static
bool
IsSupportedArgument
(
const
Argument
&
arg
)
{
{
if
(
ck
::
is_gfx11_supported
())
if
(
ck
::
is_gfx11_supported
()
||
ck
::
is_gfx12_supported
()
)
{
{
if
constexpr
(
!
(
is_same_v
<
AccDataType
,
float
>
||
is_same_v
<
AccDataType
,
int32_t
>
))
if
constexpr
(
!
(
is_same_v
<
AccDataType
,
float
>
||
is_same_v
<
AccDataType
,
int32_t
>
))
{
{
...
@@ -869,11 +874,15 @@ struct DeviceBatchedContractionMultipleD_Wmma_CShuffle
...
@@ -869,11 +874,15 @@ struct DeviceBatchedContractionMultipleD_Wmma_CShuffle
}
}
else
else
{
{
if
(
!
(
arg
.
a_kz_stride_
==
1
&&
if
(
!
(
arg
.
a_kz_stride_
==
1
))
arg
.
a_grid_desc_
.
GetLength
(
I2
)
%
ABlockTransferSrcScalarPerVector
==
0
))
{
{
printf
(
"DeviceOp: Vector Access A-k check failure
\n
"
);
index_t
LastK
=
return
false
;
AEnableLds
?
arg
.
a_grid_desc_
.
GetLength
(
I2
)
:
arg
.
a_grid_desc_
.
GetLength
(
I6
);
if
(
LastK
%
ABlockTransferSrcScalarPerVector
==
0
)
{
printf
(
"DeviceOp: Vector Access A-k check failure
\n
"
);
return
false
;
}
}
}
}
}
...
...
include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_multiple_d_dl.hpp
View file @
941d1f7c
...
@@ -70,8 +70,9 @@ __global__ void
...
@@ -70,8 +70,9 @@ __global__ void
const
ComputePtrOffsetOfBatch
compute_ptr_offset_of_batch
,
const
ComputePtrOffsetOfBatch
compute_ptr_offset_of_batch
,
const
Block2CTileMap
block_2_ctile_map
)
const
Block2CTileMap
block_2_ctile_map
)
{
{
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx906__) || defined(__gfx908__) || \
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx906__) || defined(__gfx908__) || \
defined(__gfx90a__) || defined(__gfx94__) || defined(__gfx103__) || defined(__gfx11__))
defined(__gfx90a__) || defined(__gfx94__) || defined(__gfx103__) || defined(__gfx11__) || \
defined(__gfx12__))
const
index_t
num_blocks_per_batch
=
const
index_t
num_blocks_per_batch
=
__builtin_amdgcn_readfirstlane
(
get_grid_size
()
/
batch_count
);
__builtin_amdgcn_readfirstlane
(
get_grid_size
()
/
batch_count
);
...
@@ -648,7 +649,7 @@ struct DeviceBatchedGemmMultipleD_Dl : public DeviceBatchedGemmMultiD<ALayout,
...
@@ -648,7 +649,7 @@ struct DeviceBatchedGemmMultipleD_Dl : public DeviceBatchedGemmMultiD<ALayout,
static
bool
IsSupportedArgument
(
const
Argument
&
arg
)
static
bool
IsSupportedArgument
(
const
Argument
&
arg
)
{
{
if
(
ck
::
get_device_name
()
==
"gfx906"
||
ck
::
is_xdl_supported
()
||
if
(
ck
::
get_device_name
()
==
"gfx906"
||
ck
::
is_xdl_supported
()
||
ck
::
is_gfx103_supported
()
||
ck
::
is_gfx11_supported
())
ck
::
is_gfx103_supported
()
||
ck
::
is_gfx11_supported
()
||
ck
::
is_gfx12_supported
()
)
{
{
bool
pass
=
true
;
bool
pass
=
true
;
pass
=
pass
&&
arg
.
K_
%
K1
==
0
;
pass
=
pass
&&
arg
.
K_
%
K1
==
0
;
...
...
include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_softmax_gemm_permute_wmma_cshuffle.hpp
View file @
941d1f7c
...
@@ -56,7 +56,7 @@ __global__ void
...
@@ -56,7 +56,7 @@ __global__ void
bool
input_permute
,
bool
input_permute
,
bool
output_permute
)
bool
output_permute
)
{
{
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx11__))
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx11__)
|| defined(__gfx12__)
)
// clang-format off
// clang-format off
// ***************************************************
// ***************************************************
...
@@ -159,6 +159,7 @@ __global__ void
...
@@ -159,6 +159,7 @@ __global__ void
ignore
=
O
;
ignore
=
O
;
ignore
=
G0
;
ignore
=
G0
;
ignore
=
G1
;
ignore
=
G1
;
ignore
=
alpha
;
ignore
=
input_permute
;
ignore
=
input_permute
;
ignore
=
output_permute
;
ignore
=
output_permute
;
#endif // end of if (defined(__gfx11__))
#endif // end of if (defined(__gfx11__))
...
@@ -187,7 +188,7 @@ __global__ void
...
@@ -187,7 +188,7 @@ __global__ void
index_t
head_size
,
index_t
head_size
,
float
alpha
)
float
alpha
)
{
{
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx11__))
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx11__)
|| defined(__gfx12__)
)
// clang-format off
// clang-format off
// ***************************************************
// ***************************************************
...
@@ -321,7 +322,7 @@ __global__ void
...
@@ -321,7 +322,7 @@ __global__ void
index_t
head_size
,
index_t
head_size
,
float
alpha
)
float
alpha
)
{
{
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx11__))
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx11__)
|| defined(__gfx12__)
)
// clang-format off
// clang-format off
// ***************************************************
// ***************************************************
...
@@ -858,7 +859,7 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Wmma_CShuffle
...
@@ -858,7 +859,7 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Wmma_CShuffle
static
bool
IsSupportedArgument
(
const
RawArg
&
arg
)
static
bool
IsSupportedArgument
(
const
RawArg
&
arg
)
{
{
if
(
ck
::
is_gfx11_supported
())
if
(
ck
::
is_gfx11_supported
()
||
ck
::
is_gfx12_supported
()
)
{
{
if
constexpr
(
!
(
is_same_v
<
Acc0DataType
,
float
>
||
is_same_v
<
Acc0DataType
,
int32_t
>
))
if
constexpr
(
!
(
is_same_v
<
Acc0DataType
,
float
>
||
is_same_v
<
Acc0DataType
,
int32_t
>
))
{
{
...
...
include/ck/tensor_operation/gpu/device/impl/device_contraction_multiple_d_xdl_cshuffle.hpp
View file @
941d1f7c
...
@@ -592,9 +592,7 @@ struct DeviceContractionMultipleD_Xdl_CShuffle
...
@@ -592,9 +592,7 @@ struct DeviceContractionMultipleD_Xdl_CShuffle
return
false
;
return
false
;
}
}
if
(
ck
::
get_device_name
()
!=
"gfx90a"
&&
ck
::
get_device_name
()
!=
"gfx940"
&&
if
(
!
ck
::
is_lds_direct_load_supported
()
&&
std
::
is_same
<
ADataType
,
double
>::
value
)
ck
::
get_device_name
()
!=
"gfx941"
&&
ck
::
get_device_name
()
!=
"gfx942"
&&
std
::
is_same
<
ADataType
,
double
>::
value
)
{
{
return
false
;
return
false
;
}
}
...
...
include/ck/tensor_operation/gpu/device/impl/device_convnd_bwd_data_nwc_kxc_nwk_dl.hpp
View file @
941d1f7c
...
@@ -1393,7 +1393,7 @@ struct DeviceConvNdBwdDataNwcKxcNwk_Dl
...
@@ -1393,7 +1393,7 @@ struct DeviceConvNdBwdDataNwcKxcNwk_Dl
{
{
// check device
// check device
if
(
!
(
ck
::
get_device_name
()
==
"gfx906"
||
ck
::
is_gfx103_supported
()
||
if
(
!
(
ck
::
get_device_name
()
==
"gfx906"
||
ck
::
is_gfx103_supported
()
||
ck
::
is_gfx11_supported
()))
ck
::
is_gfx11_supported
()
||
ck
::
is_gfx12_supported
()
))
{
{
return
false
;
return
false
;
}
}
...
...
include/ck/tensor_operation/gpu/device/impl/device_fpAintB_gemm_wmma.hpp
View file @
941d1f7c
...
@@ -509,7 +509,7 @@ struct DeviceFpAintBGemm_Wmma_CShuffle : public DeviceGemm_dequantB<ALayout,
...
@@ -509,7 +509,7 @@ struct DeviceFpAintBGemm_Wmma_CShuffle : public DeviceGemm_dequantB<ALayout,
static
bool
IsSupportedArgument
(
const
Argument
&
arg
)
static
bool
IsSupportedArgument
(
const
Argument
&
arg
)
{
{
if
(
ck
::
is_gfx11_supported
())
if
(
ck
::
is_gfx11_supported
()
||
ck
::
is_gfx12_supported
()
)
{
{
if
constexpr
(
!
(
is_same_v
<
AccDataType
,
float
>
||
is_same_v
<
AccDataType
,
ck
::
half_t
>
||
if
constexpr
(
!
(
is_same_v
<
AccDataType
,
float
>
||
is_same_v
<
AccDataType
,
ck
::
half_t
>
||
is_same_v
<
AccDataType
,
int32_t
>
))
is_same_v
<
AccDataType
,
int32_t
>
))
...
...
Prev
1
2
3
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment