Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel_ROCM
Commits
909f519c
Unverified
Commit
909f519c
authored
Jun 27, 2024
by
Harisankar Sadasivan
Committed by
GitHub
Jun 27, 2024
Browse files
Merge branch 'develop' into universal_streamk
parents
406fa265
3bb0fe6c
Changes
82
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
557 additions
and
67 deletions
+557
-67
include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_dl_nhwc_kyxc_nhwk.hpp
...device/impl/device_grouped_conv_fwd_dl_nhwc_kyxc_nhwk.hpp
+2
-2
include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_d_wmma_cshuffle.hpp
...impl/device_grouped_conv_fwd_multiple_d_wmma_cshuffle.hpp
+1
-1
include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_multiple_d_dl.hpp
...ion/gpu/device/impl/device_grouped_gemm_multiple_d_dl.hpp
+4
-3
include/ck/tensor_operation/gpu/device/impl/device_grouped_query_attention_forward_wmma.hpp
...vice/impl/device_grouped_query_attention_forward_wmma.hpp
+3
-2
include/ck/tensor_operation/gpu/device/impl/device_multi_query_attention_forward_wmma.hpp
...device/impl/device_multi_query_attention_forward_wmma.hpp
+3
-2
include/ck/tensor_operation/gpu/grid/gridwise_batched_gemm_softmax_gemm_wmma_cshuffle.hpp
...grid/gridwise_batched_gemm_softmax_gemm_wmma_cshuffle.hpp
+14
-6
include/ck/tensor_operation/gpu/grid/gridwise_fpAintB_gemm_wmma.hpp
.../tensor_operation/gpu/grid/gridwise_fpAintB_gemm_wmma.hpp
+15
-7
include/ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_d_wmma_cshuffle.hpp
...ation/gpu/grid/gridwise_gemm_multiple_d_wmma_cshuffle.hpp
+26
-18
include/ck/tensor_operation/gpu/grid/gridwise_gemm_wmma.hpp
include/ck/tensor_operation/gpu/grid/gridwise_gemm_wmma.hpp
+27
-15
include/ck/tensor_operation/gpu/grid/gridwise_tensor_rearrange.hpp
...k/tensor_operation/gpu/grid/gridwise_tensor_rearrange.hpp
+3
-2
include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp
...operation/gpu/thread/threadwise_tensor_slice_transfer.hpp
+108
-1
include/ck/tensor_operation/gpu/warp/wmma_gemm.hpp
include/ck/tensor_operation/gpu/warp/wmma_gemm.hpp
+146
-1
include/ck/utility/amd_smfmac.hpp
include/ck/utility/amd_smfmac.hpp
+69
-0
include/ck/utility/amd_wmma.hpp
include/ck/utility/amd_wmma.hpp
+82
-0
include/ck/utility/data_type.hpp
include/ck/utility/data_type.hpp
+1
-1
include/ck/utility/synchronization.hpp
include/ck/utility/synchronization.hpp
+17
-0
include/ck_tile/core/config.hpp
include/ck_tile/core/config.hpp
+4
-1
include/ck_tile/host/timer.hpp
include/ck_tile/host/timer.hpp
+5
-5
include/ck_tile/ops/fmha.hpp
include/ck_tile/ops/fmha.hpp
+10
-0
include/ck_tile/ops/fmha/block/block_masking.hpp
include/ck_tile/ops/fmha/block/block_masking.hpp
+17
-0
No files found.
include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_dl_nhwc_kyxc_nhwk.hpp
View file @
909f519c
...
@@ -107,7 +107,7 @@ __global__ void
...
@@ -107,7 +107,7 @@ __global__ void
const
ComputePtrOffsetOfBatch
compute_ptr_offset_of_batch
)
const
ComputePtrOffsetOfBatch
compute_ptr_offset_of_batch
)
{
{
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx906__) || defined(__gfx103__) || \
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx906__) || defined(__gfx103__) || \
defined(__gfx11__))
defined(__gfx11__)
|| defined(__gfx12__)
)
// offset base pointer for each work-group
// offset base pointer for each work-group
const
index_t
num_blocks_per_batch
=
const
index_t
num_blocks_per_batch
=
__builtin_amdgcn_readfirstlane
(
get_grid_size
()
/
batch_count
);
__builtin_amdgcn_readfirstlane
(
get_grid_size
()
/
batch_count
);
...
@@ -603,7 +603,7 @@ struct DeviceGroupedConvFwdDl_NHWC_KYXC_NHWK : public DeviceGroupedConvFwd<NDimS
...
@@ -603,7 +603,7 @@ struct DeviceGroupedConvFwdDl_NHWC_KYXC_NHWK : public DeviceGroupedConvFwd<NDimS
// check device
// check device
if
(
!
(
ck
::
get_device_name
()
==
"gfx906"
||
ck
::
is_gfx103_supported
()
||
if
(
!
(
ck
::
get_device_name
()
==
"gfx906"
||
ck
::
is_gfx103_supported
()
||
ck
::
is_gfx11_supported
()))
ck
::
is_gfx11_supported
()
||
ck
::
is_gfx12_supported
()
))
{
{
return
false
;
return
false
;
}
}
...
...
include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_d_wmma_cshuffle.hpp
View file @
909f519c
...
@@ -582,7 +582,7 @@ struct DeviceGroupedConvFwdMultipleD_Wmma_CShuffle
...
@@ -582,7 +582,7 @@ struct DeviceGroupedConvFwdMultipleD_Wmma_CShuffle
namespace
ctc
=
tensor_layout
::
convolution
;
namespace
ctc
=
tensor_layout
::
convolution
;
// check device
// check device
if
(
ck
::
is_gfx11_supported
())
if
(
ck
::
is_gfx11_supported
()
||
ck
::
is_gfx12_supported
()
)
{
{
if
constexpr
(
!
(
is_same_v
<
AccDataType
,
float
>
||
is_same_v
<
AccDataType
,
int32_t
>
))
if
constexpr
(
!
(
is_same_v
<
AccDataType
,
float
>
||
is_same_v
<
AccDataType
,
int32_t
>
))
{
{
...
...
include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_multiple_d_dl.hpp
View file @
909f519c
...
@@ -39,8 +39,9 @@ __global__ void
...
@@ -39,8 +39,9 @@ __global__ void
const
BElementwiseOperation
b_element_op
,
const
BElementwiseOperation
b_element_op
,
const
CDEElementwiseOperation
cde_element_op
)
const
CDEElementwiseOperation
cde_element_op
)
{
{
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx906__) || defined(__gfx908__) || \
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx906__) || defined(__gfx908__) || \
defined(__gfx90a__) || defined(__gfx103__) || defined(__gfx11__) || defined(__gfx94__))
defined(__gfx90a__) || defined(__gfx103__) || defined(__gfx11__) || defined(__gfx94__) || \
defined(__gfx12__))
__shared__
char
p_shared
[
GridwiseGemm
::
GetSharedMemoryNumberOfByte
()];
__shared__
char
p_shared
[
GridwiseGemm
::
GetSharedMemoryNumberOfByte
()];
const
index_t
block_id
=
get_block_1d_id
();
const
index_t
block_id
=
get_block_1d_id
();
...
@@ -673,7 +674,7 @@ struct DeviceGroupedGemmMultipleD_Dl : public DeviceGroupedGemm<ALayout,
...
@@ -673,7 +674,7 @@ struct DeviceGroupedGemmMultipleD_Dl : public DeviceGroupedGemm<ALayout,
}
}
if
(
ck
::
get_device_name
()
==
"gfx906"
||
ck
::
is_xdl_supported
()
||
if
(
ck
::
get_device_name
()
==
"gfx906"
||
ck
::
is_xdl_supported
()
||
ck
::
is_gfx103_supported
()
||
ck
::
is_gfx11_supported
())
ck
::
is_gfx103_supported
()
||
ck
::
is_gfx11_supported
()
||
ck
::
is_gfx12_supported
()
)
{
{
for
(
std
::
size_t
i
=
0
;
i
<
arg
.
gemm_desc_kernel_arg_
.
size
();
i
++
)
for
(
std
::
size_t
i
=
0
;
i
<
arg
.
gemm_desc_kernel_arg_
.
size
();
i
++
)
{
{
...
...
include/ck/tensor_operation/gpu/device/impl/device_grouped_query_attention_forward_wmma.hpp
View file @
909f519c
...
@@ -61,7 +61,7 @@ __global__ void
...
@@ -61,7 +61,7 @@ __global__ void
bool
input_permute
,
bool
input_permute
,
bool
output_permute
)
bool
output_permute
)
{
{
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx11__))
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx11__)
|| defined(__gfx12__)
)
// clang-format off
// clang-format off
// ***************************************************
// ***************************************************
...
@@ -166,6 +166,7 @@ __global__ void
...
@@ -166,6 +166,7 @@ __global__ void
ignore
=
O
;
ignore
=
O
;
ignore
=
G0
;
ignore
=
G0
;
ignore
=
G1
;
ignore
=
G1
;
ignore
=
alpha
;
ignore
=
input_permute
;
ignore
=
input_permute
;
ignore
=
output_permute
;
ignore
=
output_permute
;
#endif // end of if (defined(__gfx11__))
#endif // end of if (defined(__gfx11__))
...
@@ -596,7 +597,7 @@ struct DeviceGroupedQueryAttentionForward_Wmma
...
@@ -596,7 +597,7 @@ struct DeviceGroupedQueryAttentionForward_Wmma
static
bool
IsSupportedArgument
(
const
RawArg
&
arg
)
static
bool
IsSupportedArgument
(
const
RawArg
&
arg
)
{
{
if
(
ck
::
is_gfx11_supported
())
if
(
ck
::
is_gfx11_supported
()
||
ck
::
is_gfx12_supported
()
)
{
{
if
constexpr
(
!
(
is_same_v
<
Acc0DataType
,
float
>
||
is_same_v
<
Acc0DataType
,
int32_t
>
))
if
constexpr
(
!
(
is_same_v
<
Acc0DataType
,
float
>
||
is_same_v
<
Acc0DataType
,
int32_t
>
))
{
{
...
...
include/ck/tensor_operation/gpu/device/impl/device_multi_query_attention_forward_wmma.hpp
View file @
909f519c
...
@@ -60,7 +60,7 @@ __global__ void
...
@@ -60,7 +60,7 @@ __global__ void
bool
input_permute
,
bool
input_permute
,
bool
output_permute
)
bool
output_permute
)
{
{
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx11__))
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx11__)
|| defined(__gfx12__)
)
// clang-format off
// clang-format off
// ***************************************************
// ***************************************************
...
@@ -165,6 +165,7 @@ __global__ void
...
@@ -165,6 +165,7 @@ __global__ void
ignore
=
O
;
ignore
=
O
;
ignore
=
G0
;
ignore
=
G0
;
ignore
=
G1
;
ignore
=
G1
;
ignore
=
alpha
;
ignore
=
input_permute
;
ignore
=
input_permute
;
ignore
=
output_permute
;
ignore
=
output_permute
;
#endif // end of if (defined(__gfx11__))
#endif // end of if (defined(__gfx11__))
...
@@ -594,7 +595,7 @@ struct DeviceMultiQueryAttentionForward_Wmma
...
@@ -594,7 +595,7 @@ struct DeviceMultiQueryAttentionForward_Wmma
static
bool
IsSupportedArgument
(
const
RawArg
&
arg
)
static
bool
IsSupportedArgument
(
const
RawArg
&
arg
)
{
{
if
(
ck
::
is_gfx11_supported
())
if
(
ck
::
is_gfx11_supported
()
||
ck
::
is_gfx12_supported
()
)
{
{
if
constexpr
(
!
(
is_same_v
<
Acc0DataType
,
float
>
||
is_same_v
<
Acc0DataType
,
int32_t
>
))
if
constexpr
(
!
(
is_same_v
<
Acc0DataType
,
float
>
||
is_same_v
<
Acc0DataType
,
int32_t
>
))
{
{
...
...
include/ck/tensor_operation/gpu/grid/gridwise_batched_gemm_softmax_gemm_wmma_cshuffle.hpp
View file @
909f519c
...
@@ -371,12 +371,16 @@ struct GridwiseBatchedGemmSoftmaxGemm_Wmma
...
@@ -371,12 +371,16 @@ struct GridwiseBatchedGemmSoftmaxGemm_Wmma
if
constexpr
(
B0EnableLds
)
if
constexpr
(
B0EnableLds
)
{
{
// BK0_L_BK1 -> BK0_LRepeat_Lwaves_LPerWmma_BK1
// BK0_L_BK1 -> BK0_LRepeat_Lwaves_LPerWmma_BK1
constexpr
auto
B_K0
=
B0BlockDesc_
{}.
GetLength
(
I0
);
constexpr
auto
B_K0
=
B0BlockDesc_
{}.
GetLength
(
I0
);
constexpr
auto
B_K1
=
B0BlockDesc_
{}.
GetLength
(
I2
);
constexpr
auto
B_K1
=
B0BlockDesc_
{}.
GetLength
(
I2
);
#ifdef __gfx12__
constexpr
auto
B_KRow
=
I2
;
#else
constexpr
auto
B_KRow
=
I1
;
constexpr
auto
B_KRow
=
I1
;
#endif
return
transform_tensor_descriptor
(
return
transform_tensor_descriptor
(
B0BlockDesc_
{},
B0BlockDesc_
{},
make_tuple
(
make_unmerge_transform
(
make_tuple
(
Number
<
B_K0
>
{},
B_KRow
)),
make_tuple
(
make_unmerge_transform
(
make_tuple
(
Number
<
B_K0
/
B_KRow
>
{},
B_KRow
)),
make_unmerge_transform
(
make_tuple
(
make_unmerge_transform
(
make_tuple
(
Number
<
LRepeat
>
{},
Number
<
LWaves
>
{},
Number
<
LPerWmma
>
{})),
Number
<
LRepeat
>
{},
Number
<
LWaves
>
{},
Number
<
LPerWmma
>
{})),
make_pass_through_transform
(
Number
<
B_K1
>
{})),
make_pass_through_transform
(
Number
<
B_K1
>
{})),
...
@@ -428,12 +432,16 @@ struct GridwiseBatchedGemmSoftmaxGemm_Wmma
...
@@ -428,12 +432,16 @@ struct GridwiseBatchedGemmSoftmaxGemm_Wmma
if
constexpr
(
B1EnableLds
)
if
constexpr
(
B1EnableLds
)
{
{
// BL0_N_BL1 -> BL0_NRepeat_Nwaves_NPerWmma_BL1
// BL0_N_BL1 -> BL0_NRepeat_Nwaves_NPerWmma_BL1
constexpr
auto
B_L0
=
B1BlockDesc_
{}.
GetLength
(
I0
);
constexpr
auto
B_L0
=
B1BlockDesc_
{}.
GetLength
(
I0
);
constexpr
auto
B_L1
=
B1BlockDesc_
{}.
GetLength
(
I2
);
constexpr
auto
B_L1
=
B1BlockDesc_
{}.
GetLength
(
I2
);
#ifdef __gfx12__
constexpr
auto
B_LRow
=
I2
;
#else
constexpr
auto
B_LRow
=
I1
;
constexpr
auto
B_LRow
=
I1
;
#endif
return
transform_tensor_descriptor
(
return
transform_tensor_descriptor
(
B1BlockDesc_
{},
B1BlockDesc_
{},
make_tuple
(
make_unmerge_transform
(
make_tuple
(
Number
<
B_L0
>
{},
B_LRow
)),
make_tuple
(
make_unmerge_transform
(
make_tuple
(
Number
<
B_L0
/
B_LRow
>
{},
B_LRow
)),
make_unmerge_transform
(
make_tuple
(
make_unmerge_transform
(
make_tuple
(
Number
<
NRepeat
>
{},
Number
<
NWaves
>
{},
Number
<
NPerWmma
>
{})),
Number
<
NRepeat
>
{},
Number
<
NWaves
>
{},
Number
<
NPerWmma
>
{})),
make_pass_through_transform
(
Number
<
B_L1
>
{})),
make_pass_through_transform
(
Number
<
B_L1
>
{})),
...
...
include/ck/tensor_operation/gpu/grid/gridwise_fpAintB_gemm_wmma.hpp
View file @
909f519c
...
@@ -50,7 +50,7 @@ __global__ void
...
@@ -50,7 +50,7 @@ __global__ void
const
CElementwiseOperation
c_element_op
,
const
CElementwiseOperation
c_element_op
,
const
Block2CTileMap
block_2_ctile_map
)
const
Block2CTileMap
block_2_ctile_map
)
{
{
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx11__))
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx11__)
|| defined(__gfx12__)
)
__shared__
char
p_shared
[
GridwiseGemm
::
SharedMemTrait
::
lds_size
];
__shared__
char
p_shared
[
GridwiseGemm
::
SharedMemTrait
::
lds_size
];
GridwiseGemm
::
template
Run
<
HasMainKBlockLoop
>(
p_a_grid
,
GridwiseGemm
::
template
Run
<
HasMainKBlockLoop
>(
p_a_grid
,
...
@@ -302,12 +302,16 @@ struct GridwiseFpAintBGemm_Wmma
...
@@ -302,12 +302,16 @@ struct GridwiseFpAintBGemm_Wmma
if
constexpr
(
AEnableLds
)
if
constexpr
(
AEnableLds
)
{
{
// AK0_M_AK1 -> AK0_MRepeat_Mwaves_AKRow_MPerWmma_AK1
// AK0_M_AK1 -> AK0_MRepeat_Mwaves_AKRow_MPerWmma_AK1
constexpr
auto
A_K0
=
ABlockDesc_
{}.
GetLength
(
I0
);
constexpr
auto
A_K0
=
ABlockDesc_
{}.
GetLength
(
I0
);
constexpr
auto
A_K1
=
ABlockDesc_
{}.
GetLength
(
I2
);
constexpr
auto
A_K1
=
ABlockDesc_
{}.
GetLength
(
I2
);
#ifdef __gfx12__
constexpr
auto
A_KRow
=
I2
;
#else
constexpr
auto
A_KRow
=
I1
;
constexpr
auto
A_KRow
=
I1
;
#endif
return
transform_tensor_descriptor
(
return
transform_tensor_descriptor
(
ABlockDesc_
{},
ABlockDesc_
{},
make_tuple
(
make_unmerge_transform
(
make_tuple
(
Number
<
A_K0
>
{},
A_KRow
)),
make_tuple
(
make_unmerge_transform
(
make_tuple
(
Number
<
A_K0
/
A_KRow
>
{},
A_KRow
)),
make_unmerge_transform
(
make_tuple
(
make_unmerge_transform
(
make_tuple
(
Number
<
MRepeat
>
{},
Number
<
MWaves
>
{},
Number
<
MPerWmma
>
{})),
Number
<
MRepeat
>
{},
Number
<
MWaves
>
{},
Number
<
MPerWmma
>
{})),
make_pass_through_transform
(
Number
<
A_K1
>
{})),
make_pass_through_transform
(
Number
<
A_K1
>
{})),
...
@@ -360,12 +364,16 @@ struct GridwiseFpAintBGemm_Wmma
...
@@ -360,12 +364,16 @@ struct GridwiseFpAintBGemm_Wmma
if
constexpr
(
BEnableLds
)
if
constexpr
(
BEnableLds
)
{
{
// BK0_N_BK1 -> BK0_NRepeat_Nwaves_NPerWmma_BK1
// BK0_N_BK1 -> BK0_NRepeat_Nwaves_NPerWmma_BK1
constexpr
auto
B_K0
=
BBlockDesc_
{}.
GetLength
(
I0
);
constexpr
auto
B_K0
=
BBlockDesc_
{}.
GetLength
(
I0
);
constexpr
auto
B_K1
=
BBlockDesc_
{}.
GetLength
(
I2
);
constexpr
auto
B_K1
=
BBlockDesc_
{}.
GetLength
(
I2
);
#ifdef __gfx12__
constexpr
auto
B_KRow
=
I2
;
#else
constexpr
auto
B_KRow
=
I1
;
constexpr
auto
B_KRow
=
I1
;
#endif
return
transform_tensor_descriptor
(
return
transform_tensor_descriptor
(
BBlockDesc_
{},
BBlockDesc_
{},
make_tuple
(
make_unmerge_transform
(
make_tuple
(
Number
<
B_K0
>
{},
B_KRow
)),
make_tuple
(
make_unmerge_transform
(
make_tuple
(
Number
<
B_K0
/
B_KRow
>
{},
B_KRow
)),
make_unmerge_transform
(
make_tuple
(
make_unmerge_transform
(
make_tuple
(
Number
<
NRepeat
>
{},
Number
<
NWaves
>
{},
Number
<
NPerWmma
>
{})),
Number
<
NRepeat
>
{},
Number
<
NWaves
>
{},
Number
<
NPerWmma
>
{})),
make_pass_through_transform
(
Number
<
B_K1
>
{})),
make_pass_through_transform
(
Number
<
B_K1
>
{})),
...
...
include/ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_d_wmma_cshuffle.hpp
View file @
909f519c
...
@@ -54,7 +54,7 @@ __global__ void
...
@@ -54,7 +54,7 @@ __global__ void
const
Block2CTileMap
block_2_ctile_map
,
const
Block2CTileMap
block_2_ctile_map
,
const
ComputePtrOffsetOfBatch
compute_ptr_offset_of_batch
)
const
ComputePtrOffsetOfBatch
compute_ptr_offset_of_batch
)
{
{
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx11__))
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx11__)
|| defined(__gfx12__)
)
// offset base pointer for each work-group
// offset base pointer for each work-group
const
index_t
num_blocks_per_batch
=
const
index_t
num_blocks_per_batch
=
__builtin_amdgcn_readfirstlane
(
get_grid_size
()
/
batch_count
);
__builtin_amdgcn_readfirstlane
(
get_grid_size
()
/
batch_count
);
...
@@ -147,7 +147,7 @@ __global__ void
...
@@ -147,7 +147,7 @@ __global__ void
const
ComputePtrOffsetOfBatch
compute_ptr_offset_of_batch
,
const
ComputePtrOffsetOfBatch
compute_ptr_offset_of_batch
,
const
Block2CTileMap
block_2_etile_map
)
const
Block2CTileMap
block_2_etile_map
)
{
{
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx11__))
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx11__)
|| defined(__gfx12__)
)
// printf("entry kernel launch");
// printf("entry kernel launch");
__shared__
char
p_shared
[
GridwiseOp
::
SharedMemTrait
::
lds_size
];
__shared__
char
p_shared
[
GridwiseOp
::
SharedMemTrait
::
lds_size
];
...
@@ -237,7 +237,7 @@ __global__ void
...
@@ -237,7 +237,7 @@ __global__ void
const
CDEElementwiseOperation
cde_element_op
,
const
CDEElementwiseOperation
cde_element_op
,
const
Block2CTileMap
block_2_ctile_map
)
const
Block2CTileMap
block_2_ctile_map
)
{
{
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx11__))
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx11__)
|| defined(__gfx12__)
)
__shared__
char
p_shared
[
GridwiseOp
::
SharedMemTrait
::
lds_size
];
__shared__
char
p_shared
[
GridwiseOp
::
SharedMemTrait
::
lds_size
];
GridwiseOp
::
template
Run
<
HasMainKBlockLoop
>(
p_a_grid
,
GridwiseOp
::
template
Run
<
HasMainKBlockLoop
>(
p_a_grid
,
...
@@ -375,8 +375,9 @@ struct GridwiseGemmMultipleD_Wmma
...
@@ -375,8 +375,9 @@ struct GridwiseGemmMultipleD_Wmma
}
}
else
else
{
{
constexpr
auto
A_KRow
=
I2
;
constexpr
auto
KWmmaPerblock
=
KPerBlock
/
WmmaK
;
constexpr
auto
KWmmaPerblock
=
KPerBlock
/
WmmaK
;
constexpr
auto
K0PerWmma
=
WmmaK
/
2
/
K1
;
constexpr
auto
K0PerWmma
=
WmmaK
/
A_KRow
/
K1
;
// KWmma->MRepeat->MWave->K0PerWmma->KRow->MPerWmma->K1 Per Thread
// KWmma->MRepeat->MWave->K0PerWmma->KRow->MPerWmma->K1 Per Thread
return
make_naive_tensor_descriptor
(
return
make_naive_tensor_descriptor
(
make_tuple
(
Number
<
KWmmaPerblock
>
{},
make_tuple
(
Number
<
KWmmaPerblock
>
{},
...
@@ -422,8 +423,9 @@ struct GridwiseGemmMultipleD_Wmma
...
@@ -422,8 +423,9 @@ struct GridwiseGemmMultipleD_Wmma
}
}
else
else
{
{
constexpr
auto
B_KRow
=
I2
;
constexpr
auto
KWmmaPerblock
=
KPerBlock
/
WmmaK
;
constexpr
auto
KWmmaPerblock
=
KPerBlock
/
WmmaK
;
constexpr
auto
K0PerWmma
=
WmmaK
/
2
/
K1
;
constexpr
auto
K0PerWmma
=
WmmaK
/
B_KRow
/
K1
;
// KWmma->NRepeat->MWave->K0PerWmma->KRow->MPerWmma->K1 Per Thread
// KWmma->NRepeat->MWave->K0PerWmma->KRow->MPerWmma->K1 Per Thread
return
make_naive_tensor_descriptor
(
return
make_naive_tensor_descriptor
(
make_tuple
(
Number
<
KWmmaPerblock
>
{},
make_tuple
(
Number
<
KWmmaPerblock
>
{},
...
@@ -495,12 +497,16 @@ struct GridwiseGemmMultipleD_Wmma
...
@@ -495,12 +497,16 @@ struct GridwiseGemmMultipleD_Wmma
if
constexpr
(
AEnableLds
)
if
constexpr
(
AEnableLds
)
{
{
// AK0_M_AK1 -> AK0_MRepeat_Mwaves_AKRow_MPerWmma_AK1
// AK0_M_AK1 -> AK0_MRepeat_Mwaves_AKRow_MPerWmma_AK1
constexpr
auto
A_K0
=
ABlockDesc_
{}.
GetLength
(
I0
);
constexpr
auto
A_K0
=
ABlockDesc_
{}.
GetLength
(
I0
);
constexpr
auto
A_K1
=
ABlockDesc_
{}.
GetLength
(
I2
);
constexpr
auto
A_K1
=
ABlockDesc_
{}.
GetLength
(
I2
);
#ifdef __gfx12__
constexpr
auto
A_KRow
=
I2
;
#else
constexpr
auto
A_KRow
=
I1
;
constexpr
auto
A_KRow
=
I1
;
#endif
return
transform_tensor_descriptor
(
return
transform_tensor_descriptor
(
ABlockDesc_
{},
ABlockDesc_
{},
make_tuple
(
make_unmerge_transform
(
make_tuple
(
Number
<
A_K0
>
{},
A_KRow
)),
make_tuple
(
make_unmerge_transform
(
make_tuple
(
Number
<
A_K0
/
A_KRow
>
{},
A_KRow
)),
make_unmerge_transform
(
make_tuple
(
make_unmerge_transform
(
make_tuple
(
Number
<
MRepeat
>
{},
Number
<
MWaves
>
{},
Number
<
MPerWmma
>
{})),
Number
<
MRepeat
>
{},
Number
<
MWaves
>
{},
Number
<
MPerWmma
>
{})),
make_pass_through_transform
(
Number
<
A_K1
>
{})),
make_pass_through_transform
(
Number
<
A_K1
>
{})),
...
@@ -534,12 +540,16 @@ struct GridwiseGemmMultipleD_Wmma
...
@@ -534,12 +540,16 @@ struct GridwiseGemmMultipleD_Wmma
if
constexpr
(
BEnableLds
)
if
constexpr
(
BEnableLds
)
{
{
// BK0_N_BK1 -> BK0_NRepeat_Nwaves_NPerWmma_BK1
// BK0_N_BK1 -> BK0_NRepeat_Nwaves_NPerWmma_BK1
constexpr
auto
B_K0
=
BBlockDesc_
{}.
GetLength
(
I0
);
constexpr
auto
B_K0
=
BBlockDesc_
{}.
GetLength
(
I0
);
constexpr
auto
B_K1
=
BBlockDesc_
{}.
GetLength
(
I2
);
constexpr
auto
B_K1
=
BBlockDesc_
{}.
GetLength
(
I2
);
#ifdef __gfx12__
constexpr
auto
B_KRow
=
I2
;
#else
constexpr
auto
B_KRow
=
I1
;
constexpr
auto
B_KRow
=
I1
;
#endif
return
transform_tensor_descriptor
(
return
transform_tensor_descriptor
(
BBlockDesc_
{},
BBlockDesc_
{},
make_tuple
(
make_unmerge_transform
(
make_tuple
(
Number
<
B_K0
>
{},
B_KRow
)),
make_tuple
(
make_unmerge_transform
(
make_tuple
(
Number
<
B_K0
/
B_KRow
>
{},
B_KRow
)),
make_unmerge_transform
(
make_tuple
(
make_unmerge_transform
(
make_tuple
(
Number
<
NRepeat
>
{},
Number
<
NWaves
>
{},
Number
<
NPerWmma
>
{})),
Number
<
NRepeat
>
{},
Number
<
NWaves
>
{},
Number
<
NPerWmma
>
{})),
make_pass_through_transform
(
Number
<
B_K1
>
{})),
make_pass_through_transform
(
Number
<
B_K1
>
{})),
...
@@ -571,15 +581,12 @@ struct GridwiseGemmMultipleD_Wmma
...
@@ -571,15 +581,12 @@ struct GridwiseGemmMultipleD_Wmma
// *Caution Here repeat is shuffle repeat
// *Caution Here repeat is shuffle repeat
GetCShuffleBlockDescriptor_MShRepeat_MPerShRepeat_NShRepeat_NPerShRepeat
()
GetCShuffleBlockDescriptor_MShRepeat_MPerShRepeat_NShRepeat_NPerShRepeat
()
{
{
constexpr
index_t
MWave
=
MPerBlock
/
(
MRepeat
*
MPerWmma
);
constexpr
index_t
NWave
=
NPerBlock
/
(
NRepeat
*
NPerWmma
);
constexpr
auto
c_shuffle_block_desc_mshrepeat_mpershrepeat_nshrepeat_npershrepeat
=
constexpr
auto
c_shuffle_block_desc_mshrepeat_mpershrepeat_nshrepeat_npershrepeat
=
make_naive_tensor_descriptor_packed
(
make_naive_tensor_descriptor_packed
(
make_tuple
(
I1
,
make_tuple
(
I1
,
Number
<
CShuffleMRepeatPerShuffle
*
MWave
*
MPerWmma
>
{},
Number
<
CShuffleMRepeatPerShuffle
*
MWave
s
*
MPerWmma
>
{},
I1
,
I1
,
Number
<
CShuffleNRepeatPerShuffle
*
NWave
*
NPerWmma
>
{}));
Number
<
CShuffleNRepeatPerShuffle
*
NWave
s
*
NPerWmma
>
{}));
return
c_shuffle_block_desc_mshrepeat_mpershrepeat_nshrepeat_npershrepeat
;
return
c_shuffle_block_desc_mshrepeat_mpershrepeat_nshrepeat_npershrepeat
;
}
}
...
@@ -799,8 +806,9 @@ struct GridwiseGemmMultipleD_Wmma
...
@@ -799,8 +806,9 @@ struct GridwiseGemmMultipleD_Wmma
const
auto
M
=
e_grid_desc_m_n
.
GetLength
(
I0
);
const
auto
M
=
e_grid_desc_m_n
.
GetLength
(
I0
);
const
auto
N
=
e_grid_desc_m_n
.
GetLength
(
I1
);
const
auto
N
=
e_grid_desc_m_n
.
GetLength
(
I1
);
const
auto
MBlock
=
M
/
MPerBlock
;
const
auto
MBlock
=
M
/
MPerBlock
;
const
auto
NBlock
=
N
/
NPerBlock
;
const
auto
NBlock
=
N
/
NPerBlock
;
const
auto
e_grid_desc_mblock_mperblock_nblock_nperblock
=
transform_tensor_descriptor
(
const
auto
e_grid_desc_mblock_mperblock_nblock_nperblock
=
transform_tensor_descriptor
(
e_grid_desc_m_n
,
e_grid_desc_m_n
,
make_tuple
(
make_unmerge_transform
(
make_tuple
(
MBlock
,
Number
<
MPerBlock
>
{})),
make_tuple
(
make_unmerge_transform
(
make_tuple
(
MBlock
,
Number
<
MPerBlock
>
{})),
...
...
include/ck/tensor_operation/gpu/grid/gridwise_gemm_wmma.hpp
View file @
909f519c
...
@@ -45,7 +45,7 @@ __global__ void
...
@@ -45,7 +45,7 @@ __global__ void
const
CElementwiseOperation
c_element_op
,
const
CElementwiseOperation
c_element_op
,
const
Block2CTileMap
block_2_ctile_map
)
const
Block2CTileMap
block_2_ctile_map
)
{
{
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx11__))
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx11__)
|| defined(__gfx12__)
)
__shared__
char
p_shared
[
GridwiseGemm
::
SharedMemTrait
::
lds_size
];
__shared__
char
p_shared
[
GridwiseGemm
::
SharedMemTrait
::
lds_size
];
GridwiseGemm
::
template
Run
<
HasMainKBlockLoop
>(
p_a_grid
,
GridwiseGemm
::
template
Run
<
HasMainKBlockLoop
>(
p_a_grid
,
...
@@ -170,8 +170,9 @@ struct GridwiseGemm_Wmma
...
@@ -170,8 +170,9 @@ struct GridwiseGemm_Wmma
}
}
else
else
{
{
constexpr
auto
A_KRow
=
I2
;
constexpr
auto
KWmmaPerblock
=
KPerBlock
/
WmmaK
;
constexpr
auto
KWmmaPerblock
=
KPerBlock
/
WmmaK
;
constexpr
auto
K0PerWmma
=
WmmaK
/
2
/
K1
;
constexpr
auto
K0PerWmma
=
WmmaK
/
A_KRow
/
K1
;
// KWmma->MRepeat->MWave->K0PerWmma->KRow->MPerWmma->K1 Per Thread
// KWmma->MRepeat->MWave->K0PerWmma->KRow->MPerWmma->K1 Per Thread
return
make_naive_tensor_descriptor
(
return
make_naive_tensor_descriptor
(
make_tuple
(
Number
<
KWmmaPerblock
>
{},
make_tuple
(
Number
<
KWmmaPerblock
>
{},
...
@@ -217,8 +218,10 @@ struct GridwiseGemm_Wmma
...
@@ -217,8 +218,10 @@ struct GridwiseGemm_Wmma
}
}
else
else
{
{
constexpr
auto
B_KRow
=
I2
;
constexpr
auto
KWmmaPerblock
=
KPerBlock
/
WmmaK
;
constexpr
auto
KWmmaPerblock
=
KPerBlock
/
WmmaK
;
constexpr
auto
K0PerWmma
=
WmmaK
/
2
/
K1
;
constexpr
auto
K0PerWmma
=
WmmaK
/
B_KRow
/
K1
;
// KWmma->NRepeat->MWave->K0PerWmma->KRow->MPerWmma->K1 Per Thread
// KWmma->NRepeat->MWave->K0PerWmma->KRow->MPerWmma->K1 Per Thread
return
make_naive_tensor_descriptor
(
return
make_naive_tensor_descriptor
(
make_tuple
(
Number
<
KWmmaPerblock
>
{},
make_tuple
(
Number
<
KWmmaPerblock
>
{},
...
@@ -290,12 +293,17 @@ struct GridwiseGemm_Wmma
...
@@ -290,12 +293,17 @@ struct GridwiseGemm_Wmma
if
constexpr
(
AEnableLds
)
if
constexpr
(
AEnableLds
)
{
{
// AK0_M_AK1 -> AK0_MRepeat_Mwaves_AKRow_MPerWmma_AK1
// AK0_M_AK1 -> AK0_MRepeat_Mwaves_AKRow_MPerWmma_AK1
constexpr
auto
A_K0
=
ABlockDesc_
{}.
GetLength
(
I0
);
constexpr
auto
A_K0
=
ABlockDesc_
{}.
GetLength
(
I0
);
constexpr
auto
A_K1
=
ABlockDesc_
{}.
GetLength
(
I2
);
constexpr
auto
A_K1
=
ABlockDesc_
{}.
GetLength
(
I2
);
#ifdef __gfx12__
constexpr
auto
A_KRow
=
I2
;
#else
constexpr
auto
A_KRow
=
I1
;
constexpr
auto
A_KRow
=
I1
;
#endif
return
transform_tensor_descriptor
(
return
transform_tensor_descriptor
(
ABlockDesc_
{},
ABlockDesc_
{},
make_tuple
(
make_unmerge_transform
(
make_tuple
(
Number
<
A_K0
>
{},
A_KRow
)),
make_tuple
(
make_unmerge_transform
(
make_tuple
(
Number
<
A_K0
/
A_KRow
>
{},
A_KRow
)),
make_unmerge_transform
(
make_tuple
(
make_unmerge_transform
(
make_tuple
(
Number
<
MRepeat
>
{},
Number
<
MWaves
>
{},
Number
<
MPerWmma
>
{})),
Number
<
MRepeat
>
{},
Number
<
MWaves
>
{},
Number
<
MPerWmma
>
{})),
make_pass_through_transform
(
Number
<
A_K1
>
{})),
make_pass_through_transform
(
Number
<
A_K1
>
{})),
...
@@ -348,12 +356,16 @@ struct GridwiseGemm_Wmma
...
@@ -348,12 +356,16 @@ struct GridwiseGemm_Wmma
if
constexpr
(
BEnableLds
)
if
constexpr
(
BEnableLds
)
{
{
// BK0_N_BK1 -> BK0_NRepeat_Nwaves_NPerWmma_BK1
// BK0_N_BK1 -> BK0_NRepeat_Nwaves_NPerWmma_BK1
constexpr
auto
B_K0
=
BBlockDesc_
{}.
GetLength
(
I0
);
constexpr
auto
B_K0
=
BBlockDesc_
{}.
GetLength
(
I0
);
constexpr
auto
B_K1
=
BBlockDesc_
{}.
GetLength
(
I2
);
constexpr
auto
B_K1
=
BBlockDesc_
{}.
GetLength
(
I2
);
#ifdef __gfx12__
constexpr
auto
B_KRow
=
I2
;
#else
constexpr
auto
B_KRow
=
I1
;
constexpr
auto
B_KRow
=
I1
;
#endif
return
transform_tensor_descriptor
(
return
transform_tensor_descriptor
(
BBlockDesc_
{},
BBlockDesc_
{},
make_tuple
(
make_unmerge_transform
(
make_tuple
(
Number
<
B_K0
>
{},
B_KRow
)),
make_tuple
(
make_unmerge_transform
(
make_tuple
(
Number
<
B_K0
/
B_KRow
>
{},
B_KRow
)),
make_unmerge_transform
(
make_tuple
(
make_unmerge_transform
(
make_tuple
(
Number
<
NRepeat
>
{},
Number
<
NWaves
>
{},
Number
<
NPerWmma
>
{})),
Number
<
NRepeat
>
{},
Number
<
NWaves
>
{},
Number
<
NPerWmma
>
{})),
make_pass_through_transform
(
Number
<
B_K1
>
{})),
make_pass_through_transform
(
Number
<
B_K1
>
{})),
...
@@ -522,12 +534,6 @@ struct GridwiseGemm_Wmma
...
@@ -522,12 +534,6 @@ struct GridwiseGemm_Wmma
c_grid_desc_m_n
);
c_grid_desc_m_n
);
}
}
using
CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
=
remove_cvref_t
<
decltype
(
MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
(
CGridDesc_M_N
{}))
>
;
using
DefaultBlock2CTileMap
=
remove_cvref_t
<
decltype
(
MakeDefaultBlock2CTileMap
(
CGridDesc_M_N
{},
1
,
1
))
>
;
struct
SharedMemTrait
struct
SharedMemTrait
{
{
// LDS allocation for A and B: be careful of alignment
// LDS allocation for A and B: be careful of alignment
...
@@ -559,6 +565,12 @@ struct GridwiseGemm_Wmma
...
@@ -559,6 +565,12 @@ struct GridwiseGemm_Wmma
b_block_space_size_aligned
*
sizeof
(
BDataType
));
b_block_space_size_aligned
*
sizeof
(
BDataType
));
};
};
using
CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
=
remove_cvref_t
<
decltype
(
MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
(
CGridDesc_M_N
{}))
>
;
using
DefaultBlock2CTileMap
=
remove_cvref_t
<
decltype
(
MakeDefaultBlock2CTileMap
(
CGridDesc_M_N
{},
1
,
1
))
>
;
template
<
bool
HasMainKBlockLoop
,
typename
Block2CTileMap
=
DefaultBlock2CTileMap
>
template
<
bool
HasMainKBlockLoop
,
typename
Block2CTileMap
=
DefaultBlock2CTileMap
>
__device__
static
void
Run
(
const
ADataType
*
__restrict__
p_a_grid
,
__device__
static
void
Run
(
const
ADataType
*
__restrict__
p_a_grid
,
const
BDataType
*
__restrict__
p_b_grid
,
const
BDataType
*
__restrict__
p_b_grid
,
...
...
include/ck/tensor_operation/gpu/grid/gridwise_tensor_rearrange.hpp
View file @
909f519c
...
@@ -35,8 +35,9 @@ __global__ void
...
@@ -35,8 +35,9 @@ __global__ void
const
Block2ETileMap
block_2_tile_map
,
const
Block2ETileMap
block_2_tile_map
,
const
ComputePtrOffsetOfStridedBatch
compute_ptr_offset_of_batch
)
const
ComputePtrOffsetOfStridedBatch
compute_ptr_offset_of_batch
)
{
{
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx906__) || defined(__gfx908__) || \
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx906__) || defined(__gfx908__) || \
defined(__gfx90a__) || defined(__gfx94__) || defined(__gfx103__) || defined(__gfx11__))
defined(__gfx90a__) || defined(__gfx94__) || defined(__gfx103__) || defined(__gfx11__) || \
defined(__gfx12__))
GridwiseTensorRearrangeKernel
::
Run
(
in_grid_desc
,
GridwiseTensorRearrangeKernel
::
Run
(
in_grid_desc
,
p_in_global
,
p_in_global
,
out_grid_desc
,
out_grid_desc
,
...
...
include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp
View file @
909f519c
...
@@ -1304,7 +1304,7 @@ struct ThreadwiseTensorSliceTransfer_StaticToStatic
...
@@ -1304,7 +1304,7 @@ struct ThreadwiseTensorSliceTransfer_StaticToStatic
ElementwiseOperation
element_op_
;
ElementwiseOperation
element_op_
;
};
};
// Specilized for WMMA
// Specilized for WMMA
-Navi3
// A single Wave32 is composed by double row
// A single Wave32 is composed by double row
// Data exchange allowed between these two rows
// Data exchange allowed between these two rows
// This RowLane Dst buf will be filled from two Src buf
// This RowLane Dst buf will be filled from two Src buf
...
@@ -1439,4 +1439,111 @@ struct ThreadwiseTensorSliceTransfer_StaticToStatic_InterRow
...
@@ -1439,4 +1439,111 @@ struct ThreadwiseTensorSliceTransfer_StaticToStatic_InterRow
ElementwiseOperation
element_op_
{};
ElementwiseOperation
element_op_
{};
};
};
// Specilized for WMMA-Navi4
template
<
typename
SrcData
,
typename
DstData
,
typename
SrcDesc
,
typename
DstDesc
,
typename
ElementwiseOperation
,
typename
SliceLengths
,
typename
DimAccessOrder
,
index_t
DstVectorDim
,
index_t
DstScalarPerVector
,
bool
IntraRowSwizzlePerm
,
typename
enable_if
<
SrcDesc
::
IsKnownAtCompileTime
()
&&
DstDesc
::
IsKnownAtCompileTime
(),
bool
>
::
type
=
false
>
struct
ThreadwiseTensorSliceTransfer_StaticToStatic_IntraRow
{
static
constexpr
index_t
nDim
=
SliceLengths
::
Size
();
using
Index
=
MultiIndex
<
nDim
>
;
__device__
constexpr
ThreadwiseTensorSliceTransfer_StaticToStatic_IntraRow
(
const
Index
&
src_idx
)
{
static_assert
(
SrcDesc
::
IsKnownAtCompileTime
()
&&
DstDesc
::
IsKnownAtCompileTime
(),
"wrong! Desc need to known at compile-time"
);
static_assert
(
SliceLengths
::
At
(
Number
<
DstVectorDim
>
{})
%
DstScalarPerVector
==
0
,
"wrong! Not divisible"
);
ignore
=
src_idx
;
}
template
<
typename
SrcSliceOriginIdx
,
typename
DstSliceOriginIdx
,
typename
SrcBuffer
,
typename
DstBuffer
>
__device__
void
Run
(
const
SrcDesc
&
,
const
SrcSliceOriginIdx
&
,
const
SrcBuffer
&
src_buf
,
const
DstDesc
&
,
const
DstSliceOriginIdx
&
,
DstBuffer
&
dst_buf
)
const
{
static_assert
(
SrcDesc
::
IsKnownAtCompileTime
()
&&
DstDesc
::
IsKnownAtCompileTime
(),
"wrong! Desc need to known at compile-time"
);
static_assert
(
is_known_at_compile_time
<
remove_cvref_t
<
SrcSliceOriginIdx
>>::
value
&&
is_known_at_compile_time
<
remove_cvref_t
<
DstSliceOriginIdx
>>::
value
,
"wrong! SliceOrigin need to known at compile-time"
);
static_assert
(
SrcBuffer
::
IsStaticBuffer
()
&&
DstBuffer
::
IsStaticBuffer
(),
"wrong! Buffer need to be StaticBuffer"
);
// SrcDesc and src_slice_origin_idx are known at compile-time
constexpr
auto
src_desc
=
remove_cvref_t
<
SrcDesc
>
{};
constexpr
auto
dst_desc
=
remove_cvref_t
<
DstDesc
>
{};
constexpr
auto
src_slice_origin_idx
=
to_multi_index
(
SrcSliceOriginIdx
{});
constexpr
auto
dst_slice_origin_idx
=
to_multi_index
(
DstSliceOriginIdx
{});
// scalar per access on each dim
constexpr
auto
dst_scalar_per_access
=
generate_sequence
(
detail
::
lambda_scalar_per_access
<
DstVectorDim
,
DstScalarPerVector
>
{},
Number
<
nDim
>
{});
constexpr
auto
dst_scalar_step_in_vector
=
generate_sequence
(
detail
::
lambda_scalar_step_in_vector
<
DstVectorDim
>
{},
Number
<
nDim
>
{});
using
SpaceFillingCurve
=
SpaceFillingCurve
<
SliceLengths
,
DimAccessOrder
,
remove_cv_t
<
decltype
(
dst_scalar_per_access
)
>>
;
static_assert
(
DstScalarPerVector
==
SpaceFillingCurve
::
ScalarPerVector
,
"wrong!DstScalarPerVector != SpaceFillingCurve::ScalarPerVector"
);
constexpr
auto
num_access
=
SpaceFillingCurve
::
GetNumOfAccess
();
static_for
<
0
,
num_access
,
1
>
{}([
&
](
auto
idx_1d
)
{
constexpr
auto
idx_md
=
SpaceFillingCurve
::
GetIndex
(
idx_1d
);
// copy data from src_buf into dst_vector
static_for
<
0
,
DstScalarPerVector
,
1
>
{}([
&
](
auto
i
)
{
// src_desc error, non constexpr, caused by merge transform
constexpr
index_t
src_offset
=
src_desc
.
CalculateOffset
(
src_slice_origin_idx
+
idx_md
+
i
*
dst_scalar_step_in_vector
);
constexpr
index_t
dst_offset
=
dst_desc
.
CalculateOffset
(
dst_slice_origin_idx
+
idx_md
+
i
*
dst_scalar_step_in_vector
);
SrcData
v_this_row
;
// int type temp value due to intrinsic requirement
int
temp
=
0
;
// apply element-wise operation
element_op_
(
v_this_row
,
src_buf
[
Number
<
src_offset
>
{}]);
// apply intra-row permute.
if
constexpr
(
IntraRowSwizzlePerm
)
{
temp
=
__builtin_amdgcn_permlane16
(
temp
,
type_convert_sp
<
int
>
(
v_this_row
),
0xb3a29180
,
0xf7e6d5c4
,
1
,
0
);
v_this_row
=
type_convert_sp
<
SrcData
>
(
temp
);
}
// apply type convert
dst_buf
(
Number
<
dst_offset
>
{})
=
type_convert_sp
<
DstData
>
(
v_this_row
);
});
});
}
ElementwiseOperation
element_op_
{};
};
}
// namespace ck
}
// namespace ck
include/ck/tensor_operation/gpu/warp/wmma_gemm.hpp
View file @
909f519c
...
@@ -11,12 +11,17 @@ namespace ck {
...
@@ -11,12 +11,17 @@ namespace ck {
enum
struct
WmmaInstr
enum
struct
WmmaInstr
{
{
// gfx11
wmma_f32_16x16x16_f16
=
0
,
wmma_f32_16x16x16_f16
=
0
,
wmma_f32_16x16x16_bf16
,
wmma_f32_16x16x16_bf16
,
wmma_f16_16x16x16_f16
,
wmma_f16_16x16x16_f16
,
wmma_bf16_16x16x16_bf16
,
wmma_bf16_16x16x16_bf16
,
wmma_i32_16x16x16_iu8
,
wmma_i32_16x16x16_iu8
,
wmma_i32_16x16x16_iu4
wmma_i32_16x16x16_iu4
,
// gfx12
wmma_f32_16x16x16_f16_gfx12
,
wmma_f32_16x16x16_bf16_gfx12
,
wmma_i32_16x16x16_iu8_gfx12
,
};
};
/*
/*
...
@@ -279,6 +284,122 @@ struct wmma_type<WmmaInstr::wmma_i32_16x16x16_iu8,
...
@@ -279,6 +284,122 @@ struct wmma_type<WmmaInstr::wmma_i32_16x16x16_iu8,
}
}
};
};
// gfx12
// A-swizzled
template
<
index_t
WaveSize
>
struct
wmma_type
<
WmmaInstr
::
wmma_f32_16x16x16_f16_gfx12
,
WaveSize
,
typename
std
::
enable_if_t
<
WaveSize
==
32
||
WaveSize
==
64
>>
{
// Absolute fixing property
// * Data Pixel
static
constexpr
index_t
m_per_wmma
=
16
;
static
constexpr
index_t
n_per_wmma
=
16
;
static
constexpr
index_t
k_per_wmma
=
16
;
// static constexpr index_t src_a_data_size = 2;
// static constexpr index_t src_b_data_size = 2;
// static constexpr index_t acc_data_size = 4;
// * Thread mapping inside wave, num_thread_per_subgroups always alone N direction
static
constexpr
index_t
acc_data_size
=
4
;
static
constexpr
index_t
acc_pack_number
=
1
;
static
constexpr
index_t
num_thread_per_subgroups
=
n_per_wmma
;
// Wave mode dependent propety
static
constexpr
index_t
wave_size
=
Number
<
WaveSize
>
{};
// * Fixed in Navi3x, Will be wave mode dependent on Navi4x
// static constexpr index_t num_src_a_vgprs_per_wave = k_per_wmma / 2 * src_a_data_size / 4;
// static constexpr index_t num_src_b_vgprs_per_wave = k_per_wmma / 2 * src_b_data_size / 4;
// * num_acc_vgprs_per_wave alone M direction
// * num_subgroups alone M direction
static
constexpr
index_t
num_acc_vgprs_per_wave
=
m_per_wmma
*
n_per_wmma
/
wave_size
;
static
constexpr
index_t
num_subgroups
=
wave_size
/
num_thread_per_subgroups
;
template
<
index_t
MPerWmma
,
index_t
NPerWmma
,
class
FloatA
,
class
FloatB
,
class
FloatC
>
__device__
void
run
(
const
FloatA
&
a
,
const
FloatB
&
b
,
FloatC
&
reg_c
)
const
{
static_assert
(
wave_size
==
32
,
"only support wave32 for gfx12 wmma"
);
if
constexpr
(
wave_size
==
32
)
{
intrin_wmma_f32_16x16x16_f16_w32_gfx12
<
MPerWmma
,
NPerWmma
>::
Run
(
a
,
b
,
reg_c
);
}
}
};
template
<
index_t
WaveSize
>
struct
wmma_type
<
WmmaInstr
::
wmma_f32_16x16x16_bf16_gfx12
,
WaveSize
,
typename
std
::
enable_if_t
<
WaveSize
==
32
||
WaveSize
==
64
>>
{
// Absolute fixing property
static
constexpr
index_t
m_per_wmma
=
16
;
static
constexpr
index_t
n_per_wmma
=
16
;
static
constexpr
index_t
k_per_wmma
=
16
;
// static constexpr index_t src_a_data_size = 2;
// static constexpr index_t src_b_data_size = 2;
static
constexpr
index_t
acc_data_size
=
4
;
static
constexpr
index_t
acc_pack_number
=
1
;
static
constexpr
index_t
num_thread_per_subgroups
=
n_per_wmma
;
// Wave mode dependent propety
static
constexpr
index_t
wave_size
=
Number
<
WaveSize
>
{};
// static constexpr index_t num_src_a_vgprs_per_wave = m_per_wmma * src_a_data_size / 4;
// static constexpr index_t num_src_b_vgprs_per_wave = n_per_wmma * src_b_data_size / 4;
static
constexpr
index_t
num_acc_vgprs_per_wave
=
m_per_wmma
*
n_per_wmma
/
wave_size
;
static
constexpr
index_t
num_subgroups
=
wave_size
/
num_thread_per_subgroups
;
template
<
index_t
MPerWmma
,
index_t
NPerWmma
,
class
FloatA
,
class
FloatB
,
class
FloatC
>
__device__
void
run
(
const
FloatA
&
a
,
const
FloatB
&
b
,
FloatC
&
reg_c
)
const
{
static_assert
(
wave_size
==
32
,
"only support wave32 for gfx12 wmma"
);
if
constexpr
(
wave_size
==
32
)
{
intrin_wmma_f32_16x16x16_bf16_w32_gfx12
<
MPerWmma
,
NPerWmma
>::
Run
(
a
,
b
,
reg_c
);
}
}
};
template
<
index_t
WaveSize
>
struct
wmma_type
<
WmmaInstr
::
wmma_i32_16x16x16_iu8_gfx12
,
WaveSize
,
typename
std
::
enable_if_t
<
WaveSize
==
32
||
WaveSize
==
64
>>
{
// Absolute fixing property
static
constexpr
index_t
m_per_wmma
=
16
;
static
constexpr
index_t
n_per_wmma
=
16
;
static
constexpr
index_t
k_per_wmma
=
16
;
// static constexpr index_t src_a_data_size = 2;
// static constexpr index_t src_b_data_size = 2;
static
constexpr
index_t
acc_data_size
=
4
;
static
constexpr
index_t
acc_pack_number
=
1
;
static
constexpr
index_t
num_thread_per_subgroups
=
n_per_wmma
;
// Wave mode dependent propety
static
constexpr
index_t
wave_size
=
Number
<
WaveSize
>
{};
// static constexpr index_t num_src_a_vgprs_per_wave = m_per_wmma * src_a_data_size / 4;
// static constexpr index_t num_src_b_vgprs_per_wave = n_per_wmma * src_b_data_size / 4;
static
constexpr
index_t
num_acc_vgprs_per_wave
=
m_per_wmma
*
n_per_wmma
/
wave_size
;
static
constexpr
index_t
num_subgroups
=
wave_size
/
num_thread_per_subgroups
;
template
<
index_t
MPerWmma
,
index_t
NPerWmma
,
class
FloatA
,
class
FloatB
,
class
FloatC
,
bool
neg_a
=
false
,
bool
neg_b
=
false
,
bool
clamp
=
false
>
__device__
void
run
(
const
FloatA
&
a
,
const
FloatB
&
b
,
FloatC
&
reg_c
)
const
{
static_assert
(
wave_size
==
32
,
"only support wave32 for gfx12 wmma"
);
if
constexpr
(
wave_size
==
32
)
{
intrin_wmma_i32_16x16x16_iu8_w32_gfx12
<
MPerWmma
,
NPerWmma
,
neg_a
,
neg_b
,
clamp
>::
Run
(
a
,
b
,
reg_c
);
}
}
};
template
<
typename
src_type_a
,
template
<
typename
src_type_a
,
typename
src_type_b
,
typename
src_type_b
,
typename
dst_type
,
typename
dst_type
,
...
@@ -296,13 +417,21 @@ struct WmmaSelector
...
@@ -296,13 +417,21 @@ struct WmmaSelector
template
<
>
template
<
>
static
constexpr
auto
GetWmma
<
half_t
,
half_t
,
float
,
16
,
16
>
()
static
constexpr
auto
GetWmma
<
half_t
,
half_t
,
float
,
16
,
16
>
()
{
{
#ifdef __gfx12__
return
WmmaInstr
::
wmma_f32_16x16x16_f16_gfx12
;
#else
return
WmmaInstr
::
wmma_f32_16x16x16_f16
;
return
WmmaInstr
::
wmma_f32_16x16x16_f16
;
#endif
}
}
template
<
>
template
<
>
static
constexpr
auto
GetWmma
<
bhalf_t
,
bhalf_t
,
float
,
16
,
16
>
()
static
constexpr
auto
GetWmma
<
bhalf_t
,
bhalf_t
,
float
,
16
,
16
>
()
{
{
#ifdef __gfx12__
return
WmmaInstr
::
wmma_f32_16x16x16_bf16_gfx12
;
#else
return
WmmaInstr
::
wmma_f32_16x16x16_bf16
;
return
WmmaInstr
::
wmma_f32_16x16x16_bf16
;
#endif
}
}
template
<
>
template
<
>
...
@@ -320,8 +449,13 @@ struct WmmaSelector
...
@@ -320,8 +449,13 @@ struct WmmaSelector
template
<
>
template
<
>
static
constexpr
auto
GetWmma
<
int8_t
,
int8_t
,
int
,
16
,
16
>
()
static
constexpr
auto
GetWmma
<
int8_t
,
int8_t
,
int
,
16
,
16
>
()
{
{
#ifdef __gfx12__
return
WmmaInstr
::
wmma_i32_16x16x16_iu8_gfx12
;
#else
return
WmmaInstr
::
wmma_i32_16x16x16_iu8
;
return
WmmaInstr
::
wmma_i32_16x16x16_iu8
;
#endif
}
}
#ifdef CK_EXPERIMENTAL_BIT_INT_EXTENSION_INT4
#ifdef CK_EXPERIMENTAL_BIT_INT_EXTENSION_INT4
template
<
>
template
<
>
static
constexpr
auto
GetWmma
<
int4_t
,
int4_t
,
int
,
16
,
16
>
()
static
constexpr
auto
GetWmma
<
int4_t
,
int4_t
,
int
,
16
,
16
>
()
...
@@ -502,6 +636,9 @@ struct WmmaGemm
...
@@ -502,6 +636,9 @@ struct WmmaGemm
__device__
static
auto
GetSubGroupId
()
__device__
static
auto
GetSubGroupId
()
{
{
static_assert
(
wmma_instr
.
num_thread_per_subgroups
*
wmma_instr
.
num_subgroups
==
wmma_instr
.
wave_size
,
""
);
return
(
GetLaneId
()
/
wmma_instr
.
num_thread_per_subgroups
)
%
wmma_instr
.
num_subgroups
;
return
(
GetLaneId
()
/
wmma_instr
.
num_thread_per_subgroups
)
%
wmma_instr
.
num_subgroups
;
}
}
...
@@ -516,12 +653,20 @@ struct WmmaGemm
...
@@ -516,12 +653,20 @@ struct WmmaGemm
__host__
__device__
static
auto
CalculateAThreadOriginDataIndex
()
__host__
__device__
static
auto
CalculateAThreadOriginDataIndex
()
{
{
#ifdef __gfx12__
return
GetLaneIdUnderSubGroup
();
#else
return
TransposeC
?
GetLaneIdUnderSubGroup
()
:
GetSwizzledLaneIdLow
();
return
TransposeC
?
GetLaneIdUnderSubGroup
()
:
GetSwizzledLaneIdLow
();
#endif
}
}
__host__
__device__
static
auto
CalculateBThreadOriginDataIndex
()
__host__
__device__
static
auto
CalculateBThreadOriginDataIndex
()
{
{
#ifdef __gfx12__
return
GetLaneIdUnderSubGroup
();
#else
return
TransposeC
?
GetSwizzledLaneIdLow
()
:
GetLaneIdUnderSubGroup
();
return
TransposeC
?
GetSwizzledLaneIdLow
()
:
GetLaneIdUnderSubGroup
();
#endif
}
}
__device__
static
CIndex
GetBeginOfThreadBlk
()
__device__
static
CIndex
GetBeginOfThreadBlk
()
...
...
include/ck/utility/amd_smfmac.hpp
0 → 100644
View file @
909f519c
// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
#include "ck/ck.hpp"
#pragma once
namespace
ck
{
template
<
index_t
MPerWave
,
index_t
NPerWave
>
struct
intrin_smfmac_f32_16x16x32f16
;
template
<
>
struct
intrin_smfmac_f32_16x16x32f16
<
16
,
16
>
{
template
<
class
FloatC
>
__device__
static
void
Run
(
const
half4_t
&
reg_a
,
const
half8_t
&
reg_b
,
const
int32_t
&
reg_idx
,
FloatC
&
reg_c
)
{
reg_c
.
template
AsType
<
float4_t
>()(
Number
<
0
>
{})
=
__builtin_amdgcn_smfmac_f32_16x16x32_f16
(
reg_a
,
reg_b
,
reg_c
.
template
AsType
<
float4_t
>()[
Number
<
0
>
{}],
reg_idx
,
0
,
0
);
}
};
template
<
index_t
MPerWave
,
index_t
NPerWave
>
struct
intrin_smfmac_f32_16x16x32bf16
;
template
<
>
struct
intrin_smfmac_f32_16x16x32bf16
<
16
,
16
>
{
template
<
class
FloatC
>
__device__
static
void
Run
(
const
bhalf4_t
&
reg_a
,
const
bhalf8_t
&
reg_b
,
const
int32_t
&
reg_idx
,
FloatC
&
reg_c
)
{
reg_c
.
template
AsType
<
float4_t
>()(
Number
<
0
>
{})
=
__builtin_amdgcn_smfmac_f32_16x16x32_bf16
(
reg_a
,
reg_b
,
reg_c
.
template
AsType
<
float4_t
>()[
Number
<
0
>
{}],
reg_idx
,
0
,
0
);
}
};
template
<
index_t
MPerWave
,
index_t
NPerWave
>
struct
intrin_smfmac_f32_32x32x16f16
;
template
<
>
struct
intrin_smfmac_f32_32x32x16f16
<
32
,
32
>
{
template
<
class
FloatC
>
__device__
static
void
Run
(
const
half4_t
&
reg_a
,
const
half8_t
&
reg_b
,
const
int32_t
&
reg_idx
,
FloatC
&
reg_c
)
{
reg_c
.
template
AsType
<
float16_t
>()(
Number
<
0
>
{})
=
__builtin_amdgcn_smfmac_f32_32x32x16_f16
(
reg_a
,
reg_b
,
reg_c
.
template
AsType
<
float16_t
>()[
Number
<
0
>
{}],
reg_idx
,
0
,
0
);
}
};
template
<
index_t
MPerWave
,
index_t
NPerWave
>
struct
intrin_smfmac_f32_32x32x16bf16
;
template
<
>
struct
intrin_smfmac_f32_32x32x16bf16
<
32
,
32
>
{
template
<
class
FloatC
>
__device__
static
void
Run
(
const
bhalf4_t
&
reg_a
,
const
bhalf8_t
&
reg_b
,
const
int32_t
&
reg_idx
,
FloatC
&
reg_c
)
{
reg_c
.
template
AsType
<
float16_t
>()(
Number
<
0
>
{})
=
__builtin_amdgcn_smfmac_f32_32x32x16_bf16
(
reg_a
,
reg_b
,
reg_c
.
template
AsType
<
float16_t
>()[
Number
<
0
>
{}],
reg_idx
,
0
,
0
);
}
};
}
// namespace ck
include/ck/utility/amd_wmma.hpp
View file @
909f519c
...
@@ -257,5 +257,87 @@ struct intrin_wmma_i32_16x16x16_iu8_w64<16, 16, neg_a, neg_b, clamp>
...
@@ -257,5 +257,87 @@ struct intrin_wmma_i32_16x16x16_iu8_w64<16, 16, neg_a, neg_b, clamp>
}
}
};
};
// gfx12
/********************************WAVE32 MODE***********************************************/
#if defined(__gfx1200__) || defined(__gfx1201__)
#define __gfx12__
#endif
// src: fp16, dst: fp32
template
<
index_t
MPerWave
,
index_t
NPerWave
>
struct
intrin_wmma_f32_16x16x16_f16_w32_gfx12
;
template
<
>
struct
intrin_wmma_f32_16x16x16_f16_w32_gfx12
<
16
,
16
>
{
template
<
class
FloatC
>
__device__
static
void
Run
(
const
half8_t
&
reg_a
,
const
half8_t
&
reg_b
,
FloatC
&
reg_c
)
{
// * Inline assembly need to elimate the duplicated data load, compiler won't help you
// delete them.
// amd_assembly_wmma_f32_16x16x16_f16_w32(
// reg_a, reg_b, reg_c.template AsType<float8_t>()(Number<0>{}));
#if defined(__gfx12__)
reg_c
.
template
AsType
<
float8_t
>()(
Number
<
0
>
{})
=
__builtin_amdgcn_wmma_f32_16x16x16_f16_w32_gfx12
(
reg_a
,
reg_b
,
reg_c
.
template
AsType
<
float8_t
>()[
Number
<
0
>
{}]);
#else
ignore
=
reg_a
;
ignore
=
reg_b
;
ignore
=
reg_c
;
#endif
}
};
// src: bf16, dst: fp32
template
<
index_t
MPerWave
,
index_t
NPerWave
>
struct
intrin_wmma_f32_16x16x16_bf16_w32_gfx12
;
template
<
>
struct
intrin_wmma_f32_16x16x16_bf16_w32_gfx12
<
16
,
16
>
{
template
<
class
FloatC
>
__device__
static
void
Run
(
const
bhalf8_t
&
reg_a
,
const
bhalf8_t
&
reg_b
,
FloatC
&
reg_c
)
{
#if defined(__gfx12__)
reg_c
.
template
AsType
<
float8_t
>()(
Number
<
0
>
{})
=
__builtin_amdgcn_wmma_f32_16x16x16_bf16_w32_gfx12
(
reg_a
,
reg_b
,
reg_c
.
template
AsType
<
float8_t
>()[
Number
<
0
>
{}]);
#else
ignore
=
reg_a
;
ignore
=
reg_b
;
ignore
=
reg_c
;
#endif
}
};
// src: iu8, dst: i32
template
<
index_t
MPerWave
,
index_t
NPerWave
,
bool
neg_a
,
bool
neg_b
,
bool
clamp
>
struct
intrin_wmma_i32_16x16x16_iu8_w32_gfx12
;
template
<
bool
neg_a
,
bool
neg_b
,
bool
clamp
>
struct
intrin_wmma_i32_16x16x16_iu8_w32_gfx12
<
16
,
16
,
neg_a
,
neg_b
,
clamp
>
{
template
<
class
FloatC
>
__device__
static
void
Run
(
const
int8x8_t
&
reg_a
,
const
int8x8_t
&
reg_b
,
FloatC
&
reg_c
)
{
#if defined(__gfx12__)
reg_c
.
template
AsType
<
int32x8_t
>()(
Number
<
0
>
{})
=
__builtin_amdgcn_wmma_i32_16x16x16_iu8_w32_gfx12
(
neg_a
,
bit_cast
<
int32x2_t
>
(
reg_a
),
neg_b
,
bit_cast
<
int32x2_t
>
(
reg_b
),
reg_c
.
template
AsType
<
int32x8_t
>()[
Number
<
0
>
{}],
clamp
);
#else
ignore
=
reg_a
;
ignore
=
reg_b
;
ignore
=
reg_c
;
#endif
}
};
}
// namespace ck
}
// namespace ck
#endif
#endif
include/ck/utility/data_type.hpp
View file @
909f519c
...
@@ -203,7 +203,7 @@ struct vector_type<T, 1>
...
@@ -203,7 +203,7 @@ struct vector_type<T, 1>
}
}
};
};
int
static
err
=
0
;
__device__
int
static
err
=
0
;
template
<
typename
T
>
template
<
typename
T
>
struct
vector_type
<
T
,
2
>
struct
vector_type
<
T
,
2
>
{
{
...
...
include/ck/utility/synchronization.hpp
View file @
909f519c
...
@@ -10,12 +10,20 @@ namespace ck {
...
@@ -10,12 +10,20 @@ namespace ck {
__device__
void
block_sync_lds
()
__device__
void
block_sync_lds
()
{
{
#if CK_EXPERIMENTAL_BLOCK_SYNC_LDS_WITHOUT_SYNC_VMEM
#if CK_EXPERIMENTAL_BLOCK_SYNC_LDS_WITHOUT_SYNC_VMEM
#ifdef __gfx12__
asm
volatile
(
"\
s_wait_dscnt 0x0
\n
\
s_barrier_signal -1
\n
\
s_barrier_wait -1 \
"
::
);
#else
// asm volatile("\
// asm volatile("\
// s_waitcnt lgkmcnt(0) \n \
// s_waitcnt lgkmcnt(0) \n \
// s_barrier \
// s_barrier \
// " ::);
// " ::);
__builtin_amdgcn_s_waitcnt
(
0xc07f
);
__builtin_amdgcn_s_waitcnt
(
0xc07f
);
__builtin_amdgcn_s_barrier
();
__builtin_amdgcn_s_barrier
();
#endif
#else
#else
__syncthreads
();
__syncthreads
();
#endif
#endif
...
@@ -23,11 +31,20 @@ __device__ void block_sync_lds()
...
@@ -23,11 +31,20 @@ __device__ void block_sync_lds()
__device__
void
block_sync_lds_direct_load
()
__device__
void
block_sync_lds_direct_load
()
{
{
#ifdef __gfx12__
asm
volatile
(
"\
s_wait_vmcnt 0x0
\n
\
s_wait_dscnt 0x0
\n
\
s_barrier_signal -1
\n
\
s_barrier_wait -1 \
"
::
);
#else
asm
volatile
(
"\
asm
volatile
(
"\
s_waitcnt vmcnt(0)
\n
\
s_waitcnt vmcnt(0)
\n
\
s_waitcnt lgkmcnt(0)
\n
\
s_waitcnt lgkmcnt(0)
\n
\
s_barrier \
s_barrier \
"
::
);
"
::
);
#endif
}
}
__device__
void
s_nop
()
__device__
void
s_nop
()
...
...
include/ck_tile/core/config.hpp
View file @
909f519c
...
@@ -17,6 +17,9 @@
...
@@ -17,6 +17,9 @@
#if defined(__gfx1100__) || defined(__gfx1101__) || defined(__gfx1102__) || defined(__gfx1103__)
#if defined(__gfx1100__) || defined(__gfx1101__) || defined(__gfx1102__) || defined(__gfx1103__)
#define __gfx11__
#define __gfx11__
#endif
#endif
#if defined(__gfx1200__) || defined(__gfx1201__)
#define __gfx12__
#endif
#ifndef CK_TILE_DONT_USE_HIP_RUNTIME_HEADERS
#ifndef CK_TILE_DONT_USE_HIP_RUNTIME_HEADERS
#include "hip/hip_runtime.h"
#include "hip/hip_runtime.h"
...
@@ -155,7 +158,7 @@
...
@@ -155,7 +158,7 @@
#define CK_TILE_BUFFER_RESOURCE_3RD_DWORD 0x00020000
#define CK_TILE_BUFFER_RESOURCE_3RD_DWORD 0x00020000
#elif defined(__gfx103__) // for GPU code
#elif defined(__gfx103__) // for GPU code
#define CK_TILE_BUFFER_RESOURCE_3RD_DWORD 0x31014000
#define CK_TILE_BUFFER_RESOURCE_3RD_DWORD 0x31014000
#elif defined(__gfx11__) // for GPU code
#elif defined(__gfx11__)
|| defined(__gfx12__)
// for GPU code
#define CK_TILE_BUFFER_RESOURCE_3RD_DWORD 0x31004000
#define CK_TILE_BUFFER_RESOURCE_3RD_DWORD 0x31004000
#endif
#endif
...
...
include/ck_tile/host/timer.hpp
View file @
909f519c
...
@@ -27,7 +27,7 @@ struct gpu_timer
...
@@ -27,7 +27,7 @@ struct gpu_timer
CK_TILE_HOST
void
start
(
const
hipStream_t
&
s
)
CK_TILE_HOST
void
start
(
const
hipStream_t
&
s
)
{
{
HIP_CHECK_ERROR
(
hip
Device
Synchronize
());
HIP_CHECK_ERROR
(
hip
Stream
Synchronize
(
s
));
HIP_CHECK_ERROR
(
hipEventRecord
(
start_evt
,
s
));
HIP_CHECK_ERROR
(
hipEventRecord
(
start_evt
,
s
));
}
}
...
@@ -51,15 +51,15 @@ struct gpu_timer
...
@@ -51,15 +51,15 @@ struct gpu_timer
struct
cpu_timer
struct
cpu_timer
{
{
// torch.utils.benchmark.Timer(), there is a sync inside each timer callback
// torch.utils.benchmark.Timer(), there is a sync inside each timer callback
CK_TILE_HOST
void
start
(
const
hipStream_t
&
)
CK_TILE_HOST
void
start
(
const
hipStream_t
&
s
)
{
{
HIP_CHECK_ERROR
(
hip
Device
Synchronize
());
HIP_CHECK_ERROR
(
hip
Stream
Synchronize
(
s
));
start_tick
=
std
::
chrono
::
high_resolution_clock
::
now
();
start_tick
=
std
::
chrono
::
high_resolution_clock
::
now
();
}
}
// torch.utils.benchmark.Timer(), there is a sync inside each timer callback
// torch.utils.benchmark.Timer(), there is a sync inside each timer callback
CK_TILE_HOST
void
stop
(
const
hipStream_t
&
)
CK_TILE_HOST
void
stop
(
const
hipStream_t
&
s
)
{
{
HIP_CHECK_ERROR
(
hip
Device
Synchronize
());
HIP_CHECK_ERROR
(
hip
Stream
Synchronize
(
s
));
stop_tick
=
std
::
chrono
::
high_resolution_clock
::
now
();
stop_tick
=
std
::
chrono
::
high_resolution_clock
::
now
();
}
}
// return in ms
// return in ms
...
...
include/ck_tile/ops/fmha.hpp
View file @
909f519c
...
@@ -10,6 +10,10 @@
...
@@ -10,6 +10,10 @@
#include "ck_tile/ops/fmha/kernel/fmha_bwd_kernel.hpp"
#include "ck_tile/ops/fmha/kernel/fmha_bwd_kernel.hpp"
#include "ck_tile/ops/fmha/kernel/fmha_bwd_tile_partitioner.hpp"
#include "ck_tile/ops/fmha/kernel/fmha_bwd_tile_partitioner.hpp"
#include "ck_tile/ops/fmha/kernel/fmha_fwd_kernel.hpp"
#include "ck_tile/ops/fmha/kernel/fmha_fwd_kernel.hpp"
#include "ck_tile/ops/fmha/kernel/fmha_fwd_splitkv_combine_kernel.hpp"
#include "ck_tile/ops/fmha/kernel/fmha_fwd_splitkv_combine_tile_partitioner.hpp"
#include "ck_tile/ops/fmha/kernel/fmha_fwd_splitkv_kernel.hpp"
#include "ck_tile/ops/fmha/kernel/fmha_fwd_splitkv_tile_partitioner.hpp"
#include "ck_tile/ops/fmha/kernel/fmha_fwd_tile_partitioner.hpp"
#include "ck_tile/ops/fmha/kernel/fmha_fwd_tile_partitioner.hpp"
#include "ck_tile/ops/fmha/pipeline/block_fmha_bwd_dot_do_o.hpp"
#include "ck_tile/ops/fmha/pipeline/block_fmha_bwd_dot_do_o.hpp"
#include "ck_tile/ops/fmha/pipeline/block_fmha_bwd_dot_do_o_default_policy.hpp"
#include "ck_tile/ops/fmha/pipeline/block_fmha_bwd_dot_do_o_default_policy.hpp"
...
@@ -22,6 +26,12 @@
...
@@ -22,6 +26,12 @@
#include "ck_tile/ops/fmha/pipeline/block_fmha_bwd_pipeline_default_policy.hpp"
#include "ck_tile/ops/fmha/pipeline/block_fmha_bwd_pipeline_default_policy.hpp"
#include "ck_tile/ops/fmha/pipeline/block_fmha_bwd_pipeline_enum.hpp"
#include "ck_tile/ops/fmha/pipeline/block_fmha_bwd_pipeline_enum.hpp"
#include "ck_tile/ops/fmha/pipeline/block_fmha_bwd_pipeline_problem.hpp"
#include "ck_tile/ops/fmha/pipeline/block_fmha_bwd_pipeline_problem.hpp"
#include "ck_tile/ops/fmha/pipeline/block_fmha_fwd_splitkv_combine_pipeline.hpp"
#include "ck_tile/ops/fmha/pipeline/block_fmha_fwd_splitkv_combine_pipeline_default_policy.hpp"
#include "ck_tile/ops/fmha/pipeline/block_fmha_fwd_splitkv_pipeline_qr_ks_vs.hpp"
#include "ck_tile/ops/fmha/pipeline/block_fmha_fwd_splitkv_pipeline_qr_ks_vs_async.hpp"
#include "ck_tile/ops/fmha/pipeline/block_fmha_fwd_splitkv_pipeline_qr_ks_vs_async_default_policy.hpp"
#include "ck_tile/ops/fmha/pipeline/block_fmha_fwd_splitkv_pipeline_qr_ks_vs_default_policy.hpp"
#include "ck_tile/ops/fmha/pipeline/block_fmha_pipeline_enum.hpp"
#include "ck_tile/ops/fmha/pipeline/block_fmha_pipeline_enum.hpp"
#include "ck_tile/ops/fmha/pipeline/block_fmha_pipeline_problem.hpp"
#include "ck_tile/ops/fmha/pipeline/block_fmha_pipeline_problem.hpp"
#include "ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs.hpp"
#include "ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs.hpp"
...
...
include/ck_tile/ops/fmha/block/block_masking.hpp
View file @
909f519c
...
@@ -299,6 +299,23 @@ struct SimplifiedGenericAttentionMask
...
@@ -299,6 +299,23 @@ struct SimplifiedGenericAttentionMask
}
}
}
}
template
<
index_t
TileHeight
,
index_t
TileWidth
>
CK_TILE_HOST_DEVICE
constexpr
auto
GetTileRangeAlongX
(
index_t
i_y
,
number
<
TileHeight
>
height
,
number
<
TileWidth
>
width
,
index_t
num_splits
,
index_t
i_split
)
const
{
auto
[
origin_start
,
origin_end
]
=
GetTileRangeAlongX
(
i_y
,
height
,
width
);
const
index_t
x_per_split
=
ck_tile
::
max
(
1
,
x_total
/
num_splits
);
const
index_t
split_start
=
x_per_split
*
i_split
;
const
index_t
split_end
=
(
i_split
==
num_splits
-
1
?
x_total
:
split_start
+
x_per_split
);
return
ck_tile
::
make_tuple
(
ck_tile
::
max
(
origin_start
,
split_start
),
ck_tile
::
min
(
origin_end
,
split_end
));
}
// to get the loop length along Y axis, return index:[start, end), end-start=length
// to get the loop length along Y axis, return index:[start, end), end-start=length
// use this if need loop over Y axis tile by tile (like q-seqlen loopover)
// use this if need loop over Y axis tile by tile (like q-seqlen loopover)
// TODO: y_end still could be negative, so end-start could be negative(need check)
// TODO: y_end still could be negative, so end-start could be negative(need check)
...
...
Prev
1
2
3
4
5
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment