Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel_ROCM
Commits
5cb59d36
Commit
5cb59d36
authored
Apr 07, 2024
by
Jing Zhang
Browse files
resolve conflicts
parents
7e3a5613
7e147c64
Changes
36
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
66 additions
and
53 deletions
+66
-53
CMakeLists.txt
CMakeLists.txt
+2
-2
Jenkinsfile
Jenkinsfile
+2
-2
cmake/EnableCompilerWarnings.cmake
cmake/EnableCompilerWarnings.cmake
+1
-1
docs/conf.py
docs/conf.py
+2
-0
docs/wrapper.rst
docs/wrapper.rst
+7
-7
example/01_gemm/CMakeLists.txt
example/01_gemm/CMakeLists.txt
+11
-7
example/01_gemm/run_gemm_example.inc
example/01_gemm/run_gemm_example.inc
+3
-3
example/02_gemm_bilinear/CMakeLists.txt
example/02_gemm_bilinear/CMakeLists.txt
+2
-2
example/30_grouped_conv_fwd_multiple_d/CMakeLists.txt
example/30_grouped_conv_fwd_multiple_d/CMakeLists.txt
+1
-1
include/ck/ck.hpp
include/ck/ck.hpp
+1
-8
include/ck/host_utility/device_prop.hpp
include/ck/host_utility/device_prop.hpp
+4
-1
include/ck/tensor_operation/gpu/block/blockwise_gemm_xdlops.hpp
...e/ck/tensor_operation/gpu/block/blockwise_gemm_xdlops.hpp
+7
-0
include/ck/tensor_operation/gpu/device/impl/device_batched_contraction_multiple_d_wmma_cshuffle.hpp
...l/device_batched_contraction_multiple_d_wmma_cshuffle.hpp
+2
-2
include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_multiple_d_dl.hpp
...ion/gpu/device/impl/device_batched_gemm_multiple_d_dl.hpp
+4
-3
include/ck/tensor_operation/gpu/device/impl/device_convnd_bwd_data_nwc_kxc_nwk_dl.hpp
...gpu/device/impl/device_convnd_bwd_data_nwc_kxc_nwk_dl.hpp
+1
-1
include/ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_dl.hpp
...r_operation/gpu/device/impl/device_gemm_multiple_d_dl.hpp
+4
-3
include/ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_wmma_cshuffle.hpp
.../gpu/device/impl/device_gemm_multiple_d_wmma_cshuffle.hpp
+2
-2
include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_dl_multiple_d_nhwc_kyxc_nhwk.hpp
.../device_grouped_conv_fwd_dl_multiple_d_nhwc_kyxc_nhwk.hpp
+4
-3
include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_dl_nhwc_kyxc_nhwk.hpp
...device/impl/device_grouped_conv_fwd_dl_nhwc_kyxc_nhwk.hpp
+2
-2
include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_multiple_d_dl.hpp
...ion/gpu/device/impl/device_grouped_gemm_multiple_d_dl.hpp
+4
-3
No files found.
CMakeLists.txt
View file @
5cb59d36
...
...
@@ -113,7 +113,7 @@ message("checking which targets are supported")
#Setting GPU_TARGETS on command line will override this list
if
(
NOT PROFILER_ONLY
)
rocm_check_target_ids
(
DEFAULT_GPU_TARGETS
TARGETS
"gfx908;gfx90a;gfx940;gfx941;gfx942;gfx1030;gfx1100;gfx1101;gfx1102
;gfx1200
"
)
TARGETS
"gfx908;gfx90a;gfx940;gfx941;gfx942;gfx1030;gfx1100;gfx1101;gfx1102"
)
else
()
add_definitions
(
-DPROFILER_ONLY
)
set
(
GPU_TARGETS
""
CACHE STRING
""
FORCE
)
...
...
@@ -129,7 +129,7 @@ else()
elseif
(
GPU_ARCH MATCHES
"gfx11"
)
rocm_check_target_ids
(
DEFAULT_GPU_TARGETS TARGETS
"gfx1100;gfx1101;gfx1102"
)
elseif
(
GPU_ARCH MATCHES
"gfx12"
)
rocm_check_target_ids
(
DEFAULT_GPU_TARGETS TARGETS
"gfx1200"
)
rocm_check_target_ids
(
DEFAULT_GPU_TARGETS TARGETS
"gfx1200
;gfx1201
"
)
else
()
message
(
FATAL_ERROR
"For PROFILE_ONLY build, please specify GPU_ARCH as gfx90, gfx94, gfx10, gfx11 or gfx12"
)
endif
()
...
...
Jenkinsfile
View file @
5cb59d36
...
...
@@ -496,7 +496,7 @@ def Build_CK(Map conf=[:]){
def
navi_node
=
0
def
mi300_node
=
0
gitStatusWrapper
(
credentialsId:
"${status_wrapper_creds}"
,
gitHubContext:
"Jenkins - ${variant}"
,
account:
'ROCm'
,
repo:
'composable_kernel-internal'
)
{
gitStatusWrapper
(
credentialsId:
"${
env.
status_wrapper_creds}"
,
gitHubContext:
"Jenkins - ${variant}"
,
account:
'ROCm'
,
repo:
'composable_kernel-internal'
)
{
try
{
(
retimage
,
image
)
=
getDockerImage
(
conf
)
withDockerContainer
(
image:
image
,
args:
dockerOpts
)
{
...
...
@@ -602,7 +602,7 @@ def process_results(Map conf=[:]){
def
variant
=
env
.
STAGE_NAME
def
retimage
gitStatusWrapper
(
credentialsId:
"${status_wrapper_creds}"
,
gitHubContext:
"Jenkins - ${variant}"
,
account:
'ROCm'
,
repo:
'composable_kernel-internal'
)
{
gitStatusWrapper
(
credentialsId:
"${
env.
status_wrapper_creds}"
,
gitHubContext:
"Jenkins - ${variant}"
,
account:
'ROCm'
,
repo:
'composable_kernel-internal'
)
{
try
{
(
retimage
,
image
)
=
getDockerImage
(
conf
)
}
...
...
cmake/EnableCompilerWarnings.cmake
View file @
5cb59d36
...
...
@@ -66,7 +66,7 @@ else()
-Wunreachable-code
-Wunused
-Wno-reserved-identifier
#
-Werror
-Werror
-Wno-option-ignored
-Wsign-compare
-Wno-extra-semi-stmt
...
...
docs/conf.py
View file @
5cb59d36
...
...
@@ -45,3 +45,5 @@ for sphinx_var in ROCmDocs.SPHINX_VARS:
extensions
+=
[
'sphinxcontrib.bibtex'
]
bibtex_bibfiles
=
[
'refs.bib'
]
cpp_id_attributes
=
[
"__global__"
,
"__device__"
,
"__host__"
]
docs/wrapper.rst
View file @
5cb59d36
...
...
@@ -64,31 +64,31 @@ Advanced examples:
Layout
-------------------------------------
.. doxygenstruct::
ck::wrapper::
Layout
.. doxygenstruct:: Layout
-------------------------------------
Layout helpers
-------------------------------------
.. doxygenfile:: layout_utils.hpp
.. doxygenfile::
include/ck/wrapper/utils/
layout_utils.hpp
-------------------------------------
Tensor
-------------------------------------
.. doxygenstruct::
ck::wrapper::
Tensor
.. doxygenstruct:: Tensor
-------------------------------------
Tensor helpers
-------------------------------------
.. doxygenfile:: tensor_utils.hpp
.. doxygenfile::
include/ck/wrapper/utils/
tensor_utils.hpp
.. doxygenfile:: tensor_partition.hpp
.. doxygenfile::
include/ck/wrapper/utils/
tensor_partition.hpp
-------------------------------------
Operations
-------------------------------------
.. doxygenfile:: copy.hpp
.. doxygenfile:: gemm.hpp
.. doxygenfile::
include/ck/wrapper/operations/
copy.hpp
.. doxygenfile::
include/ck/wrapper/operations/
gemm.hpp
example/01_gemm/CMakeLists.txt
View file @
5cb59d36
...
...
@@ -27,7 +27,8 @@ add_example_dependencies(example_gemm_xdl example_gemm_xdl_wavelet_fp16)
add_example_executable
(
example_gemm_xdl_skip_b_lds_fp16 gemm_xdl_skip_b_lds_fp16.cpp
)
add_example_dependencies
(
example_gemm_xdl example_gemm_xdl_skip_b_lds_fp16
)
if
(
GPU_TARGETS MATCHES
"gfx1100"
OR GPU_TARGETS MATCHES
"gfx1101"
OR GPU_TARGETS MATCHES
"gfx1102"
OR GPU_TARGETS MATCHES
"gfx1200"
)
if
(
GPU_TARGETS MATCHES
"gfx11"
OR GPU_TARGETS MATCHES
"gfx12"
)
add_custom_target
(
example_gemm_wmma
)
add_example_executable
(
example_gemm_wmma_fp16 gemm_wmma_fp16.cpp
)
add_example_dependencies
(
example_gemm_wmma example_gemm_wmma_fp16
)
...
...
@@ -53,12 +54,6 @@ add_example_dependencies(example_gemm_xdl example_gemm_xdl_fp64)
add_example_executable
(
example_gemm_xdl_streamk gemm_xdl_streamk.cpp
)
add_example_executable
(
example_gemm_xdl_fp8 gemm_xdl_fp8.cpp
)
add_example_dependencies
(
example_gemm_xdl example_gemm_xdl_fp8
)
add_example_executable
(
example_gemm_xdl_fp8_bf8 gemm_xdl_fp8_bf8.cpp
)
add_example_dependencies
(
example_gemm_xdl example_gemm_xdl_fp8_bf8
)
list
(
APPEND gpu_list gfx90a gfx940 gfx941 gfx942 gfx950
)
set
(
target 0
)
foreach
(
gpu IN LISTS GPU_TARGETS
)
...
...
@@ -71,3 +66,12 @@ foreach(gpu IN LISTS GPU_TARGETS)
set
(
target 1
)
endif
()
endforeach
()
add_example_executable
(
example_gemm_xdl_fp8 gemm_xdl_fp8.cpp
)
add_example_dependencies
(
example_gemm_xdl example_gemm_xdl_fp8
)
add_example_executable
(
example_gemm_xdl_fp8_bf8 gemm_xdl_fp8_bf8.cpp
)
add_example_dependencies
(
example_gemm_xdl example_gemm_xdl_fp8_bf8
)
add_example_executable
(
example_gemm_xdl_fp16_fp8 gemm_xdl_fp16_fp8.cpp
)
add_example_dependencies
(
example_gemm_xdl example_gemm_xdl_fp16_fp8
)
example/01_gemm/run_gemm_example.inc
View file @
5cb59d36
...
...
@@ -155,12 +155,12 @@ bool run_gemm(const ProblemType& problem_size, const ExecutionConfig& config)
ck
::
utils
::
FillUniformDistribution
<
BDataType
>
{
-
1.
f
,
1.
f
}(
b_k_n
);
break
;
case
3
:
ck
::
utils
::
Fill
Constant
<
ADataType
>
{
static_cast
<
ADataType
>
(
1.
f
)
}(
a_m_k
);
ck
::
utils
::
Fill
UniformDistributionIntegerValue
<
ADataType
>
{
1.
f
,
1.
f
}(
a_m_k
);
ck
::
utils
::
FillUniformDistributionIntegerValue
<
BDataType
>
{
-
5.
f
,
5.
f
}(
b_k_n
);
break
;
case
4
:
ck
::
utils
::
FillUniformDistributionIntegerValue
<
ADataType
>
{
-
5
.
f
,
5
.
f
}(
a_m_k
);
ck
::
utils
::
Fill
Constant
<
BDataType
>
{
static_cast
<
BDataType
>
(
1.
f
)
}(
b_k_n
);
ck
::
utils
::
FillUniformDistributionIntegerValue
<
ADataType
>
{
1
.
f
,
1
.
f
}(
a_m_k
);
ck
::
utils
::
Fill
UniformDistributionIntegerValue
<
BDataType
>
{
1.
f
,
1.
f
}(
b_k_n
);
break
;
case
5
:
ck
::
utils
::
FillUniformDistributionIntegerValue
<
ADataType
>
{
-
2.
f
,
2.
f
}(
a_m_k
);
...
...
example/02_gemm_bilinear/CMakeLists.txt
View file @
5cb59d36
list
(
APPEND gpu_list1 gfx1100 gfx1101 gfx1102 gfx1
200
)
list
(
APPEND gpu_list1 gfx1100 gfx1101 gfx1102 gfx1
103 gfx1200 gfx1201
)
list
(
APPEND gpu_list2 gfx908 gfx90a gfx940 gfx941 gfx942 gfx950
)
set
(
target 0
)
foreach
(
gpu IN LISTS GPU_TARGETS
)
...
...
@@ -6,7 +6,7 @@ foreach(gpu IN LISTS GPU_TARGETS)
add_example_executable
(
example_gemm_bilinear_wmma_fp16 gemm_bilinear_wmma_fp16.cpp
)
add_example_executable
(
example_gemm_bilinear_wmma_int8 gemm_bilinear_wmma_int8.cpp
)
endif
()
if
(
GPU_TARGETS MATCHES
"gfx908"
OR GPU_TARGETS MATCHES
"gfx90a"
OR GPU_TARGETS MATCHES
"gfx94
0
"
)
if
(
GPU_TARGETS MATCHES
"gfx908"
OR GPU_TARGETS MATCHES
"gfx90a"
OR GPU_TARGETS MATCHES
"gfx94"
)
set
(
target 1
)
endif
()
endforeach
()
...
...
example/30_grouped_conv_fwd_multiple_d/CMakeLists.txt
View file @
5cb59d36
list
(
APPEND gpu_list1 gfx908 gfx90a gfx940 gfx941 gfx942 gfx950
)
list
(
APPEND gpu_list2 gfx1100 gfx1101 gfx1102 gfx1200
)
list
(
APPEND gpu_list2 gfx1100 gfx1101 gfx1102
gfx1103
gfx1200
)
set
(
target 0
)
foreach
(
gpu IN LISTS GPU_TARGETS
)
...
...
include/ck/ck.hpp
View file @
5cb59d36
...
...
@@ -58,7 +58,7 @@
#if defined(__gfx1100__) || defined(__gfx1101__) || defined(__gfx1102__) || defined(__gfx1103__)
#define __gfx11__
#endif
#if defined(__gfx1200__)
#if defined(__gfx1200__)
|| defined(__gfx1201__)
#define __gfx12__
#endif
...
...
@@ -104,13 +104,6 @@
#define CK_USE_AMD_MFMA_GFX940
#endif
// WMMA instruction
#ifndef __HIP_DEVICE_COMPILE__ // for host code
#define CK_USE_AMD_WMMA
#elif defined(__gfx11__) // for GPU code
#define CK_USE_AMD_WMMA
#endif
// buffer load
#define CK_USE_AMD_BUFFER_LOAD 1
...
...
include/ck/host_utility/device_prop.hpp
View file @
5cb59d36
...
...
@@ -85,6 +85,9 @@ inline bool is_navi3_supported()
ck
::
get_device_name
()
==
"gfx1102"
||
ck
::
get_device_name
()
==
"gfx1103"
;
}
inline
bool
is_navi4_supported
()
{
return
ck
::
get_device_name
()
==
"gfx1200"
;
}
inline
bool
is_navi4_supported
()
{
return
ck
::
get_device_name
()
==
"gfx1200"
||
ck
::
get_device_name
()
==
"gfx1201"
;
}
}
// namespace ck
include/ck/tensor_operation/gpu/block/blockwise_gemm_xdlops.hpp
View file @
5cb59d36
...
...
@@ -488,7 +488,14 @@ struct BlockwiseGemmXdlopsInterwave_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1
// sync point.
if
constexpr
(
k
.
value
!=
0
||
KPerInnerLoop
==
KPerThread
)
{
#ifdef __gfx12__
asm
volatile
(
"\
s_barrier_signal -1
\n
\
s_barrier_wait -1 \
"
::
);
#else
asm
volatile
(
"s_barrier"
::
);
#endif
__builtin_amdgcn_sched_barrier
(
0
);
}
static_for
<
0
,
KPerInnerLoop
,
KPack
>
{}([
&
](
auto
k_
)
{
...
...
include/ck/tensor_operation/gpu/device/impl/device_batched_contraction_multiple_d_wmma_cshuffle.hpp
View file @
5cb59d36
...
...
@@ -137,8 +137,8 @@ struct DeviceBatchedContractionMultipleD_Wmma_CShuffle
static
constexpr
auto
BEnableLds_auto
=
MWaves
==
1
?
false
:
true
;
// If true, LDS is used unconditionally
static
constexpr
auto
AEnableLds_manu
=
fals
e
;
static
constexpr
auto
BEnableLds_manu
=
fals
e
;
static
constexpr
auto
AEnableLds_manu
=
tru
e
;
static
constexpr
auto
BEnableLds_manu
=
tru
e
;
static
constexpr
auto
AEnableLds
=
AEnableLds_auto
||
AEnableLds_manu
||
(
NumPrefetch
>
1
);
static
constexpr
auto
BEnableLds
=
BEnableLds_auto
||
BEnableLds_manu
||
(
NumPrefetch
>
1
);
...
...
include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_multiple_d_dl.hpp
View file @
5cb59d36
...
...
@@ -70,8 +70,9 @@ __global__ void
const
ComputePtrOffsetOfBatch
compute_ptr_offset_of_batch
,
const
Block2CTileMap
block_2_ctile_map
)
{
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx906__) || defined(__gfx908__) || \
defined(__gfx90a__) || defined(__gfx94__) || defined(__gfx103__) || defined(__gfx11__))
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx906__) || defined(__gfx908__) || \
defined(__gfx90a__) || defined(__gfx94__) || defined(__gfx103__) || defined(__gfx11__) || \
defined(__gfx12__))
const
index_t
num_blocks_per_batch
=
__builtin_amdgcn_readfirstlane
(
get_grid_size
()
/
batch_count
);
...
...
@@ -648,7 +649,7 @@ struct DeviceBatchedGemmMultipleD_Dl : public DeviceBatchedGemmMultiD<ALayout,
static
bool
IsSupportedArgument
(
const
Argument
&
arg
)
{
if
(
ck
::
get_device_name
()
==
"gfx906"
||
ck
::
is_xdl_supported
()
||
ck
::
is_navi2_supported
()
||
ck
::
is_navi3_supported
())
ck
::
is_navi2_supported
()
||
ck
::
is_navi3_supported
()
||
ck
::
is_navi4_supported
()
)
{
bool
pass
=
true
;
pass
=
pass
&&
arg
.
K_
%
K1
==
0
;
...
...
include/ck/tensor_operation/gpu/device/impl/device_convnd_bwd_data_nwc_kxc_nwk_dl.hpp
View file @
5cb59d36
...
...
@@ -1394,7 +1394,7 @@ struct DeviceConvNdBwdDataNwcKxcNwk_Dl
{
// check device
if
(
!
(
ck
::
get_device_name
()
==
"gfx906"
||
ck
::
is_navi2_supported
()
||
ck
::
is_navi3_supported
()))
ck
::
is_navi3_supported
()
||
ck
::
is_navi4_supported
()
))
{
return
false
;
}
...
...
include/ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_dl.hpp
View file @
5cb59d36
...
...
@@ -50,8 +50,9 @@ __global__ void
const
CGridDesc_M0_M10_M11_N0_N10_N11
e_grid_desc_m0_m10_m11_n0_n10_n11
,
const
Block2CTileMap
block_2_ctile_map
)
{
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx906__) || defined(__gfx908__) || \
defined(__gfx90a__) || defined(__gfx94__) || defined(__gfx103__) || defined(__gfx11__))
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx906__) || defined(__gfx908__) || \
defined(__gfx90a__) || defined(__gfx94__) || defined(__gfx103__) || defined(__gfx11__) || \
defined(__gfx12__))
constexpr
index_t
shared_block_size
=
GridwiseGemm
::
GetSharedMemoryNumberOfByte
()
/
sizeof
(
ABDataType
);
...
...
@@ -552,7 +553,7 @@ struct DeviceGemmMultipleD_Dl : public DeviceGemmMultipleD<ALayout,
static
bool
IsSupportedArgument
(
const
Argument
&
arg
)
{
if
(
ck
::
get_device_name
()
==
"gfx906"
||
ck
::
is_xdl_supported
()
||
ck
::
is_navi2_supported
()
||
ck
::
is_navi3_supported
())
ck
::
is_navi2_supported
()
||
ck
::
is_navi3_supported
()
||
ck
::
is_navi4_supported
()
)
{
return
GridwiseGemm
::
CheckValidity
(
arg
.
a_grid_desc_k0_m_k1_
,
arg
.
b_grid_desc_k0_n_k1_
,
arg
.
e_grid_desc_m_n_
);
...
...
include/ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_wmma_cshuffle.hpp
View file @
5cb59d36
...
...
@@ -101,8 +101,8 @@ struct DeviceGemmMultipleD_Wmma_CShuffle : public DeviceGemmMultipleD<ALayout,
(
MWaves
==
1
&&
is_same
<
tensor_layout
::
gemm
::
ColumnMajor
,
BLayout
>::
value
)
?
false
:
true
;
// If true, LDS is used unconditionally
static
constexpr
auto
AEnableLds_manu
=
fals
e
;
static
constexpr
auto
BEnableLds_manu
=
fals
e
;
static
constexpr
auto
AEnableLds_manu
=
tru
e
;
static
constexpr
auto
BEnableLds_manu
=
tru
e
;
static
constexpr
auto
AEnableLds
=
AEnableLds_auto
||
AEnableLds_manu
||
(
NumPrefetch
>
1
);
static
constexpr
auto
BEnableLds
=
BEnableLds_auto
||
BEnableLds_manu
||
(
NumPrefetch
>
1
);
...
...
include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_dl_multiple_d_nhwc_kyxc_nhwk.hpp
View file @
5cb59d36
...
...
@@ -90,8 +90,9 @@ __global__ void
const
Block2CTileMap
block_2_ctile_map
,
const
ComputePtrOffsetOfBatch
compute_ptr_offset_of_batch
)
{
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx906__) || defined(__gfx103__) || \
defined(__gfx90a__) || defined(__gfx908__) || defined(__gfx94__) || defined(__gfx11__))
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx906__) || defined(__gfx103__) || \
defined(__gfx90a__) || defined(__gfx908__) || defined(__gfx94__) || defined(__gfx11__) || \
defined(__gfx12__))
// offset base pointer for each work-group
const
index_t
num_blocks_per_batch
=
__builtin_amdgcn_readfirstlane
(
get_grid_size
()
/
batch_count
);
...
...
@@ -666,7 +667,7 @@ struct DeviceGroupedConvFwdDlMultipleD_NHWC_KYXC_NHWK
// check device
if
(
!
(
ck
::
get_device_name
()
==
"gfx906"
||
ck
::
is_xdl_supported
()
||
ck
::
is_navi2_supported
()
||
ck
::
is_navi3_supported
()))
ck
::
is_navi2_supported
()
||
ck
::
is_navi3_supported
()
||
ck
::
is_navi4_supported
()
))
{
return
false
;
}
...
...
include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_dl_nhwc_kyxc_nhwk.hpp
View file @
5cb59d36
...
...
@@ -107,7 +107,7 @@ __global__ void
const
ComputePtrOffsetOfBatch
compute_ptr_offset_of_batch
)
{
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx906__) || defined(__gfx103__) || \
defined(__gfx11__))
defined(__gfx11__)
|| defined(__gfx12__)
)
// offset base pointer for each work-group
const
index_t
num_blocks_per_batch
=
__builtin_amdgcn_readfirstlane
(
get_grid_size
()
/
batch_count
);
...
...
@@ -602,7 +602,7 @@ struct DeviceGroupedConvFwdDl_NHWC_KYXC_NHWK : public DeviceGroupedConvFwd<NDimS
// check device
if
(
!
(
ck
::
get_device_name
()
==
"gfx906"
||
ck
::
is_navi2_supported
()
||
ck
::
is_navi3_supported
()))
ck
::
is_navi3_supported
()
||
ck
::
is_navi4_supported
()
))
{
return
false
;
}
...
...
include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_multiple_d_dl.hpp
View file @
5cb59d36
...
...
@@ -39,8 +39,9 @@ __global__ void
const
BElementwiseOperation
b_element_op
,
const
CDEElementwiseOperation
cde_element_op
)
{
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx906__) || defined(__gfx908__) || \
defined(__gfx90a__) || defined(__gfx103__) || defined(__gfx11__) || defined(__gfx94__))
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx906__) || defined(__gfx908__) || \
defined(__gfx90a__) || defined(__gfx103__) || defined(__gfx11__) || defined(__gfx94__) || \
defined(__gfx12__))
__shared__
char
p_shared
[
GridwiseGemm
::
GetSharedMemoryNumberOfByte
()];
const
index_t
block_id
=
get_block_1d_id
();
...
...
@@ -668,7 +669,7 @@ struct DeviceGroupedGemmMultipleD_Dl : public DeviceGroupedGemm<ALayout,
}
if
(
ck
::
get_device_name
()
==
"gfx906"
||
ck
::
is_xdl_supported
()
||
ck
::
is_navi2_supported
()
||
ck
::
is_navi3_supported
())
ck
::
is_navi2_supported
()
||
ck
::
is_navi3_supported
()
||
ck
::
is_navi4_supported
()
)
{
for
(
std
::
size_t
i
=
0
;
i
<
arg
.
gemm_desc_kernel_arg_
.
size
();
i
++
)
{
...
...
Prev
1
2
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment