Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel_ROCM
Commits
9fa379ea
Unverified
Commit
9fa379ea
authored
Mar 12, 2024
by
Illia Silin
Committed by
GitHub
Mar 12, 2024
Browse files
Merge pull request #47 from ROCm/navi4x_wmma
Navi4x wmma GEMM
parents
9de63596
9a9cb884
Changes
22
Expand all
Show whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
775 additions
and
74 deletions
+775
-74
CMakeLists.txt
CMakeLists.txt
+4
-2
cmake/EnableCompilerWarnings.cmake
cmake/EnableCompilerWarnings.cmake
+1
-1
example/01_gemm/CMakeLists.txt
example/01_gemm/CMakeLists.txt
+2
-2
example/01_gemm/gemm_wmma_fp16.cpp
example/01_gemm/gemm_wmma_fp16.cpp
+41
-44
example/02_gemm_bilinear/CMakeLists.txt
example/02_gemm_bilinear/CMakeLists.txt
+1
-1
example/29_batched_gemm_bias_e_permute/CMakeLists.txt
example/29_batched_gemm_bias_e_permute/CMakeLists.txt
+1
-1
example/64_fpAintB_gemm/CMakeLists.txt
example/64_fpAintB_gemm/CMakeLists.txt
+1
-1
include/ck/ck.hpp
include/ck/ck.hpp
+6
-3
include/ck/host_utility/device_prop.hpp
include/ck/host_utility/device_prop.hpp
+2
-0
include/ck/tensor_operation/gpu/block/blockwise_gemm_wmma.hpp
...ude/ck/tensor_operation/gpu/block/blockwise_gemm_wmma.hpp
+473
-2
include/ck/tensor_operation/gpu/device/impl/device_batched_contraction_multiple_d_wmma_cshuffle.hpp
...l/device_batched_contraction_multiple_d_wmma_cshuffle.hpp
+2
-2
include/ck/tensor_operation/gpu/device/impl/device_fpAintB_gemm_wmma.hpp
...or_operation/gpu/device/impl/device_fpAintB_gemm_wmma.hpp
+1
-1
include/ck/tensor_operation/gpu/device/impl/device_gemm_dl.hpp
...de/ck/tensor_operation/gpu/device/impl/device_gemm_dl.hpp
+1
-1
include/ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_wmma_cshuffle.hpp
.../gpu/device/impl/device_gemm_multiple_d_wmma_cshuffle.hpp
+3
-3
include/ck/tensor_operation/gpu/device/impl/device_gemm_wmma.hpp
.../ck/tensor_operation/gpu/device/impl/device_gemm_wmma.hpp
+3
-3
include/ck/tensor_operation/gpu/grid/gridwise_fpAintB_gemm_wmma.hpp
.../tensor_operation/gpu/grid/gridwise_fpAintB_gemm_wmma.hpp
+1
-2
include/ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_d_wmma_cshuffle.hpp
...ation/gpu/grid/gridwise_gemm_multiple_d_wmma_cshuffle.hpp
+3
-3
include/ck/tensor_operation/gpu/grid/gridwise_gemm_wmma.hpp
include/ck/tensor_operation/gpu/grid/gridwise_gemm_wmma.hpp
+1
-1
include/ck/tensor_operation/gpu/warp/wmma_gemm.hpp
include/ck/tensor_operation/gpu/warp/wmma_gemm.hpp
+146
-1
include/ck/utility/amd_wmma.hpp
include/ck/utility/amd_wmma.hpp
+82
-0
No files found.
CMakeLists.txt
View file @
9fa379ea
...
@@ -118,7 +118,7 @@ else()
...
@@ -118,7 +118,7 @@ else()
add_definitions
(
-DPROFILER_ONLY
)
add_definitions
(
-DPROFILER_ONLY
)
set
(
GPU_TARGETS
""
CACHE STRING
""
FORCE
)
set
(
GPU_TARGETS
""
CACHE STRING
""
FORCE
)
if
(
GPU_TARGETS
)
if
(
GPU_TARGETS
)
message
(
FATAL_ERROR
"For PROFILE_ONLY build, please do not set GPU_TARGETS, use GPU_ARCH = gfx90, gfx94, gfx10, or gfx1
1
"
)
message
(
FATAL_ERROR
"For PROFILE_ONLY build, please do not set GPU_TARGETS, use GPU_ARCH = gfx90, gfx94, gfx10,
gfx11
or gfx1
2
"
)
endif
()
endif
()
if
(
GPU_ARCH MATCHES
"gfx90"
)
if
(
GPU_ARCH MATCHES
"gfx90"
)
rocm_check_target_ids
(
DEFAULT_GPU_TARGETS TARGETS
"gfx908;gfx90a"
)
rocm_check_target_ids
(
DEFAULT_GPU_TARGETS TARGETS
"gfx908;gfx90a"
)
...
@@ -128,8 +128,10 @@ else()
...
@@ -128,8 +128,10 @@ else()
rocm_check_target_ids
(
DEFAULT_GPU_TARGETS TARGETS
"gfx1030"
)
rocm_check_target_ids
(
DEFAULT_GPU_TARGETS TARGETS
"gfx1030"
)
elseif
(
GPU_ARCH MATCHES
"gfx11"
)
elseif
(
GPU_ARCH MATCHES
"gfx11"
)
rocm_check_target_ids
(
DEFAULT_GPU_TARGETS TARGETS
"gfx1100;gfx1101;gfx1102"
)
rocm_check_target_ids
(
DEFAULT_GPU_TARGETS TARGETS
"gfx1100;gfx1101;gfx1102"
)
elseif
(
GPU_ARCH MATCHES
"gfx12"
)
rocm_check_target_ids
(
DEFAULT_GPU_TARGETS TARGETS
"gfx1200;gfx1201"
)
else
()
else
()
message
(
FATAL_ERROR
"For PROFILE_ONLY build, please specify GPU_ARCH as gfx90, gfx94, gfx10, or gfx1
1
"
)
message
(
FATAL_ERROR
"For PROFILE_ONLY build, please specify GPU_ARCH as gfx90, gfx94, gfx10,
gfx11
or gfx1
2
"
)
endif
()
endif
()
set
(
GPU_TARGETS
"
${
DEFAULT_GPU_TARGETS
}
"
CACHE STRING
" "
FORCE
)
set
(
GPU_TARGETS
"
${
DEFAULT_GPU_TARGETS
}
"
CACHE STRING
" "
FORCE
)
endif
()
endif
()
...
...
cmake/EnableCompilerWarnings.cmake
View file @
9fa379ea
example/01_gemm/CMakeLists.txt
View file @
9fa379ea
...
@@ -27,7 +27,8 @@ add_example_dependencies(example_gemm_xdl example_gemm_xdl_wavelet_fp16)
...
@@ -27,7 +27,8 @@ add_example_dependencies(example_gemm_xdl example_gemm_xdl_wavelet_fp16)
add_example_executable
(
example_gemm_xdl_skip_b_lds_fp16 gemm_xdl_skip_b_lds_fp16.cpp
)
add_example_executable
(
example_gemm_xdl_skip_b_lds_fp16 gemm_xdl_skip_b_lds_fp16.cpp
)
add_example_dependencies
(
example_gemm_xdl example_gemm_xdl_skip_b_lds_fp16
)
add_example_dependencies
(
example_gemm_xdl example_gemm_xdl_skip_b_lds_fp16
)
if
(
GPU_TARGETS MATCHES
"gfx11"
)
if
(
GPU_TARGETS MATCHES
"gfx11"
OR GPU_TARGETS MATCHES
"gfx12"
)
add_custom_target
(
example_gemm_wmma
)
add_custom_target
(
example_gemm_wmma
)
add_example_executable
(
example_gemm_wmma_fp16 gemm_wmma_fp16.cpp
)
add_example_executable
(
example_gemm_wmma_fp16 gemm_wmma_fp16.cpp
)
add_example_dependencies
(
example_gemm_wmma example_gemm_wmma_fp16
)
add_example_dependencies
(
example_gemm_wmma example_gemm_wmma_fp16
)
...
@@ -74,4 +75,3 @@ add_example_dependencies(example_gemm_xdl example_gemm_xdl_fp8_bf8)
...
@@ -74,4 +75,3 @@ add_example_dependencies(example_gemm_xdl example_gemm_xdl_fp8_bf8)
add_example_executable
(
example_gemm_xdl_fp16_fp8 gemm_xdl_fp16_fp8.cpp
)
add_example_executable
(
example_gemm_xdl_fp16_fp8 gemm_xdl_fp16_fp8.cpp
)
add_example_dependencies
(
example_gemm_xdl example_gemm_xdl_fp16_fp8
)
add_example_dependencies
(
example_gemm_xdl example_gemm_xdl_fp16_fp8
)
example/01_gemm/gemm_wmma_fp16.cpp
View file @
9fa379ea
...
@@ -19,11 +19,9 @@ using AElementOp = PassThrough;
...
@@ -19,11 +19,9 @@ using AElementOp = PassThrough;
using
BElementOp
=
PassThrough
;
using
BElementOp
=
PassThrough
;
using
CElementOp
=
PassThrough
;
using
CElementOp
=
PassThrough
;
static
constexpr
auto
GemmDefault
=
ck
::
tensor_operation
::
device
::
GemmSpecialization
::
MNKPadding
;
static
constexpr
auto
GemmDefault
=
ck
::
tensor_operation
::
device
::
GemmSpecialization
::
Default
;
// clang-format off
using
DeviceGemmInstance
=
ck
::
tensor_operation
::
device
::
DeviceGemmWmma_CShuffle
<
ALayout
,
using
DeviceGemmInstance
=
ck
::
tensor_operation
::
device
::
DeviceGemmWmma_CShuffle
<
ALayout
,
BLayout
,
BLayout
,
CLayout
,
CLayout
,
ADataType
,
ADataType
,
...
@@ -35,35 +33,34 @@ using DeviceGemmInstance = ck::tensor_operation::device::DeviceGemmWmma_CShuffle
...
@@ -35,35 +33,34 @@ using DeviceGemmInstance = ck::tensor_operation::device::DeviceGemmWmma_CShuffle
BElementOp
,
BElementOp
,
CElementOp
,
CElementOp
,
GemmDefault
,
GemmDefault
,
1
,
// Prefetch stage
1
,
128
,
// BlockSize
32
,
64
,
// MPerBlock
16
,
128
,
// NPerBlock
32
,
64
,
// KPerBlock
64
,
8
,
// K1
8
,
16
,
// MPerWmma
16
,
16
,
// NPerWmma
16
,
2
,
// M-Repeat // M-PerWmma / M-Repeat = M-Wave
1
,
4
,
// N-Repeat // N-PerWmma / N-Repeat = N-Wave
2
,
S
<
4
,
32
,
1
>
,
S
<
2
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
2
,
8
,
8
,
8
,
8
,
true
,
true
,
S
<
4
,
32
,
1
>
,
S
<
2
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
2
,
8
,
8
,
8
,
8
,
true
,
true
,
1
,
// C shuffle (M Repeat) Per store
1
,
1
,
// C shuffle (N Repeat) Per store
1
,
S
<
1
,
32
,
1
,
4
>
,
S
<
1
,
16
,
1
,
2
>
,
8
>
;
8
>
;
// clang-format on
using
ReferenceGemmInstance
=
ck
::
tensor_operation
::
host
::
using
ReferenceGemmInstance
=
ck
::
tensor_operation
::
host
::
ReferenceGemm
<
ADataType
,
BDataType
,
CDataType
,
AccDataType
,
AElementOp
,
BElementOp
,
CElementOp
>
;
ReferenceGemm
<
ADataType
,
BDataType
,
CDataType
,
AccDataType
,
AElementOp
,
BElementOp
,
CElementOp
>
;
...
...
example/02_gemm_bilinear/CMakeLists.txt
View file @
9fa379ea
list
(
APPEND gpu_list1 gfx1100 gfx1101 gfx1102
)
list
(
APPEND gpu_list1 gfx1100 gfx1101 gfx1102
gfx1103 gfx1200 gfx1201
)
list
(
APPEND gpu_list2 gfx908 gfx90a gfx940 gfx941 gfx942 gfx950
)
list
(
APPEND gpu_list2 gfx908 gfx90a gfx940 gfx941 gfx942 gfx950
)
set
(
target 0
)
set
(
target 0
)
foreach
(
gpu IN LISTS GPU_TARGETS
)
foreach
(
gpu IN LISTS GPU_TARGETS
)
...
...
example/29_batched_gemm_bias_e_permute/CMakeLists.txt
View file @
9fa379ea
add_example_executable
(
example_batched_gemm_bias_e_permute_xdl_fp16 batched_gemm_bias_e_permute_xdl_fp16.cpp
)
add_example_executable
(
example_batched_gemm_bias_e_permute_xdl_fp16 batched_gemm_bias_e_permute_xdl_fp16.cpp
)
if
(
GPU_TARGETS MATCHES
"gfx11"
)
if
(
GPU_TARGETS MATCHES
"gfx11"
OR GPU_TARGETS MATCHES
"gfx12"
)
add_example_executable
(
example_batched_gemm_bias_e_permute_wmma_fp16 batched_gemm_bias_e_permute_wmma_fp16.cpp
)
add_example_executable
(
example_batched_gemm_bias_e_permute_wmma_fp16 batched_gemm_bias_e_permute_wmma_fp16.cpp
)
endif
()
endif
()
example/64_fpAintB_gemm/CMakeLists.txt
View file @
9fa379ea
if
(
GPU_TARGETS MATCHES
"gfx11"
)
if
(
GPU_TARGETS MATCHES
"gfx11"
OR GPU_TARGETS MATCHES
"gfx12"
)
add_custom_target
(
example_fpAintB_gemm_wmma
)
add_custom_target
(
example_fpAintB_gemm_wmma
)
add_example_executable
(
example_fp16int8_gemm_wmma fp16int8_gemm_wmma.cpp
)
add_example_executable
(
example_fp16int8_gemm_wmma fp16int8_gemm_wmma.cpp
)
add_dependencies
(
example_fpAintB_gemm_wmma example_fp16int8_gemm_wmma
)
add_dependencies
(
example_fpAintB_gemm_wmma example_fp16int8_gemm_wmma
)
...
...
include/ck/ck.hpp
View file @
9fa379ea
...
@@ -58,6 +58,9 @@
...
@@ -58,6 +58,9 @@
#if defined(__gfx1100__) || defined(__gfx1101__) || defined(__gfx1102__) || defined(__gfx1103__)
#if defined(__gfx1100__) || defined(__gfx1101__) || defined(__gfx1102__) || defined(__gfx1103__)
#define __gfx11__
#define __gfx11__
#endif
#endif
#if defined(__gfx1200__) || defined(__gfx1201__)
#define __gfx12__
#endif
// buffer resource
// buffer resource
#ifndef __HIP_DEVICE_COMPILE__ // for host code
#ifndef __HIP_DEVICE_COMPILE__ // for host code
...
@@ -67,7 +70,7 @@
...
@@ -67,7 +70,7 @@
#define CK_BUFFER_RESOURCE_3RD_DWORD 0x00020000
#define CK_BUFFER_RESOURCE_3RD_DWORD 0x00020000
#elif defined(__gfx103__)
#elif defined(__gfx103__)
#define CK_BUFFER_RESOURCE_3RD_DWORD 0x31014000
#define CK_BUFFER_RESOURCE_3RD_DWORD 0x31014000
#elif defined(__gfx11__)
#elif defined(__gfx11__)
|| defined(__gfx12__)
#define CK_BUFFER_RESOURCE_3RD_DWORD 0x31004000
#define CK_BUFFER_RESOURCE_3RD_DWORD 0x31004000
#endif
#endif
...
@@ -80,7 +83,7 @@
...
@@ -80,7 +83,7 @@
#define CK_USE_AMD_V_FMAC_F32
#define CK_USE_AMD_V_FMAC_F32
#define CK_USE_AMD_V_DOT2_F32_F16
#define CK_USE_AMD_V_DOT2_F32_F16
#define CK_USE_AMD_V_DOT4_I32_I8
#define CK_USE_AMD_V_DOT4_I32_I8
#elif defined(__gfx11__)
#elif defined(__gfx11__)
|| defined(__gfx12__)
#define CK_USE_AMD_V_FMAC_F32
#define CK_USE_AMD_V_FMAC_F32
#define CK_USE_AMD_V_DOT2_F32_F16
#define CK_USE_AMD_V_DOT2_F32_F16
#define CK_USE_AMD_V_DOT4_I32_I8_GFX11
#define CK_USE_AMD_V_DOT4_I32_I8_GFX11
...
@@ -104,7 +107,7 @@
...
@@ -104,7 +107,7 @@
// WMMA instruction
// WMMA instruction
#ifndef __HIP_DEVICE_COMPILE__ // for host code
#ifndef __HIP_DEVICE_COMPILE__ // for host code
#define CK_USE_AMD_WMMA
#define CK_USE_AMD_WMMA
#elif defined(__gfx11__) // for GPU code
#elif defined(__gfx11__)
|| defined(__gfx12__)
// for GPU code
#define CK_USE_AMD_WMMA
#define CK_USE_AMD_WMMA
#endif
#endif
...
...
include/ck/host_utility/device_prop.hpp
View file @
9fa379ea
...
@@ -85,4 +85,6 @@ inline bool is_navi3_supported()
...
@@ -85,4 +85,6 @@ inline bool is_navi3_supported()
ck
::
get_device_name
()
==
"gfx1102"
||
ck
::
get_device_name
()
==
"gfx1103"
;
ck
::
get_device_name
()
==
"gfx1102"
||
ck
::
get_device_name
()
==
"gfx1103"
;
}
}
inline
bool
is_navi4_supported
()
{
return
ck
::
get_device_name
()
==
"gfx1200"
;
}
}
// namespace ck
}
// namespace ck
include/ck/tensor_operation/gpu/block/blockwise_gemm_wmma.hpp
View file @
9fa379ea
This diff is collapsed.
Click to expand it.
include/ck/tensor_operation/gpu/device/impl/device_batched_contraction_multiple_d_wmma_cshuffle.hpp
View file @
9fa379ea
...
@@ -137,8 +137,8 @@ struct DeviceBatchedContractionMultipleD_Wmma_CShuffle
...
@@ -137,8 +137,8 @@ struct DeviceBatchedContractionMultipleD_Wmma_CShuffle
static
constexpr
auto
BEnableLds_auto
=
MWaves
==
1
?
false
:
true
;
static
constexpr
auto
BEnableLds_auto
=
MWaves
==
1
?
false
:
true
;
// If true, LDS is used unconditionally
// If true, LDS is used unconditionally
static
constexpr
auto
AEnableLds_manu
=
fals
e
;
static
constexpr
auto
AEnableLds_manu
=
tru
e
;
static
constexpr
auto
BEnableLds_manu
=
fals
e
;
static
constexpr
auto
BEnableLds_manu
=
tru
e
;
static
constexpr
auto
AEnableLds
=
AEnableLds_auto
||
AEnableLds_manu
||
(
NumPrefetch
>
1
);
static
constexpr
auto
AEnableLds
=
AEnableLds_auto
||
AEnableLds_manu
||
(
NumPrefetch
>
1
);
static
constexpr
auto
BEnableLds
=
BEnableLds_auto
||
BEnableLds_manu
||
(
NumPrefetch
>
1
);
static
constexpr
auto
BEnableLds
=
BEnableLds_auto
||
BEnableLds_manu
||
(
NumPrefetch
>
1
);
...
...
include/ck/tensor_operation/gpu/device/impl/device_fpAintB_gemm_wmma.hpp
View file @
9fa379ea
...
@@ -509,7 +509,7 @@ struct DeviceFpAintBGemm_Wmma_CShuffle : public DeviceGemm_dequantB<ALayout,
...
@@ -509,7 +509,7 @@ struct DeviceFpAintBGemm_Wmma_CShuffle : public DeviceGemm_dequantB<ALayout,
static
bool
IsSupportedArgument
(
const
Argument
&
arg
)
static
bool
IsSupportedArgument
(
const
Argument
&
arg
)
{
{
if
(
ck
::
is_navi3_supported
())
if
(
ck
::
is_navi3_supported
()
||
ck
::
is_navi4_supported
()
)
{
{
if
constexpr
(
!
(
is_same_v
<
AccDataType
,
float
>
||
is_same_v
<
AccDataType
,
ck
::
half_t
>
||
if
constexpr
(
!
(
is_same_v
<
AccDataType
,
float
>
||
is_same_v
<
AccDataType
,
ck
::
half_t
>
||
is_same_v
<
AccDataType
,
int32_t
>
))
is_same_v
<
AccDataType
,
int32_t
>
))
...
...
include/ck/tensor_operation/gpu/device/impl/device_gemm_dl.hpp
View file @
9fa379ea
...
@@ -537,7 +537,7 @@ struct DeviceGemmDl : public DeviceGemm<ALayout,
...
@@ -537,7 +537,7 @@ struct DeviceGemmDl : public DeviceGemm<ALayout,
}
}
if
(
ck
::
get_device_name
()
==
"gfx906"
||
ck
::
is_navi2_supported
()
||
if
(
ck
::
get_device_name
()
==
"gfx906"
||
ck
::
is_navi2_supported
()
||
ck
::
is_navi3_supported
())
ck
::
is_navi3_supported
()
||
ck
::
is_navi4_supported
()
)
{
{
return
GridwiseGemm
::
CheckValidity
(
return
GridwiseGemm
::
CheckValidity
(
arg
.
a_grid_desc_k0_m_k1_
,
arg
.
b_grid_desc_k0_n_k1_
,
arg
.
c_grid_desc_m_n_
);
arg
.
a_grid_desc_k0_m_k1_
,
arg
.
b_grid_desc_k0_n_k1_
,
arg
.
c_grid_desc_m_n_
);
...
...
include/ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_wmma_cshuffle.hpp
View file @
9fa379ea
...
@@ -101,8 +101,8 @@ struct DeviceGemmMultipleD_Wmma_CShuffle : public DeviceGemmMultipleD<ALayout,
...
@@ -101,8 +101,8 @@ struct DeviceGemmMultipleD_Wmma_CShuffle : public DeviceGemmMultipleD<ALayout,
(
MWaves
==
1
&&
is_same
<
tensor_layout
::
gemm
::
ColumnMajor
,
BLayout
>::
value
)
?
false
:
true
;
(
MWaves
==
1
&&
is_same
<
tensor_layout
::
gemm
::
ColumnMajor
,
BLayout
>::
value
)
?
false
:
true
;
// If true, LDS is used unconditionally
// If true, LDS is used unconditionally
static
constexpr
auto
AEnableLds_manu
=
fals
e
;
static
constexpr
auto
AEnableLds_manu
=
tru
e
;
static
constexpr
auto
BEnableLds_manu
=
fals
e
;
static
constexpr
auto
BEnableLds_manu
=
tru
e
;
static
constexpr
auto
AEnableLds
=
AEnableLds_auto
||
AEnableLds_manu
||
(
NumPrefetch
>
1
);
static
constexpr
auto
AEnableLds
=
AEnableLds_auto
||
AEnableLds_manu
||
(
NumPrefetch
>
1
);
static
constexpr
auto
BEnableLds
=
BEnableLds_auto
||
BEnableLds_manu
||
(
NumPrefetch
>
1
);
static
constexpr
auto
BEnableLds
=
BEnableLds_auto
||
BEnableLds_manu
||
(
NumPrefetch
>
1
);
...
@@ -515,7 +515,7 @@ struct DeviceGemmMultipleD_Wmma_CShuffle : public DeviceGemmMultipleD<ALayout,
...
@@ -515,7 +515,7 @@ struct DeviceGemmMultipleD_Wmma_CShuffle : public DeviceGemmMultipleD<ALayout,
static
bool
IsSupportedArgument
(
const
Argument
&
arg
)
static
bool
IsSupportedArgument
(
const
Argument
&
arg
)
{
{
if
(
ck
::
is_navi3_supported
())
if
(
ck
::
is_navi3_supported
()
||
ck
::
is_navi4_supported
()
)
{
{
if
constexpr
(
!
(
is_same_v
<
AccDataType
,
float
>
||
is_same_v
<
AccDataType
,
int32_t
>
))
if
constexpr
(
!
(
is_same_v
<
AccDataType
,
float
>
||
is_same_v
<
AccDataType
,
int32_t
>
))
{
{
...
...
include/ck/tensor_operation/gpu/device/impl/device_gemm_wmma.hpp
View file @
9fa379ea
...
@@ -94,8 +94,8 @@ struct DeviceGemmWmma_CShuffle : public DeviceGemm<ALayout,
...
@@ -94,8 +94,8 @@ struct DeviceGemmWmma_CShuffle : public DeviceGemm<ALayout,
(
MWaves
==
1
&&
is_same
<
tensor_layout
::
gemm
::
ColumnMajor
,
BLayout
>::
value
)
?
false
:
true
;
(
MWaves
==
1
&&
is_same
<
tensor_layout
::
gemm
::
ColumnMajor
,
BLayout
>::
value
)
?
false
:
true
;
// If true, LDS is used unconditionally
// If true, LDS is used unconditionally
static
constexpr
auto
AEnableLds_manu
=
fals
e
;
static
constexpr
auto
AEnableLds_manu
=
tru
e
;
static
constexpr
auto
BEnableLds_manu
=
fals
e
;
static
constexpr
auto
BEnableLds_manu
=
tru
e
;
static
constexpr
auto
AEnableLds
=
AEnableLds_auto
||
AEnableLds_manu
||
(
NumPrefetch
>
1
);
static
constexpr
auto
AEnableLds
=
AEnableLds_auto
||
AEnableLds_manu
||
(
NumPrefetch
>
1
);
static
constexpr
auto
BEnableLds
=
BEnableLds_auto
||
BEnableLds_manu
||
(
NumPrefetch
>
1
);
static
constexpr
auto
BEnableLds
=
BEnableLds_auto
||
BEnableLds_manu
||
(
NumPrefetch
>
1
);
...
@@ -443,7 +443,7 @@ struct DeviceGemmWmma_CShuffle : public DeviceGemm<ALayout,
...
@@ -443,7 +443,7 @@ struct DeviceGemmWmma_CShuffle : public DeviceGemm<ALayout,
static
bool
IsSupportedArgument
(
const
Argument
&
arg
)
static
bool
IsSupportedArgument
(
const
Argument
&
arg
)
{
{
if
(
ck
::
is_navi3_supported
())
if
(
ck
::
is_navi3_supported
()
||
ck
::
is_navi4_supported
()
)
{
{
if
constexpr
(
!
(
is_same_v
<
AccDataType
,
float
>
||
is_same_v
<
AccDataType
,
ck
::
half_t
>
||
if
constexpr
(
!
(
is_same_v
<
AccDataType
,
float
>
||
is_same_v
<
AccDataType
,
ck
::
half_t
>
||
is_same_v
<
AccDataType
,
int32_t
>
))
is_same_v
<
AccDataType
,
int32_t
>
))
...
...
include/ck/tensor_operation/gpu/grid/gridwise_fpAintB_gemm_wmma.hpp
View file @
9fa379ea
...
@@ -50,8 +50,7 @@ __global__ void
...
@@ -50,8 +50,7 @@ __global__ void
const
CElementwiseOperation
c_element_op
,
const
CElementwiseOperation
c_element_op
,
const
Block2CTileMap
block_2_ctile_map
)
const
Block2CTileMap
block_2_ctile_map
)
{
{
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx1100__) || defined(__gfx1101__) || \
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx11__) || defined(__gfx12__))
defined(__gfx1102__))
__shared__
char
p_shared
[
GridwiseGemm
::
SharedMemTrait
::
lds_size
];
__shared__
char
p_shared
[
GridwiseGemm
::
SharedMemTrait
::
lds_size
];
GridwiseGemm
::
template
Run
<
HasMainKBlockLoop
>(
p_a_grid
,
GridwiseGemm
::
template
Run
<
HasMainKBlockLoop
>(
p_a_grid
,
...
...
include/ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_d_wmma_cshuffle.hpp
View file @
9fa379ea
...
@@ -54,7 +54,7 @@ __global__ void
...
@@ -54,7 +54,7 @@ __global__ void
const
Block2CTileMap
block_2_ctile_map
,
const
Block2CTileMap
block_2_ctile_map
,
const
ComputePtrOffsetOfBatch
compute_ptr_offset_of_batch
)
const
ComputePtrOffsetOfBatch
compute_ptr_offset_of_batch
)
{
{
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx11__))
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx11__)
|| defined(__gfx12__)
)
// offset base pointer for each work-group
// offset base pointer for each work-group
const
index_t
num_blocks_per_batch
=
const
index_t
num_blocks_per_batch
=
__builtin_amdgcn_readfirstlane
(
get_grid_size
()
/
batch_count
);
__builtin_amdgcn_readfirstlane
(
get_grid_size
()
/
batch_count
);
...
@@ -147,7 +147,7 @@ __global__ void
...
@@ -147,7 +147,7 @@ __global__ void
const
ComputePtrOffsetOfBatch
compute_ptr_offset_of_batch
,
const
ComputePtrOffsetOfBatch
compute_ptr_offset_of_batch
,
const
Block2CTileMap
block_2_etile_map
)
const
Block2CTileMap
block_2_etile_map
)
{
{
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx11__))
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx11__)
|| defined(__gfx12__)
)
// printf("entry kernel launch");
// printf("entry kernel launch");
__shared__
char
p_shared
[
GridwiseOp
::
SharedMemTrait
::
lds_size
];
__shared__
char
p_shared
[
GridwiseOp
::
SharedMemTrait
::
lds_size
];
...
@@ -237,7 +237,7 @@ __global__ void
...
@@ -237,7 +237,7 @@ __global__ void
const
CDEElementwiseOperation
cde_element_op
,
const
CDEElementwiseOperation
cde_element_op
,
const
Block2CTileMap
block_2_ctile_map
)
const
Block2CTileMap
block_2_ctile_map
)
{
{
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx11__))
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx11__)
|| defined(__gfx12__)
)
__shared__
char
p_shared
[
GridwiseOp
::
SharedMemTrait
::
lds_size
];
__shared__
char
p_shared
[
GridwiseOp
::
SharedMemTrait
::
lds_size
];
GridwiseOp
::
template
Run
<
HasMainKBlockLoop
>(
p_a_grid
,
GridwiseOp
::
template
Run
<
HasMainKBlockLoop
>(
p_a_grid
,
...
...
include/ck/tensor_operation/gpu/grid/gridwise_gemm_wmma.hpp
View file @
9fa379ea
...
@@ -45,7 +45,7 @@ __global__ void
...
@@ -45,7 +45,7 @@ __global__ void
const
CElementwiseOperation
c_element_op
,
const
CElementwiseOperation
c_element_op
,
const
Block2CTileMap
block_2_ctile_map
)
const
Block2CTileMap
block_2_ctile_map
)
{
{
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx11__))
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx11__)
|| defined(__gfx12__)
)
__shared__
char
p_shared
[
GridwiseGemm
::
SharedMemTrait
::
lds_size
];
__shared__
char
p_shared
[
GridwiseGemm
::
SharedMemTrait
::
lds_size
];
GridwiseGemm
::
template
Run
<
HasMainKBlockLoop
>(
p_a_grid
,
GridwiseGemm
::
template
Run
<
HasMainKBlockLoop
>(
p_a_grid
,
...
...
include/ck/tensor_operation/gpu/warp/wmma_gemm.hpp
View file @
9fa379ea
...
@@ -11,12 +11,17 @@ namespace ck {
...
@@ -11,12 +11,17 @@ namespace ck {
enum
struct
WmmaInstr
enum
struct
WmmaInstr
{
{
// gfx11
wmma_f32_16x16x16_f16
=
0
,
wmma_f32_16x16x16_f16
=
0
,
wmma_f32_16x16x16_bf16
,
wmma_f32_16x16x16_bf16
,
wmma_f16_16x16x16_f16
,
wmma_f16_16x16x16_f16
,
wmma_bf16_16x16x16_bf16
,
wmma_bf16_16x16x16_bf16
,
wmma_i32_16x16x16_iu8
,
wmma_i32_16x16x16_iu8
,
wmma_i32_16x16x16_iu4
wmma_i32_16x16x16_iu4
,
// gfx12
wmma_f32_16x16x16_f16_gfx12
,
wmma_f32_16x16x16_bf16_gfx12
,
wmma_i32_16x16x16_iu8_gfx12
,
};
};
/*
/*
...
@@ -279,6 +284,122 @@ struct wmma_type<WmmaInstr::wmma_i32_16x16x16_iu8,
...
@@ -279,6 +284,122 @@ struct wmma_type<WmmaInstr::wmma_i32_16x16x16_iu8,
}
}
};
};
// gfx12
// A-swizzled
template
<
index_t
WaveSize
>
struct
wmma_type
<
WmmaInstr
::
wmma_f32_16x16x16_f16_gfx12
,
WaveSize
,
typename
std
::
enable_if_t
<
WaveSize
==
32
||
WaveSize
==
64
>>
{
// Absolute fixing property
// * Data Pixel
static
constexpr
index_t
m_per_wmma
=
16
;
static
constexpr
index_t
n_per_wmma
=
16
;
static
constexpr
index_t
k_per_wmma
=
16
;
// static constexpr index_t src_a_data_size = 2;
// static constexpr index_t src_b_data_size = 2;
// static constexpr index_t acc_data_size = 4;
// * Thread mapping inside wave, num_thread_per_subgroups always alone N direction
static
constexpr
index_t
acc_data_size
=
4
;
static
constexpr
index_t
acc_pack_number
=
1
;
static
constexpr
index_t
num_thread_per_subgroups
=
n_per_wmma
;
// Wave mode dependent propety
static
constexpr
index_t
wave_size
=
Number
<
WaveSize
>
{};
// * Fixed in Navi3x, Will be wave mode dependent on Navi4x
// static constexpr index_t num_src_a_vgprs_per_wave = k_per_wmma / 2 * src_a_data_size / 4;
// static constexpr index_t num_src_b_vgprs_per_wave = k_per_wmma / 2 * src_b_data_size / 4;
// * num_acc_vgprs_per_wave alone M direction
// * num_subgroups alone M direction
static
constexpr
index_t
num_acc_vgprs_per_wave
=
m_per_wmma
*
n_per_wmma
/
wave_size
;
static
constexpr
index_t
num_subgroups
=
wave_size
/
num_thread_per_subgroups
;
template
<
index_t
MPerWmma
,
index_t
NPerWmma
,
class
FloatA
,
class
FloatB
,
class
FloatC
>
__device__
void
run
(
const
FloatA
&
a
,
const
FloatB
&
b
,
FloatC
&
reg_c
)
const
{
static_assert
(
wave_size
==
32
,
"only support wave32 for gfx12 wmma"
);
if
constexpr
(
wave_size
==
32
)
{
intrin_wmma_f32_16x16x16_f16_w32_gfx12
<
MPerWmma
,
NPerWmma
>::
Run
(
a
,
b
,
reg_c
);
}
}
};
template
<
index_t
WaveSize
>
struct
wmma_type
<
WmmaInstr
::
wmma_f32_16x16x16_bf16_gfx12
,
WaveSize
,
typename
std
::
enable_if_t
<
WaveSize
==
32
||
WaveSize
==
64
>>
{
// Absolute fixing property
static
constexpr
index_t
m_per_wmma
=
16
;
static
constexpr
index_t
n_per_wmma
=
16
;
static
constexpr
index_t
k_per_wmma
=
16
;
// static constexpr index_t src_a_data_size = 2;
// static constexpr index_t src_b_data_size = 2;
static
constexpr
index_t
acc_data_size
=
4
;
static
constexpr
index_t
acc_pack_number
=
1
;
static
constexpr
index_t
num_thread_per_subgroups
=
n_per_wmma
;
// Wave mode dependent propety
static
constexpr
index_t
wave_size
=
Number
<
WaveSize
>
{};
// static constexpr index_t num_src_a_vgprs_per_wave = m_per_wmma * src_a_data_size / 4;
// static constexpr index_t num_src_b_vgprs_per_wave = n_per_wmma * src_b_data_size / 4;
static
constexpr
index_t
num_acc_vgprs_per_wave
=
m_per_wmma
*
n_per_wmma
/
wave_size
;
static
constexpr
index_t
num_subgroups
=
wave_size
/
num_thread_per_subgroups
;
template
<
index_t
MPerWmma
,
index_t
NPerWmma
,
class
FloatA
,
class
FloatB
,
class
FloatC
>
__device__
void
run
(
const
FloatA
&
a
,
const
FloatB
&
b
,
FloatC
&
reg_c
)
const
{
static_assert
(
wave_size
==
32
,
"only support wave32 for gfx12 wmma"
);
if
constexpr
(
wave_size
==
32
)
{
intrin_wmma_f32_16x16x16_bf16_w32_gfx12
<
MPerWmma
,
NPerWmma
>::
Run
(
a
,
b
,
reg_c
);
}
}
};
template
<
index_t
WaveSize
>
struct
wmma_type
<
WmmaInstr
::
wmma_i32_16x16x16_iu8_gfx12
,
WaveSize
,
typename
std
::
enable_if_t
<
WaveSize
==
32
||
WaveSize
==
64
>>
{
// Absolute fixing property
static
constexpr
index_t
m_per_wmma
=
16
;
static
constexpr
index_t
n_per_wmma
=
16
;
static
constexpr
index_t
k_per_wmma
=
16
;
// static constexpr index_t src_a_data_size = 2;
// static constexpr index_t src_b_data_size = 2;
static
constexpr
index_t
acc_data_size
=
4
;
static
constexpr
index_t
acc_pack_number
=
1
;
static
constexpr
index_t
num_thread_per_subgroups
=
n_per_wmma
;
// Wave mode dependent propety
static
constexpr
index_t
wave_size
=
Number
<
WaveSize
>
{};
// static constexpr index_t num_src_a_vgprs_per_wave = m_per_wmma * src_a_data_size / 4;
// static constexpr index_t num_src_b_vgprs_per_wave = n_per_wmma * src_b_data_size / 4;
static
constexpr
index_t
num_acc_vgprs_per_wave
=
m_per_wmma
*
n_per_wmma
/
wave_size
;
static
constexpr
index_t
num_subgroups
=
wave_size
/
num_thread_per_subgroups
;
template
<
index_t
MPerWmma
,
index_t
NPerWmma
,
class
FloatA
,
class
FloatB
,
class
FloatC
,
bool
neg_a
=
false
,
bool
neg_b
=
false
,
bool
clamp
=
false
>
__device__
void
run
(
const
FloatA
&
a
,
const
FloatB
&
b
,
FloatC
&
reg_c
)
const
{
static_assert
(
wave_size
==
32
,
"only support wave32 for gfx12 wmma"
);
if
constexpr
(
wave_size
==
32
)
{
intrin_wmma_i32_16x16x16_iu8_w32_gfx12
<
MPerWmma
,
NPerWmma
,
neg_a
,
neg_b
,
clamp
>::
Run
(
a
,
b
,
reg_c
);
}
}
};
template
<
typename
src_type_a
,
template
<
typename
src_type_a
,
typename
src_type_b
,
typename
src_type_b
,
typename
dst_type
,
typename
dst_type
,
...
@@ -296,13 +417,21 @@ struct WmmaSelector
...
@@ -296,13 +417,21 @@ struct WmmaSelector
template
<
>
template
<
>
static
constexpr
auto
GetWmma
<
half_t
,
half_t
,
float
,
16
,
16
>
()
static
constexpr
auto
GetWmma
<
half_t
,
half_t
,
float
,
16
,
16
>
()
{
{
#ifdef __gfx12__
return
WmmaInstr
::
wmma_f32_16x16x16_f16_gfx12
;
#else
return
WmmaInstr
::
wmma_f32_16x16x16_f16
;
return
WmmaInstr
::
wmma_f32_16x16x16_f16
;
#endif
}
}
template
<
>
template
<
>
static
constexpr
auto
GetWmma
<
bhalf_t
,
bhalf_t
,
float
,
16
,
16
>
()
static
constexpr
auto
GetWmma
<
bhalf_t
,
bhalf_t
,
float
,
16
,
16
>
()
{
{
#ifdef __gfx12__
return
WmmaInstr
::
wmma_f32_16x16x16_bf16_gfx12
;
#else
return
WmmaInstr
::
wmma_f32_16x16x16_bf16
;
return
WmmaInstr
::
wmma_f32_16x16x16_bf16
;
#endif
}
}
template
<
>
template
<
>
...
@@ -320,8 +449,13 @@ struct WmmaSelector
...
@@ -320,8 +449,13 @@ struct WmmaSelector
template
<
>
template
<
>
static
constexpr
auto
GetWmma
<
int8_t
,
int8_t
,
int
,
16
,
16
>
()
static
constexpr
auto
GetWmma
<
int8_t
,
int8_t
,
int
,
16
,
16
>
()
{
{
#ifdef __gfx12__
return
WmmaInstr
::
wmma_i32_16x16x16_iu8_gfx12
;
#else
return
WmmaInstr
::
wmma_i32_16x16x16_iu8
;
return
WmmaInstr
::
wmma_i32_16x16x16_iu8
;
#endif
}
}
#ifdef CK_EXPERIMENTAL_BIT_INT_EXTENSION_INT4
#ifdef CK_EXPERIMENTAL_BIT_INT_EXTENSION_INT4
template
<
>
template
<
>
static
constexpr
auto
GetWmma
<
int4_t
,
int4_t
,
int
,
16
,
16
>
()
static
constexpr
auto
GetWmma
<
int4_t
,
int4_t
,
int
,
16
,
16
>
()
...
@@ -502,6 +636,9 @@ struct WmmaGemm
...
@@ -502,6 +636,9 @@ struct WmmaGemm
__device__
static
auto
GetSubGroupId
()
__device__
static
auto
GetSubGroupId
()
{
{
static_assert
(
wmma_instr
.
num_thread_per_subgroups
*
wmma_instr
.
num_subgroups
==
wmma_instr
.
wave_size
,
""
);
return
(
GetLaneId
()
/
wmma_instr
.
num_thread_per_subgroups
)
%
wmma_instr
.
num_subgroups
;
return
(
GetLaneId
()
/
wmma_instr
.
num_thread_per_subgroups
)
%
wmma_instr
.
num_subgroups
;
}
}
...
@@ -516,12 +653,20 @@ struct WmmaGemm
...
@@ -516,12 +653,20 @@ struct WmmaGemm
__host__
__device__
static
auto
CalculateAThreadOriginDataIndex
()
__host__
__device__
static
auto
CalculateAThreadOriginDataIndex
()
{
{
#ifdef __gfx12__
return
GetLaneIdUnderSubGroup
();
#else
return
TransposeC
?
GetLaneIdUnderSubGroup
()
:
GetSwizzledLaneIdLow
();
return
TransposeC
?
GetLaneIdUnderSubGroup
()
:
GetSwizzledLaneIdLow
();
#endif
}
}
__host__
__device__
static
auto
CalculateBThreadOriginDataIndex
()
__host__
__device__
static
auto
CalculateBThreadOriginDataIndex
()
{
{
#ifdef __gfx12__
return
GetLaneIdUnderSubGroup
();
#else
return
TransposeC
?
GetSwizzledLaneIdLow
()
:
GetLaneIdUnderSubGroup
();
return
TransposeC
?
GetSwizzledLaneIdLow
()
:
GetLaneIdUnderSubGroup
();
#endif
}
}
__device__
static
CIndex
GetBeginOfThreadBlk
()
__device__
static
CIndex
GetBeginOfThreadBlk
()
...
...
include/ck/utility/amd_wmma.hpp
View file @
9fa379ea
...
@@ -257,5 +257,87 @@ struct intrin_wmma_i32_16x16x16_iu8_w64<16, 16, neg_a, neg_b, clamp>
...
@@ -257,5 +257,87 @@ struct intrin_wmma_i32_16x16x16_iu8_w64<16, 16, neg_a, neg_b, clamp>
}
}
};
};
// gfx12
/********************************WAVE32 MODE***********************************************/
#if defined(__gfx1200__)
#define __gfx12__
#endif
// src: fp16, dst: fp32
template
<
index_t
MPerWave
,
index_t
NPerWave
>
struct
intrin_wmma_f32_16x16x16_f16_w32_gfx12
;
template
<
>
struct
intrin_wmma_f32_16x16x16_f16_w32_gfx12
<
16
,
16
>
{
template
<
class
FloatC
>
__device__
static
void
Run
(
const
half8_t
&
reg_a
,
const
half8_t
&
reg_b
,
FloatC
&
reg_c
)
{
// * Inline assembly need to elimate the duplicated data load, compiler won't help you
// delete them.
// amd_assembly_wmma_f32_16x16x16_f16_w32(
// reg_a, reg_b, reg_c.template AsType<float8_t>()(Number<0>{}));
#if defined(__gfx12__)
reg_c
.
template
AsType
<
float8_t
>()(
Number
<
0
>
{})
=
__builtin_amdgcn_wmma_f32_16x16x16_f16_w32_gfx12
(
reg_a
,
reg_b
,
reg_c
.
template
AsType
<
float8_t
>()[
Number
<
0
>
{}]);
#else
ignore
=
reg_a
;
ignore
=
reg_b
;
ignore
=
reg_c
;
#endif
}
};
// src: bf16, dst: fp32
template
<
index_t
MPerWave
,
index_t
NPerWave
>
struct
intrin_wmma_f32_16x16x16_bf16_w32_gfx12
;
template
<
>
struct
intrin_wmma_f32_16x16x16_bf16_w32_gfx12
<
16
,
16
>
{
template
<
class
FloatC
>
__device__
static
void
Run
(
const
bhalf8_t
&
reg_a
,
const
bhalf8_t
&
reg_b
,
FloatC
&
reg_c
)
{
#if defined(__gfx12__)
reg_c
.
template
AsType
<
float8_t
>()(
Number
<
0
>
{})
=
__builtin_amdgcn_wmma_f32_16x16x16_bf16_w32_gfx12
(
reg_a
,
reg_b
,
reg_c
.
template
AsType
<
float8_t
>()[
Number
<
0
>
{}]);
#else
ignore
=
reg_a
;
ignore
=
reg_b
;
ignore
=
reg_c
;
#endif
}
};
// src: iu8, dst: i32
template
<
index_t
MPerWave
,
index_t
NPerWave
,
bool
neg_a
,
bool
neg_b
,
bool
clamp
>
struct
intrin_wmma_i32_16x16x16_iu8_w32_gfx12
;
template
<
bool
neg_a
,
bool
neg_b
,
bool
clamp
>
struct
intrin_wmma_i32_16x16x16_iu8_w32_gfx12
<
16
,
16
,
neg_a
,
neg_b
,
clamp
>
{
template
<
class
FloatC
>
__device__
static
void
Run
(
const
int8x8_t
&
reg_a
,
const
int8x8_t
&
reg_b
,
FloatC
&
reg_c
)
{
#if defined(__gfx12__)
reg_c
.
template
AsType
<
int32x8_t
>()(
Number
<
0
>
{})
=
__builtin_amdgcn_wmma_i32_16x16x16_iu8_w32_gfx12
(
neg_a
,
bit_cast
<
int32x2_t
>
(
reg_a
),
neg_b
,
bit_cast
<
int32x2_t
>
(
reg_b
),
reg_c
.
template
AsType
<
int32x8_t
>()[
Number
<
0
>
{}],
clamp
);
#else
ignore
=
reg_a
;
ignore
=
reg_b
;
ignore
=
reg_c
;
#endif
}
};
}
// namespace ck
}
// namespace ck
#endif
#endif
Prev
1
2
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment