Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel_ROCM
Commits
ce30621d
Unverified
Commit
ce30621d
authored
Jun 17, 2024
by
Rostyslav Geyyer
Committed by
GitHub
Jun 17, 2024
Browse files
Merge branch 'develop' into lwpck-1815
parents
e8450b71
17ed368f
Changes
22
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
405 additions
and
231 deletions
+405
-231
include/ck/ck.hpp
include/ck/ck.hpp
+1
-1
include/ck/tensor_operation/gpu/device/impl/device_column_to_image_impl.hpp
...operation/gpu/device/impl/device_column_to_image_impl.hpp
+3
-2
include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_data_multiple_d_xdl_cshuffle_v1.hpp
...vice_grouped_conv_bwd_data_multiple_d_xdl_cshuffle_v1.hpp
+4
-7
include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_dl.hpp
...ion/gpu/device/impl/device_grouped_conv_bwd_weight_dl.hpp
+3
-6
include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_multiple_d_xdl_cshuffle.hpp
...evice_grouped_conv_bwd_weight_multiple_d_xdl_cshuffle.hpp
+3
-6
include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_two_stage_xdl_cshuffle.hpp
...device_grouped_conv_bwd_weight_two_stage_xdl_cshuffle.hpp
+7
-13
include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_xdl_cshuffle.hpp
...vice/impl/device_grouped_conv_bwd_weight_xdl_cshuffle.hpp
+3
-6
include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_dl_multiple_d_nhwc_kyxc_nhwk.hpp
.../device_grouped_conv_fwd_dl_multiple_d_nhwc_kyxc_nhwk.hpp
+8
-10
include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_dl_nhwc_kyxc_nhwk.hpp
...device/impl/device_grouped_conv_fwd_dl_nhwc_kyxc_nhwk.hpp
+5
-4
include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_abd_xdl_cshuffle.hpp
...mpl/device_grouped_conv_fwd_multiple_abd_xdl_cshuffle.hpp
+109
-54
include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_abd_xdl_cshuffle_v3.hpp
.../device_grouped_conv_fwd_multiple_abd_xdl_cshuffle_v3.hpp
+76
-49
include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_d_multiple_r_xdl_cshuffle.hpp
...e_grouped_conv_fwd_multiple_d_multiple_r_xdl_cshuffle.hpp
+5
-4
include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_d_wmma_cshuffle.hpp
...impl/device_grouped_conv_fwd_multiple_d_wmma_cshuffle.hpp
+5
-4
include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_utils.hpp
...r_operation/gpu/device/impl/device_grouped_conv_utils.hpp
+30
-30
include/ck/tensor_operation/gpu/device/impl/device_image_to_column_impl.hpp
...operation/gpu/device/impl/device_image_to_column_impl.hpp
+3
-2
include/ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_d_wmma_cshuffle.hpp
...ation/gpu/grid/gridwise_gemm_multiple_d_wmma_cshuffle.hpp
+7
-13
include/ck/tensor_operation/operator_transform/transform_conv_fwd_to_gemm.hpp
...eration/operator_transform/transform_conv_fwd_to_gemm.hpp
+81
-14
include/ck_tile/core/arch/arch.hpp
include/ck_tile/core/arch/arch.hpp
+7
-4
include/ck_tile/core/config.hpp
include/ck_tile/core/config.hpp
+4
-0
include/ck_tile/core/tensor/tile_elementwise.hpp
include/ck_tile/core/tensor/tile_elementwise.hpp
+41
-2
No files found.
include/ck/ck.hpp
View file @
ce30621d
...
...
@@ -155,7 +155,7 @@ CK_DECLARE_ENV_VAR_BOOL(CK_LOGGING)
#define CK_USE_AMD_V_DOT_DPP8_INLINE_ASM 1
// LDS direct loads using inline assembly
#define CK_USE_AMD_LDS_DIRECT_LOAD_INLINE_ASM
1
#define CK_USE_AMD_LDS_DIRECT_LOAD_INLINE_ASM
0
// set stochastic rounding as default for f8 conversions
#define CK_USE_SR_F8_CONVERSION 1
...
...
include/ck/tensor_operation/gpu/device/impl/device_column_to_image_impl.hpp
View file @
ce30621d
// SPDX-License-Identifier: MIT
// Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2023
-2024
, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
...
...
@@ -247,7 +247,8 @@ struct DeviceColumnToImageImpl
independent_filter_strides
,
conv_filter_dilations
,
input_left_pads_with_offset
,
input_right_pads
);
input_right_pads
,
N
);
const
auto
in_gemmm_gemmk_desc
=
matrix_padder
.
PadADescriptor_M_K
(
in_gemmmraw_gemmkraw_desc
);
...
...
include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_data_multiple_d_xdl_cshuffle_v1.hpp
View file @
ce30621d
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-202
3
, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-202
4
, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
...
...
@@ -93,12 +93,9 @@ __global__ void
__builtin_amdgcn_readfirstlane
(
get_grid_size
()
/
batch_count
);
const
index_t
g_idx
=
__builtin_amdgcn_readfirstlane
(
get_block_1d_id
()
/
num_blocks_per_batch
);
const
long_index_t
a_batch_offset
=
__builtin_amdgcn_readfirstlane
(
static_cast
<
long_index_t
>
(
compute_ptr_offset_of_batch
.
GetAPtrOffset
(
g_idx
)));
const
long_index_t
b_batch_offset
=
__builtin_amdgcn_readfirstlane
(
static_cast
<
long_index_t
>
(
compute_ptr_offset_of_batch
.
GetBPtrOffset
(
g_idx
)));
const
long_index_t
e_batch_offset
=
__builtin_amdgcn_readfirstlane
(
static_cast
<
long_index_t
>
(
compute_ptr_offset_of_batch
.
GetEPtrOffset
(
g_idx
)));
const
long_index_t
a_batch_offset
=
compute_ptr_offset_of_batch
.
GetAPtrOffset
(
g_idx
);
const
long_index_t
b_batch_offset
=
compute_ptr_offset_of_batch
.
GetBPtrOffset
(
g_idx
);
const
long_index_t
e_batch_offset
=
compute_ptr_offset_of_batch
.
GetEPtrOffset
(
g_idx
);
const
auto
ds_batch_offset
=
compute_ptr_offset_of_batch
.
GetDsPtrOffset
(
g_idx
);
...
...
include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_dl.hpp
View file @
ce30621d
...
...
@@ -54,12 +54,9 @@ __global__ void
__builtin_amdgcn_readfirstlane
(
get_grid_size
()
/
batch_count
);
const
index_t
g_idx
=
__builtin_amdgcn_readfirstlane
(
get_block_1d_id
()
/
num_blocks_per_batch
);
const
long_index_t
a_batch_offset
=
__builtin_amdgcn_readfirstlane
(
static_cast
<
long_index_t
>
(
compute_ptr_offset_of_batch
.
GetAPtrOffset
(
g_idx
)));
const
long_index_t
b_batch_offset
=
__builtin_amdgcn_readfirstlane
(
static_cast
<
long_index_t
>
(
compute_ptr_offset_of_batch
.
GetBPtrOffset
(
g_idx
)));
const
long_index_t
c_batch_offset
=
__builtin_amdgcn_readfirstlane
(
static_cast
<
long_index_t
>
(
compute_ptr_offset_of_batch
.
GetCPtrOffset
(
g_idx
)));
const
long_index_t
a_batch_offset
=
compute_ptr_offset_of_batch
.
GetAPtrOffset
(
g_idx
);
const
long_index_t
b_batch_offset
=
compute_ptr_offset_of_batch
.
GetBPtrOffset
(
g_idx
);
const
long_index_t
c_batch_offset
=
compute_ptr_offset_of_batch
.
GetCPtrOffset
(
g_idx
);
__shared__
FloatAB
p_shared
[
GridwiseGemm
::
GetSharedMemoryNumberOfByte
()
/
sizeof
(
FloatAB
)];
...
...
include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_multiple_d_xdl_cshuffle.hpp
View file @
ce30621d
...
...
@@ -66,12 +66,9 @@ __global__ void
__builtin_amdgcn_readfirstlane
(
get_grid_size
()
/
batch_count
);
const
index_t
g_idx
=
__builtin_amdgcn_readfirstlane
(
get_block_1d_id
()
/
num_blocks_per_batch
);
const
long_index_t
a_batch_offset
=
__builtin_amdgcn_readfirstlane
(
static_cast
<
long_index_t
>
(
compute_ptr_offset_of_batch
.
GetAPtrOffset
(
g_idx
)));
const
long_index_t
b_batch_offset
=
__builtin_amdgcn_readfirstlane
(
static_cast
<
long_index_t
>
(
compute_ptr_offset_of_batch
.
GetBPtrOffset
(
g_idx
)));
const
long_index_t
c_batch_offset
=
__builtin_amdgcn_readfirstlane
(
static_cast
<
long_index_t
>
(
compute_ptr_offset_of_batch
.
GetCPtrOffset
(
g_idx
)));
const
long_index_t
a_batch_offset
=
compute_ptr_offset_of_batch
.
GetAPtrOffset
(
g_idx
);
const
long_index_t
b_batch_offset
=
compute_ptr_offset_of_batch
.
GetBPtrOffset
(
g_idx
);
const
long_index_t
c_batch_offset
=
compute_ptr_offset_of_batch
.
GetCPtrOffset
(
g_idx
);
__shared__
FloatA
p_shared
[
GridwiseGemm
::
GetSharedMemoryNumberOfByte
()
/
sizeof
(
FloatA
)];
...
...
include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_two_stage_xdl_cshuffle.hpp
View file @
ce30621d
...
...
@@ -59,12 +59,9 @@ __global__ void
const
index_t
g_idx
=
__builtin_amdgcn_readfirstlane
(
blockIdx
.
z
*
NumBatchToMerge
);
const
index_t
k_idx
=
__builtin_amdgcn_readfirstlane
(
blockIdx
.
y
*
num_k_per_block
);
const
long_index_t
a_batch_offset
=
__builtin_amdgcn_readfirstlane
(
static_cast
<
long_index_t
>
(
compute_ptr_offset_of_batch
.
GetAPtrOffset
(
g_idx
)));
const
long_index_t
b_batch_offset
=
__builtin_amdgcn_readfirstlane
(
static_cast
<
long_index_t
>
(
compute_ptr_offset_of_batch
.
GetBPtrOffset
(
g_idx
)));
const
long_index_t
e_batch_offset
=
__builtin_amdgcn_readfirstlane
(
static_cast
<
long_index_t
>
(
compute_ptr_offset_of_batch
.
GetEPtrOffset
(
g_idx
)));
const
long_index_t
a_batch_offset
=
compute_ptr_offset_of_batch
.
GetAPtrOffset
(
g_idx
);
const
long_index_t
b_batch_offset
=
compute_ptr_offset_of_batch
.
GetBPtrOffset
(
g_idx
);
const
long_index_t
e_batch_offset
=
compute_ptr_offset_of_batch
.
GetEPtrOffset
(
g_idx
);
__shared__
char
p_shared
[
GridwiseGemm
::
GetSharedMemoryNumberOfByte
()];
...
...
@@ -116,12 +113,9 @@ __global__ void
const
index_t
g_idx
=
__builtin_amdgcn_readfirstlane
(
blockIdx
.
z
*
NumBatchToMerge
);
const
index_t
k_idx
=
__builtin_amdgcn_readfirstlane
(
blockIdx
.
y
*
num_k_per_block
);
const
long_index_t
a_batch_offset
=
__builtin_amdgcn_readfirstlane
(
static_cast
<
long_index_t
>
(
compute_ptr_offset_of_batch
.
GetAPtrOffset
(
g_idx
)));
const
long_index_t
b_batch_offset
=
__builtin_amdgcn_readfirstlane
(
static_cast
<
long_index_t
>
(
compute_ptr_offset_of_batch
.
GetBPtrOffset
(
g_idx
)));
const
long_index_t
e_batch_offset
=
__builtin_amdgcn_readfirstlane
(
static_cast
<
long_index_t
>
(
compute_ptr_offset_of_batch
.
GetEPtrOffset
(
g_idx
)));
const
long_index_t
a_batch_offset
=
compute_ptr_offset_of_batch
.
GetAPtrOffset
(
g_idx
);
const
long_index_t
b_batch_offset
=
compute_ptr_offset_of_batch
.
GetBPtrOffset
(
g_idx
);
const
long_index_t
e_batch_offset
=
compute_ptr_offset_of_batch
.
GetEPtrOffset
(
g_idx
);
// Pass two lds pointer is the key to tell compiler that ds_read/write
// operate on different lds chunk at same time without order dependecy
...
...
@@ -1268,7 +1262,7 @@ struct DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle
arg
.
Conv_G_
;
std
::
array
<
index_t
,
I1
>
in_out_batch_strides
=
{
arg
.
compute_ptr_offset_of_batch_
.
BatchStrideC_
};
static_cast
<
index_t
>
(
arg
.
compute_ptr_offset_of_batch_
.
BatchStrideC_
)
};
const
auto
kernel
=
kernel_batched_elementwise
<
GridwiseElementwise
,
ck
::
Tuple
<
CElementwiseGridDesc_M_N
>
,
...
...
include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_xdl_cshuffle.hpp
View file @
ce30621d
...
...
@@ -61,12 +61,9 @@ __global__ void
__builtin_amdgcn_readfirstlane
(
get_grid_size
()
/
batch_count
);
const
index_t
g_idx
=
__builtin_amdgcn_readfirstlane
(
get_block_1d_id
()
/
num_blocks_per_batch
);
const
long_index_t
a_batch_offset
=
__builtin_amdgcn_readfirstlane
(
static_cast
<
long_index_t
>
(
compute_ptr_offset_of_batch
.
GetAPtrOffset
(
g_idx
)));
const
long_index_t
b_batch_offset
=
__builtin_amdgcn_readfirstlane
(
static_cast
<
long_index_t
>
(
compute_ptr_offset_of_batch
.
GetBPtrOffset
(
g_idx
)));
const
long_index_t
c_batch_offset
=
__builtin_amdgcn_readfirstlane
(
static_cast
<
long_index_t
>
(
compute_ptr_offset_of_batch
.
GetCPtrOffset
(
g_idx
)));
const
long_index_t
a_batch_offset
=
compute_ptr_offset_of_batch
.
GetAPtrOffset
(
g_idx
);
const
long_index_t
b_batch_offset
=
compute_ptr_offset_of_batch
.
GetBPtrOffset
(
g_idx
);
const
long_index_t
c_batch_offset
=
compute_ptr_offset_of_batch
.
GetCPtrOffset
(
g_idx
);
__shared__
FloatA
p_shared
[
GridwiseGemm
::
GetSharedMemoryNumberOfByte
()
/
sizeof
(
FloatA
)];
...
...
include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_dl_multiple_d_nhwc_kyxc_nhwk.hpp
View file @
ce30621d
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-202
3
, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-202
4
, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
...
...
@@ -97,12 +97,9 @@ __global__ void
__builtin_amdgcn_readfirstlane
(
get_grid_size
()
/
batch_count
);
const
index_t
g_idx
=
__builtin_amdgcn_readfirstlane
(
get_block_1d_id
()
/
num_blocks_per_batch
);
const
long_index_t
a_batch_offset
=
__builtin_amdgcn_readfirstlane
(
static_cast
<
long_index_t
>
(
compute_ptr_offset_of_batch
.
GetAPtrOffset
(
g_idx
)));
const
long_index_t
b_batch_offset
=
__builtin_amdgcn_readfirstlane
(
static_cast
<
long_index_t
>
(
compute_ptr_offset_of_batch
.
GetBPtrOffset
(
g_idx
)));
const
long_index_t
c_batch_offset
=
__builtin_amdgcn_readfirstlane
(
static_cast
<
long_index_t
>
(
compute_ptr_offset_of_batch
.
GetEPtrOffset
(
g_idx
)));
const
long_index_t
a_batch_offset
=
compute_ptr_offset_of_batch
.
GetAPtrOffset
(
g_idx
);
const
long_index_t
b_batch_offset
=
compute_ptr_offset_of_batch
.
GetBPtrOffset
(
g_idx
);
const
long_index_t
c_batch_offset
=
compute_ptr_offset_of_batch
.
GetEPtrOffset
(
g_idx
);
const
auto
ds_batch_offset
=
compute_ptr_offset_of_batch
.
GetDsPtrOffset
(
g_idx
);
...
...
@@ -266,7 +263,8 @@ struct DeviceGroupedConvFwdDlMultipleD_NHWC_KYXC_NHWK
conv_filter_strides
,
conv_filter_dilations
,
input_left_pads
,
input_right_pads
);
input_right_pads
,
a_g_n_c_wis_lengths
[
I1
]);
const
auto
in_gemmm_gemmk_desc
=
matrix_padder
.
PadADescriptor_M_K
(
in_gemmmraw_gemmkraw_desc
);
...
...
@@ -312,8 +310,8 @@ struct DeviceGroupedConvFwdDlMultipleD_NHWC_KYXC_NHWK
const
std
::
array
<
index_t
,
NDimSpatial
+
3
>&
e_g_n_k_wos_strides
)
{
const
auto
out_gemmmraw_gemmnraw_desc
=
conv_to_gemm_transformer
.
template
MakeCDescriptor_M_N
<
ELay
>(
e_g_n_k_wos_lengths
,
e_g_n_k_wos_strides
);
conv_to_gemm_transformer
.
template
MakeCDescriptor_M_N
<
ELay
>(
e_g_n_k_wos_lengths
,
e_g_n_k_wos_strides
,
e_g_n_k_wos_lengths
[
I1
]
);
const
auto
out_gemmm_gemmn_desc
=
matrix_padder
.
PadCDescriptor_M_N
(
out_gemmmraw_gemmnraw_desc
);
...
...
include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_dl_nhwc_kyxc_nhwk.hpp
View file @
ce30621d
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-202
3
, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-202
4
, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
...
...
@@ -263,7 +263,8 @@ struct DeviceGroupedConvFwdDl_NHWC_KYXC_NHWK : public DeviceGroupedConvFwd<NDimS
conv_filter_strides
,
conv_filter_dilations
,
input_left_pads
,
input_right_pads
);
input_right_pads
,
a_g_n_c_wis_lengths
[
I1
]);
const
auto
in_gemmm_gemmk_desc
=
matrix_padder
.
PadADescriptor_M_K
(
in_gemmmraw_gemmkraw_desc
);
...
...
@@ -310,8 +311,8 @@ struct DeviceGroupedConvFwdDl_NHWC_KYXC_NHWK : public DeviceGroupedConvFwd<NDimS
const
std
::
array
<
index_t
,
NDimSpatial
+
3
>&
c_g_n_k_wos_strides
)
{
const
auto
out_gemmmraw_gemmnraw_desc
=
conv_to_gemm_transformer
.
template
MakeCDescriptor_M_N
<
CLay
>(
c_g_n_k_wos_lengths
,
c_g_n_k_wos_strides
);
conv_to_gemm_transformer
.
template
MakeCDescriptor_M_N
<
CLay
>(
c_g_n_k_wos_lengths
,
c_g_n_k_wos_strides
,
c_g_n_k_wos_lengths
[
I1
]
);
const
auto
out_gemmm_gemmn_desc
=
matrix_padder
.
PadCDescriptor_M_N
(
out_gemmmraw_gemmnraw_desc
);
...
...
include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_abd_xdl_cshuffle.hpp
View file @
ce30621d
...
...
@@ -69,7 +69,8 @@ template <typename GridwiseGemm,
typename
DsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
,
typename
EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock
,
typename
Block2ETileMap
,
typename
ComputePtrOffsetOfBatch
,
typename
ComputePtrOffsetOfG
,
typename
ComputePtrOffsetOfN
,
bool
HasMainKBlockLoop
,
bool
isMultiA
,
bool
isMultiB
>
...
...
@@ -85,7 +86,7 @@ __global__ void
const
AElementwiseOperation
a_element_op
,
const
BElementwiseOperation
b_element_op
,
const
CDEElementwiseOperation
cde_element_op
,
const
index_t
batch
_count
,
const
index_t
groups
_count
,
const
AGridDesc_AK0_M_AK1
a_grid_desc_k0_m_k1
,
const
BGridDesc_BK0_N_BK1
b_grid_desc_k0_n_k1
,
const
DsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
...
...
@@ -93,18 +94,22 @@ __global__ void
const
EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock
e_grid_desc_mblock_mperblock_nblock_nperblock_
,
const
Block2ETileMap
block_2_ctile_map
,
const
ComputePtrOffsetOfBatch
compute_ptr_offset_of_batch
)
const
ComputePtrOffsetOfG
compute_ptr_offset_of_groups
,
const
ComputePtrOffsetOfN
compute_ptr_offset_of_n
)
{
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \
defined(__gfx94__))
// offset base pointer for each work-group
const
index_t
num_blocks_per_batch
=
__builtin_amdgcn_readfirstlane
(
get_grid_size
()
/
batch_count
);
const
index_t
g_idx
=
__builtin_amdgcn_readfirstlane
(
get_block_1d_id
()
/
num_blocks_per_batch
);
const
index_t
num_blocks_per_batch
=
__builtin_amdgcn_readfirstlane
(
gridDim
.
y
/
groups_count
);
const
index_t
&
num_blocks_per_n
=
groups_count
;
const
index_t
g_idx
=
__builtin_amdgcn_readfirstlane
(
blockIdx
.
y
/
num_blocks_per_batch
);
const
index_t
n_idx
=
__builtin_amdgcn_readfirstlane
(
blockIdx
.
y
/
num_blocks_per_n
);
const
long_index_t
e_batch_offset
=
compute_ptr_offset_of_groups
.
GetEPtrOffset
(
g_idx
);
const
auto
&
ds_batch_offset
=
compute_ptr_offset_of_groups
.
GetDsPtrOffset
(
g_idx
);
const
long_index_t
e_batch_offset
=
__builtin_amdgcn_readfirstlane
(
static_cast
<
long_index_t
>
(
compute_ptr_offset_of_batch
.
GetEPtrOffset
(
g_idx
)));
const
auto
&
ds_batch_offset
=
compute_ptr_offset_of_batch
.
GetDsPtrOffset
(
g_idx
);
const
long_index_t
e_n_offset
=
compute_ptr_offset_of_n
.
GetEPtrOffset
(
n_idx
);
__shared__
char
p_shared
[
GridwiseGemm
::
GetSharedMemoryNumberOfByte
()];
...
...
@@ -121,13 +126,28 @@ __global__ void
AsPointer
p_as_grid_grp
;
BsPointer
p_bs_grid_grp
;
const
auto
&
as_batch_offset
=
compute_ptr_offset_of_batch
.
GetAsPtrOffset
(
g_idx
);
const
auto
&
as_batch_offset
=
compute_ptr_offset_of_groups
.
GetAsPtrOffset
(
g_idx
);
// compute_ptr_offset_of_n_ not need BatchStrideB so
// in case of MultiA is false but isMultiB is true
// BatchStrideA_ is not tuple.
if
constexpr
(
isMultiA
)
{
const
auto
&
as_n_offset
=
compute_ptr_offset_of_n
.
GetAsPtrOffset
(
n_idx
);
static
constexpr
index_t
NumATensor
=
AGridDesc_AK0_M_AK1
::
Size
();
static_for
<
0
,
NumATensor
,
1
>
{}(
[
&
](
auto
i
)
{
p_as_grid_grp
(
i
)
=
p_as_grid
[
i
]
+
as_batch_offset
[
i
];
});
static
constexpr
index_t
NumATensor
=
AGridDesc_AK0_M_AK1
::
Size
();
static_for
<
0
,
NumATensor
,
1
>
{}([
&
](
auto
i
)
{
p_as_grid_grp
(
i
)
=
p_as_grid
[
i
]
+
as_batch_offset
[
i
]
+
as_n_offset
[
i
];
});
}
else
{
const
long_index_t
a_n_offset
=
compute_ptr_offset_of_n
.
GetAPtrOffset
(
n_idx
);
static_for
<
0
,
1
,
1
>
{}(
[
&
](
auto
i
)
{
p_as_grid_grp
(
i
)
=
p_as_grid
[
i
]
+
as_batch_offset
[
i
]
+
a_n_offset
;
});
}
const
auto
&
bs_batch_offset
=
compute_ptr_offset_of_
batch
.
GetBsPtrOffset
(
g_idx
);
const
auto
&
bs_batch_offset
=
compute_ptr_offset_of_
groups
.
GetBsPtrOffset
(
g_idx
);
static
constexpr
index_t
NumBTensor
=
BGridDesc_BK0_N_BK1
::
Size
();
static_for
<
0
,
NumBTensor
,
1
>
{}(
...
...
@@ -137,7 +157,7 @@ __global__ void
p_as_grid_grp
,
p_bs_grid_grp
,
p_ds_grid_grp
,
p_e_grid
+
e_batch_offset
,
p_e_grid
+
e_batch_offset
+
e_n_offset
,
p_shared
,
a_element_op
,
b_element_op
,
...
...
@@ -150,16 +170,16 @@ __global__ void
}
else
{
const
long_index_t
a_batch_offset
=
__builtin_amdgcn_readfirstlane
(
static_ca
st
<
long_index_t
>
(
compute_ptr_offset_of_
batch
.
Get
A
PtrOffset
(
g_idx
)
))
;
const
long_index_t
b_batch_offset
=
__builtin_amdgcn_readfirstlane
(
static_ca
st
<
long_index_t
>
(
compute_ptr_offset_of_
batch
.
Get
B
PtrOffset
(
g
_idx
)
))
;
const
long_index_t
a_batch_offset
=
compute_ptr_offset_of_groups
.
GetAPtrOffset
(
g_idx
);
con
st
long_index_t
b_batch_offset
=
compute_ptr_offset_of_
groups
.
Get
B
PtrOffset
(
g_idx
);
con
st
long_index_t
a_n_offset
=
compute_ptr_offset_of_
n
.
Get
A
PtrOffset
(
n
_idx
);
GridwiseGemm
::
template
Run
<
HasMainKBlockLoop
>(
p_as_grid
+
a_batch_offset
,
p_as_grid
+
a_batch_offset
+
a_n_offset
,
p_bs_grid
+
b_batch_offset
,
p_ds_grid_grp
,
p_e_grid
+
e_batch_offset
,
p_e_grid
+
e_batch_offset
+
e_n_offset
,
p_shared
,
a_element_op
,
b_element_op
,
...
...
@@ -175,7 +195,7 @@ __global__ void
ignore
=
p_bs_grid
;
ignore
=
p_ds_grid
;
ignore
=
p_e_grid
;
ignore
=
batch
_count
;
ignore
=
groups
_count
;
ignore
=
a_grid_desc_k0_m_k1
;
ignore
=
b_grid_desc_k0_n_k1
;
ignore
=
ds_grid_desc_mblock_mperblock_nblock_nperblock
;
...
...
@@ -183,7 +203,8 @@ __global__ void
ignore
=
a_element_op
;
ignore
=
b_element_op
;
ignore
=
cde_element_op
;
ignore
=
compute_ptr_offset_of_batch
;
ignore
=
compute_ptr_offset_of_groups
;
ignore
=
compute_ptr_offset_of_n
;
ignore
=
block_2_ctile_map
;
#endif
}
...
...
@@ -309,7 +330,8 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
const
std
::
array
<
index_t
,
NDimSpatial
>&
conv_filter_strides
,
const
std
::
array
<
index_t
,
NDimSpatial
>&
conv_filter_dilations
,
const
std
::
array
<
index_t
,
NDimSpatial
>&
input_left_pads
,
const
std
::
array
<
index_t
,
NDimSpatial
>&
input_right_pads
)
const
std
::
array
<
index_t
,
NDimSpatial
>&
input_right_pads
,
const
index_t
Conv_N
)
{
const
auto
in_gemmmraw_gemmkraw_desc
=
conv_to_gemm_transformer
.
template
MakeADescriptor_M_K
<
ALay
>(
a_g_n_c_wis_lengths
,
...
...
@@ -321,7 +343,8 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
conv_filter_strides
,
conv_filter_dilations
,
input_left_pads
,
input_right_pads
);
input_right_pads
,
Conv_N
);
const
auto
in_gemmm_gemmk_desc
=
matrix_padder
.
PadADescriptor_M_K
(
in_gemmmraw_gemmkraw_desc
);
...
...
@@ -347,11 +370,12 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
template
<
typename
ELay
>
static
auto
MakeEGridDescriptor_M_N
(
const
std
::
array
<
index_t
,
NDimSpatial
+
3
>&
e_g_n_k_wos_lengths
,
const
std
::
array
<
index_t
,
NDimSpatial
+
3
>&
e_g_n_k_wos_strides
)
const
std
::
array
<
index_t
,
NDimSpatial
+
3
>&
e_g_n_k_wos_strides
,
const
index_t
Conv_N
)
{
const
auto
out_gemmmraw_gemmnraw_desc
=
conv_to_gemm_transformer
.
template
MakeCDescriptor_M_N
<
ELay
>(
e_g_n_k_wos_lengths
,
e_g_n_k_wos_strides
);
conv_to_gemm_transformer
.
template
MakeCDescriptor_M_N
<
ELay
>(
e_g_n_k_wos_lengths
,
e_g_n_k_wos_strides
,
Conv_N
);
const
auto
out_gemmm_gemmn_desc
=
matrix_padder
.
PadCDescriptor_M_N
(
out_gemmmraw_gemmnraw_desc
);
...
...
@@ -363,24 +387,25 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
// Pass e_g_n_k_wos_lengths for logical broadcast.
static
auto
MakeDsGridDescriptor_M_N
(
const
std
::
array
<
index_t
,
NDimSpatial
+
3
>&
e_g_n_k_wos_lengths
,
const
std
::
array
<
std
::
array
<
index_t
,
NDimSpatial
+
3
>
,
NumDTensor
>&
ds_g_n_k_wos_strides
)
const
std
::
array
<
std
::
array
<
index_t
,
NDimSpatial
+
3
>
,
NumDTensor
>&
ds_g_n_k_wos_strides
,
const
index_t
Conv_N
)
{
return
generate_tuple
(
[
&
](
auto
i
)
{
using
DLayout
=
remove_cvref_t
<
tuple_element_t
<
i
.
value
,
DsLayout
>>
;
return
DeviceOp
::
MakeEGridDescriptor_M_N
<
DLayout
>
(
e_g_n_k_wos_lengths
,
ds_g_n_k_wos_strides
[
i
]);
return
DeviceOp
::
MakeEGridDescriptor_M_N
<
DLayout
>
(
e_g_n_k_wos_lengths
,
ds_g_n_k_wos_strides
[
i
]
,
Conv_N
);
},
Number
<
NumDTensor
>
{});
}
// desc for problem definition
using
AGridDesc_M_K
=
remove_cvref_t
<
decltype
(
MakeAGridDescriptor_M_K
<
ALayout
>
(
{},
{},
{},
{},
{},
{},
{},
{},
{},
{}))
>
;
{},
{},
{},
{},
{},
{},
{},
{},
{},
{}
,
1
))
>
;
using
BGridDesc_N_K
=
remove_cvref_t
<
decltype
(
MakeBGridDescriptor_N_K
<
BLayout
>
({},
{}))
>
;
using
DsGridDesc_M_N
=
remove_cvref_t
<
decltype
(
MakeDsGridDescriptor_M_N
({},
{}))
>
;
using
EGridDesc_M_N
=
remove_cvref_t
<
decltype
(
MakeEGridDescriptor_M_N
<
ELayout
>
({},
{}))
>
;
using
DsGridDesc_M_N
=
remove_cvref_t
<
decltype
(
MakeDsGridDescriptor_M_N
({},
{}
,
1
))
>
;
using
EGridDesc_M_N
=
remove_cvref_t
<
decltype
(
MakeEGridDescriptor_M_N
<
ELayout
>
({},
{}
,
1
))
>
;
// If we are using multiAB and one of the template datatype parameters is not a tuple, convert
// it to it
...
...
@@ -468,6 +493,12 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
p_ds_grid_
{},
p_e_grid_
{
static_cast
<
EDataType
*>
(
p_e
)},
num_group_
{
a_g_n_c_wis_lengths
[
0
]},
conv_N_per_block_
{
conv_to_gemm_transformer
.
template
GetSplitedNSize
<
ADataType
,
EDataType
>(
a_g_n_c_wis_lengths
,
a_g_n_c_wis_strides
,
e_g_n_k_wos_lengths
,
e_g_n_k_wos_strides
)},
a_grid_desc_m_k_
{
DeviceOp
::
MakeAGridDescriptor_M_K
<
ALayout
>
(
a_g_n_c_wis_lengths
,
a_g_n_c_wis_strides
,
b_g_k_c_xs_lengths
,
...
...
@@ -477,12 +508,13 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
conv_filter_strides
,
conv_filter_dilations
,
input_left_pads
,
input_right_pads
)},
input_right_pads
,
conv_N_per_block_
)},
b_grid_desc_n_k_
{
DeviceOp
::
MakeBGridDescriptor_N_K
<
BLayout
>
(
b_g_k_c_xs_lengths
,
b_g_k_c_xs_strides
)},
ds_grid_desc_m_n_
{},
e_grid_desc_m_n_
{
DeviceOp
::
MakeEGridDescriptor_M_N
<
ELayout
>
(
e_g_n_k_wos_lengths
,
e_g_n_k_wos_strides
)},
e_grid_desc_m_n_
{
DeviceOp
::
MakeEGridDescriptor_M_N
<
ELayout
>
(
e_g_n_k_wos_lengths
,
e_g_n_k_wos_strides
,
conv_N_per_block_
)},
a_grid_desc_ak0_m_ak1_
{
GridwiseGemm
::
MakeDefaultAGridDescriptor_AK0_M_AK1
(
a_grid_desc_m_k_
)},
b_grid_desc_bk0_n_bk1_
{
...
...
@@ -490,7 +522,8 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
ds_grid_desc_mblock_mperblock_nblock_nperblock_
{},
e_grid_desc_mblock_mperblock_nblock_nperblock_
{},
block_2_etile_map_
{
GridwiseGemm
::
MakeDefaultBlock2ETileMap
(
e_grid_desc_m_n_
)},
compute_ptr_offset_of_batch_
{},
compute_ptr_offset_of_groups_
{},
compute_ptr_offset_of_n_
{},
a_element_op_
{
a_element_op
},
b_element_op_
{
b_element_op
},
cde_element_op_
{
cde_element_op
},
...
...
@@ -511,8 +544,8 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
if
constexpr
(
isMultiA
||
isMultiB
)
{
static_for
<
0
,
NumATensor
,
1
>
{}([
&
](
auto
i
)
{
// Init compute_ptr_offset_of_
batch
_ for multiple AB
compute_ptr_offset_of_
batch
_
.
BatchStrideA_
(
i
)
=
a_g_n_c_wis_strides
[
0
];
// Init compute_ptr_offset_of_
groups
_ for multiple AB
compute_ptr_offset_of_
groups
_
.
BatchStrideA_
(
i
)
=
a_g_n_c_wis_strides
[
0
];
// Use GemmADataType/GemmBDataType to iterate over tuple (even if passed data
// type is not tuple)
...
...
@@ -524,16 +557,23 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
{
// p_as is tuple
p_as_grid_
(
i
)
=
static_cast
<
const
DataType
*>
(
p_as
[
i
.
value
]);
// compute_ptr_offset_of_n_ not need BatchStrideB so
// in case of MultiA is false but isMultiB is true
// BatchStrideA_ is not tuple.
compute_ptr_offset_of_n_
.
BatchStrideA_
(
i
)
=
a_g_n_c_wis_strides
[
1
]
*
conv_N_per_block_
;
}
else
{
// if MultiB and not MultiA then p_as is single pointer
p_as_grid_
(
i
)
=
static_cast
<
const
DataType
*>
(
p_as
);
compute_ptr_offset_of_n_
.
BatchStrideA_
=
a_g_n_c_wis_strides
[
1
]
*
conv_N_per_block_
;
}
});
static_for
<
0
,
NumBTensor
,
1
>
{}([
&
](
auto
i
)
{
// Init compute_ptr_offset_of_
batch
_ for multiple AB
compute_ptr_offset_of_
batch
_
.
BatchStrideB_
(
i
)
=
b_g_k_c_xs_strides
[
0
];
// Init compute_ptr_offset_of_
groups
_ for multiple AB
compute_ptr_offset_of_
groups
_
.
BatchStrideB_
(
i
)
=
b_g_k_c_xs_strides
[
0
];
using
DataType
=
remove_cvref_t
<
tuple_element_t
<
i
.
value
,
GemmBDataType
>>
;
// It is possible that one of the AB is a pointer and one is a tuple.
...
...
@@ -553,8 +593,9 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
}
else
{
compute_ptr_offset_of_batch_
.
BatchStrideA_
=
a_g_n_c_wis_strides
[
0
];
compute_ptr_offset_of_batch_
.
BatchStrideB_
=
b_g_k_c_xs_strides
[
0
];
compute_ptr_offset_of_groups_
.
BatchStrideA_
=
a_g_n_c_wis_strides
[
0
];
compute_ptr_offset_of_groups_
.
BatchStrideB_
=
b_g_k_c_xs_strides
[
0
];
compute_ptr_offset_of_n_
.
BatchStrideA_
=
a_g_n_c_wis_strides
[
1
]
*
conv_N_per_block_
;
// p_as and p_bs are pointers
p_as_grid_
(
I0
)
=
static_cast
<
const
ADataType
*>
(
p_as
);
...
...
@@ -570,13 +611,16 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
p_ds_grid_
(
i
)
=
static_cast
<
const
DDataType
*>
(
p_ds
[
i
]);
// D batch stride
compute_ptr_offset_of_batch_
.
BatchStrideDs_
(
i
)
=
ds_g_n_k_wos_strides
[
i
][
0
];
compute_ptr_offset_of_groups_
.
BatchStrideDs_
(
i
)
=
ds_g_n_k_wos_strides
[
i
][
0
];
compute_ptr_offset_of_n_
.
BatchStrideDs_
(
i
)
=
ds_g_n_k_wos_strides
[
i
][
1
]
*
conv_N_per_block_
;
// D desc
ds_grid_desc_m_n_
(
i
)
=
DeviceOp
::
MakeEGridDescriptor_M_N
<
DLayout
>
(
e_g_n_k_wos_lengths
,
ds_g_n_k_wos_strides
[
i
]);
e_g_n_k_wos_lengths
,
ds_g_n_k_wos_strides
[
i
]
,
conv_N_per_block_
);
});
compute_ptr_offset_of_batch_
.
BatchStrideE_
=
e_g_n_k_wos_strides
[
0
];
compute_ptr_offset_of_groups_
.
BatchStrideE_
=
e_g_n_k_wos_strides
[
0
];
compute_ptr_offset_of_n_
.
BatchStrideE_
=
e_g_n_k_wos_strides
[
1
]
*
conv_N_per_block_
;
// populate desc for Ds/E
if
constexpr
(
isMultiA
||
isMultiB
)
...
...
@@ -638,6 +682,8 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
// tensor descriptors for problem definiton
index_t
num_group_
;
index_t
conv_N_per_block_
;
AGridDesc_M_K
a_grid_desc_m_k_
;
BGridDesc_N_K
b_grid_desc_n_k_
;
DsGridDesc_M_N
ds_grid_desc_m_n_
;
...
...
@@ -655,7 +701,8 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
// for computing batch offset
ComputePtrOffsetOfStridedBatch
<
NumATensor
,
NumBTensor
,
NumDTensor
>
compute_ptr_offset_of_batch_
;
compute_ptr_offset_of_groups_
;
ComputePtrOffsetOfStridedBatch
<
NumATensor
,
I1
,
NumDTensor
>
compute_ptr_offset_of_n_
;
// element-wise op
AElementwiseOperation
a_element_op_
;
...
...
@@ -689,8 +736,12 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
arg
.
Print
();
}
const
index_t
grid_size
=
arg
.
block_2_etile_map_
.
CalculateGridSize
(
arg
.
e_grid_desc_m_n_
)
*
arg
.
num_group_
;
const
index_t
num_workgroups_per_Conv_N
=
arg
.
a_g_n_c_wis_lengths_
[
I1
]
/
arg
.
conv_N_per_block_
;
const
index_t
gdx
=
arg
.
block_2_etile_map_
.
CalculateGridSize
(
arg
.
e_grid_desc_m_n_
);
const
index_t
gdy
=
arg
.
num_group_
*
num_workgroups_per_Conv_N
;
const
index_t
gdz
=
1
;
const
auto
K
=
arg
.
a_grid_desc_ak0_m_ak1_
.
GetLength
(
I0
)
*
arg
.
a_grid_desc_ak0_m_ak1_
.
GetLength
(
I2
);
...
...
@@ -721,6 +772,7 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
DeviceOp
::
EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock
,
Block2ETileMap
,
ComputePtrOffsetOfStridedBatch
<
NumATensor
,
NumBTensor
,
NumDTensor
>
,
ComputePtrOffsetOfStridedBatch
<
NumATensor
,
I1
,
NumDTensor
>
,
has_main_loop
,
isMultiA
,
isMultiB
>
;
...
...
@@ -728,7 +780,7 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
return
launch_and_time_kernel
(
stream_config
,
kernel
,
dim3
(
g
rid_size
),
dim3
(
g
dx
,
gdy
,
gdz
),
dim3
(
BlockSize
),
0
,
arg
.
p_as_grid_
,
...
...
@@ -744,7 +796,8 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
arg
.
ds_grid_desc_mblock_mperblock_nblock_nperblock_
,
arg
.
e_grid_desc_mblock_mperblock_nblock_nperblock_
,
arg
.
block_2_etile_map_
,
arg
.
compute_ptr_offset_of_batch_
);
arg
.
compute_ptr_offset_of_groups_
,
arg
.
compute_ptr_offset_of_n_
);
}
else
{
...
...
@@ -763,6 +816,7 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
DeviceOp
::
EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock
,
Block2ETileMap
,
ComputePtrOffsetOfStridedBatch
<
NumATensor
,
NumBTensor
,
NumDTensor
>
,
ComputePtrOffsetOfStridedBatch
<
NumATensor
,
I1
,
NumDTensor
>
,
has_main_loop
,
isMultiA
,
isMultiB
>
;
...
...
@@ -770,7 +824,7 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
return
launch_and_time_kernel
(
stream_config
,
kernel
,
dim3
(
g
rid_size
),
dim3
(
g
dx
,
gdy
,
gdz
),
dim3
(
BlockSize
),
0
,
arg
.
p_as_grid_
.
At
(
I0
),
// Pass just A descriptor instead of tuple
...
...
@@ -786,7 +840,8 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
arg
.
ds_grid_desc_mblock_mperblock_nblock_nperblock_
,
arg
.
e_grid_desc_mblock_mperblock_nblock_nperblock_
,
arg
.
block_2_etile_map_
,
arg
.
compute_ptr_offset_of_batch_
);
arg
.
compute_ptr_offset_of_groups_
,
arg
.
compute_ptr_offset_of_n_
);
}
};
...
...
include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_abd_xdl_cshuffle_v3.hpp
View file @
ce30621d
...
...
@@ -60,7 +60,7 @@ template <typename GridwiseGemm,
typename
AGridDesc_AK0_M_K1
,
typename
BGridDesc_BK0_N_K1
,
typename
CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock
,
typename
ComputePtrOffset
OfBatch
,
typename
ComputePtrOffset
,
bool
HasMainKBlockLoop
,
InMemoryDataOperationEnum
CGlobalMemoryDataOperation
,
index_t
MinimumOccupancy
=
1
,
...
...
@@ -69,26 +69,28 @@ __global__ void
#if CK_USE_LAUNCH_BOUNDS
__launch_bounds__
(
CK_MAX_THREAD_PER_BLOCK
,
MinimumOccupancy
)
#endif
kernel_grouped_conv_fwd_xdl_cshuffle_v3
(
typename
GridwiseGemm
::
Argument
karg
,
const
A
GridDesc_
A
K0_
M
_K1
a
_grid_desc_
a
k0_
m_a
k1
,
const
B
GridDesc_
BK0_N_K1
b_grid_desc_bk0_n_bk1
,
const
CG
rid
D
esc_
MB
lock_
MP
er
B
lock_
NB
lock_
NP
er
B
lock
c_grid_desc_mblock_mperblock_nblock_nperblock
,
const
ComputePtrOffset
OfBatch
compute_ptr_offset_of_
batch
,
const
index_t
batch
_count
)
kernel_grouped_conv_fwd_xdl_cshuffle_v3
(
typename
GridwiseGemm
::
Argument
karg
,
const
AGridDesc_AK0_M_K1
a_grid_desc_ak0_m_ak1
,
const
B
GridDesc_
B
K0_
N
_K1
b
_grid_desc_
b
k0_
n_b
k1
,
const
C
GridDesc_
MBlock_MPerBlock_NBlock_NPerBlock
c_g
rid
_d
esc_
mb
lock_
mp
er
b
lock_
nb
lock_
np
er
b
lock
,
const
ComputePtrOffset
compute_ptr_offset_of_groups
,
const
ComputePtrOffset
compute_ptr_offset_of_
n
,
const
index_t
groups
_count
)
{
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx9__))
// offset base pointer for each work-group
const
index_t
num_blocks_per_batch
=
__builtin_amdgcn_readfirstlane
(
gridDim
.
y
/
batch_count
);
const
index_t
num_blocks_per_batch
=
__builtin_amdgcn_readfirstlane
(
gridDim
.
y
/
groups_count
);
const
index_t
&
num_blocks_per_n
=
groups_count
;
const
index_t
g_idx
=
__builtin_amdgcn_readfirstlane
(
blockIdx
.
y
/
num_blocks_per_batch
);
const
index_t
n_idx
=
__builtin_amdgcn_readfirstlane
(
blockIdx
.
y
/
num_blocks_per_n
);
const
long_index_t
a_batch_offset
=
__builtin_amdgcn_readfirstlane
(
static_ca
st
<
long_index_t
>
(
compute_ptr_offset_of_
batch
.
Get
A
PtrOffset
(
g_idx
)
))
;
const
long_index_t
b
_batch_offset
=
__builtin_amdgcn_readfirstlane
(
static_cast
<
long_index_t
>
(
compute_ptr_offset_of_batch
.
GetBPtrOffset
(
g_idx
)));
const
long_index_t
e_batch
_offset
=
__builtin_amdgcn_readfirstlane
(
static_ca
st
<
long_index_t
>
(
compute_ptr_offset_of_
batch
.
GetEPtrOffset
(
g
_idx
)
))
;
const
long_index_t
a_batch_offset
=
compute_ptr_offset_of_groups
.
GetAPtrOffset
(
g_idx
);
con
st
long_index_t
b_batch_offset
=
compute_ptr_offset_of_
groups
.
Get
B
PtrOffset
(
g_idx
);
const
long_index_t
e
_batch_offset
=
compute_ptr_offset_of_groups
.
GetEPtrOffset
(
g_idx
);
const
long_index_t
a_n
_offset
=
compute_ptr_offset_of_n
.
GetAPtrOffset
(
n_idx
);
con
st
long_index_t
e_n_offset
=
compute_ptr_offset_of_
n
.
GetEPtrOffset
(
n
_idx
);
__shared__
char
p_shared
[
GridwiseGemm
::
GetSharedMemoryNumberOfByte
()];
...
...
@@ -97,9 +99,9 @@ __global__ void
CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock
,
HasMainKBlockLoop
,
CGlobalMemoryDataOperation
,
TailNum
>(
karg
.
p_a_grid
+
a_batch_offset
,
TailNum
>(
karg
.
p_a_grid
+
a_batch_offset
+
a_n_offset
,
karg
.
p_b_grid
+
b_batch_offset
,
karg
.
p_c_grid
+
e_batch_offset
,
karg
.
p_c_grid
+
e_batch_offset
+
e_n_offset
,
p_shared
,
karg
,
a_grid_desc_ak0_m_ak1
,
...
...
@@ -114,7 +116,7 @@ template <typename GridwiseGemm,
typename
AGridDesc_AK0_M_K1
,
typename
BGridDesc_BK0_N_K1
,
typename
CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock
,
typename
ComputePtrOffset
OfBatch
,
typename
ComputePtrOffset
,
bool
HasMainKBlockLoop
,
InMemoryDataOperationEnum
CGlobalMemoryDataOperation
,
index_t
MinimumOccupancy
=
1
,
...
...
@@ -129,20 +131,23 @@ __global__ void
const
BGridDesc_BK0_N_K1
b_grid_desc_bk0_n_bk1
,
const
CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock
c_grid_desc_mblock_mperblock_nblock_nperblock
,
const
ComputePtrOffsetOfBatch
compute_ptr_offset_of_batch
,
const
index_t
batch_count
)
const
ComputePtrOffset
compute_ptr_offset_of_groups
,
const
ComputePtrOffset
compute_ptr_offset_of_n
,
const
index_t
groups_count
)
{
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx9__))
// offset base pointer for each work-group
const
index_t
num_blocks_per_batch
=
__builtin_amdgcn_readfirstlane
(
gridDim
.
y
/
batch_count
);
const
index_t
num_blocks_per_batch
=
__builtin_amdgcn_readfirstlane
(
gridDim
.
y
/
groups_count
);
const
index_t
&
num_blocks_per_n
=
groups_count
;
const
index_t
g_idx
=
__builtin_amdgcn_readfirstlane
(
blockIdx
.
y
/
num_blocks_per_batch
);
const
index_t
n_idx
=
__builtin_amdgcn_readfirstlane
(
blockIdx
.
y
/
num_blocks_per_n
);
const
long_index_t
a_batch_offset
=
compute_ptr_offset_of_groups
.
GetAPtrOffset
(
g_idx
);
const
long_index_t
b_batch_offset
=
compute_ptr_offset_of_groups
.
GetBPtrOffset
(
g_idx
);
const
long_index_t
e_batch_offset
=
compute_ptr_offset_of_groups
.
GetEPtrOffset
(
g_idx
);
const
long_index_t
a_batch_offset
=
__builtin_amdgcn_readfirstlane
(
static_cast
<
long_index_t
>
(
compute_ptr_offset_of_batch
.
GetAPtrOffset
(
g_idx
)));
const
long_index_t
b_batch_offset
=
__builtin_amdgcn_readfirstlane
(
static_cast
<
long_index_t
>
(
compute_ptr_offset_of_batch
.
GetBPtrOffset
(
g_idx
)));
const
long_index_t
e_batch_offset
=
__builtin_amdgcn_readfirstlane
(
static_cast
<
long_index_t
>
(
compute_ptr_offset_of_batch
.
GetEPtrOffset
(
g_idx
)));
const
long_index_t
a_n_offset
=
compute_ptr_offset_of_n
.
GetAPtrOffset
(
n_idx
);
const
long_index_t
e_n_offset
=
compute_ptr_offset_of_n
.
GetEPtrOffset
(
n_idx
);
// Pass two lds pointer is the key to tell compiler that ds_read/write
// operate on different lds chunk at same time without order dependecy
...
...
@@ -154,9 +159,9 @@ __global__ void
CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock
,
HasMainKBlockLoop
,
CGlobalMemoryDataOperation
,
TailNum
>(
karg
.
p_a_grid
+
a_batch_offset
,
TailNum
>(
karg
.
p_a_grid
+
a_batch_offset
+
a_n_offset
,
karg
.
p_b_grid
+
b_batch_offset
,
karg
.
p_c_grid
+
e_batch_offset
,
karg
.
p_c_grid
+
e_batch_offset
+
e_n_offset
,
p_shared_0
,
p_shared_1
,
karg
,
...
...
@@ -294,7 +299,9 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3
const
std
::
array
<
index_t
,
NDimSpatial
>&
conv_filter_strides
,
const
std
::
array
<
index_t
,
NDimSpatial
>&
conv_filter_dilations
,
const
std
::
array
<
index_t
,
NDimSpatial
>&
input_left_pads
,
const
std
::
array
<
index_t
,
NDimSpatial
>&
input_right_pads
)
const
std
::
array
<
index_t
,
NDimSpatial
>&
input_right_pads
,
const
index_t
Conv_N
)
{
const
auto
in_gemmmraw_gemmkraw_desc
=
conv_to_gemm_transformer
.
template
MakeADescriptor_M_K
<
ALay
>(
a_g_n_c_wis_lengths
,
...
...
@@ -306,7 +313,8 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3
conv_filter_strides
,
conv_filter_dilations
,
input_left_pads
,
input_right_pads
);
input_right_pads
,
Conv_N
);
const
auto
in_gemmm_gemmk_desc
=
matrix_padder
.
PadADescriptor_M_K
(
in_gemmmraw_gemmkraw_desc
);
...
...
@@ -350,11 +358,13 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3
template
<
typename
ELay
>
static
auto
MakeEGridDescriptor_M_N
(
const
std
::
array
<
index_t
,
NDimSpatial
+
3
>&
e_g_n_k_wos_lengths
,
const
std
::
array
<
index_t
,
NDimSpatial
+
3
>&
e_g_n_k_wos_strides
)
const
std
::
array
<
index_t
,
NDimSpatial
+
3
>&
e_g_n_k_wos_strides
,
const
index_t
Conv_N
)
{
const
auto
out_gemmmraw_gemmnraw_desc
=
conv_to_gemm_transformer
.
template
MakeCDescriptor_M_N
<
ELay
>(
e_g_n_k_wos_lengths
,
e_g_n_k_wos_strides
);
conv_to_gemm_transformer
.
template
MakeCDescriptor_M_N
<
ELay
>(
e_g_n_k_wos_lengths
,
e_g_n_k_wos_strides
,
Conv_N
);
const
auto
out_gemmm_gemmn_desc
=
matrix_padder
.
PadCDescriptor_M_N
(
out_gemmmraw_gemmnraw_desc
);
...
...
@@ -363,7 +373,7 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3
}
// desc for problem definition
using
EGridDesc_M_N
=
remove_cvref_t
<
decltype
(
MakeEGridDescriptor_M_N
<
ELayout
>
({},
{}))
>
;
using
EGridDesc_M_N
=
remove_cvref_t
<
decltype
(
MakeEGridDescriptor_M_N
<
ELayout
>
({},
{}
,
1
))
>
;
#define GridwiseGemmV3TemplateParams \
tensor_layout::gemm::RowMajor, tensor_layout::gemm::ColumnMajor, \
...
...
@@ -396,7 +406,7 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3
// desc for blockwise copy
using
AGridDesc_AK0_M_AK1
=
remove_cvref_t
<
decltype
(
MakeAGridDescriptor_AK0_M_AK1
<
ALayout
>
(
{},
{},
{},
{},
{},
{},
{},
{},
{},
{}))
>
;
{},
{},
{},
{},
{},
{},
{},
{},
{},
{}
,
1
))
>
;
using
BGridDesc_BK0_N_BK1
=
remove_cvref_t
<
decltype
(
MakeBGridDescriptor_BK0_N_BK1
<
BLayout
>
({},
{}))
>
;
using
EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock
=
...
...
@@ -429,6 +439,12 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3
p_b_grid_
{},
p_e_grid_
{
static_cast
<
EDataType
*>
(
p_e
)},
num_group_
{
a_g_n_c_wis_lengths
[
0
]},
conv_N_per_block_
{
conv_to_gemm_transformer
.
template
GetSplitedNSize
<
ADataType
,
EDataType
>(
a_g_n_c_wis_lengths
,
a_g_n_c_wis_strides
,
e_g_n_k_wos_lengths
,
e_g_n_k_wos_strides
)},
a_grid_desc_ak0_m_ak1_
{
MakeAGridDescriptor_AK0_M_AK1
<
ALayout
>
(
a_g_n_c_wis_lengths
,
a_g_n_c_wis_strides
,
b_g_k_c_xs_lengths
,
...
...
@@ -438,13 +454,15 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3
conv_filter_strides
,
conv_filter_dilations
,
input_left_pads
,
input_right_pads
)},
input_right_pads
,
conv_N_per_block_
)},
b_grid_desc_bk0_n_bk1_
{
MakeBGridDescriptor_BK0_N_BK1
<
BLayout
>
(
b_g_k_c_xs_lengths
,
b_g_k_c_xs_strides
)},
e_grid_desc_m_n_
{
DeviceOp
::
MakeEGridDescriptor_M_N
<
ELayout
>
(
e_g_n_k_wos_lengths
,
e_g_n_k_wos_strides
)},
e_grid_desc_m_n_
{
DeviceOp
::
MakeEGridDescriptor_M_N
<
ELayout
>
(
e_g_n_k_wos_lengths
,
e_g_n_k_wos_strides
,
conv_N_per_block_
)},
e_grid_desc_mblock_mperblock_nblock_nperblock_
{},
compute_ptr_offset_of_batch_
{},
compute_ptr_offset_of_groups_
{},
compute_ptr_offset_of_n_
{},
a_element_op_
{
a_element_op
},
b_element_op_
{
b_element_op
},
cde_element_op_
{
cde_element_op
},
...
...
@@ -459,15 +477,17 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3
input_left_pads_
{
input_left_pads
},
input_right_pads_
{
input_right_pads
}
{
// A/B/E Batch Stride
compute_ptr_offset_of_batch_
.
BatchStrideA_
=
a_g_n_c_wis_strides
[
0
];
compute_ptr_offset_of_batch_
.
BatchStrideB_
=
b_g_k_c_xs_strides
[
0
];
// A/B/E Batch/N Stride
compute_ptr_offset_of_groups_
.
BatchStrideA_
=
a_g_n_c_wis_strides
[
0
];
compute_ptr_offset_of_groups_
.
BatchStrideB_
=
b_g_k_c_xs_strides
[
0
];
compute_ptr_offset_of_n_
.
BatchStrideA_
=
a_g_n_c_wis_strides
[
1
]
*
conv_N_per_block_
;
// p_as and p_bs are pointers
p_a_grid_
=
static_cast
<
const
ADataType
*>
(
p_as
);
p_b_grid_
=
static_cast
<
const
BDataType
*>
(
p_bs
);
compute_ptr_offset_of_batch_
.
BatchStrideE_
=
e_g_n_k_wos_strides
[
0
];
compute_ptr_offset_of_groups_
.
BatchStrideE_
=
e_g_n_k_wos_strides
[
0
];
compute_ptr_offset_of_n_
.
BatchStrideE_
=
e_g_n_k_wos_strides
[
1
]
*
conv_N_per_block_
;
e_grid_desc_mblock_mperblock_nblock_nperblock_
=
MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
(
e_grid_desc_m_n_
);
...
...
@@ -488,6 +508,7 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3
// tensor descriptors for problem definiton
index_t
num_group_
;
index_t
conv_N_per_block_
;
// tensor descriptors for block/thread-wise copy
AGridDesc_AK0_M_AK1
a_grid_desc_ak0_m_ak1_
;
...
...
@@ -496,7 +517,8 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3
EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock
e_grid_desc_mblock_mperblock_nblock_nperblock_
;
// for computing batch offset
ComputePtrOffsetOfStridedBatch
<
I1
,
I1
,
I0
>
compute_ptr_offset_of_batch_
;
ComputePtrOffsetOfStridedBatch
<
I1
,
I1
,
I0
>
compute_ptr_offset_of_groups_
;
ComputePtrOffsetOfStridedBatch
<
I1
,
I1
,
I0
>
compute_ptr_offset_of_n_
;
// element-wise op
AElementwiseOperation
a_element_op_
;
...
...
@@ -538,11 +560,14 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3
const
index_t
GemmK
=
arg
.
a_grid_desc_ak0_m_ak1_
.
GetLength
(
I0
)
*
arg
.
a_grid_desc_ak0_m_ak1_
.
GetLength
(
I2
);
const
index_t
num_workgroups_per_Conv_N
=
arg
.
a_g_n_c_wis_lengths_
[
I1
]
/
arg
.
conv_N_per_block_
;
index_t
gdx
,
gdy
,
gdz
;
std
::
tie
(
gdx
,
gdy
,
gdz
)
=
GridwiseGemm
::
CalculateGridSize
(
GemmM
,
GemmN
,
I1
/*arg.KBatch*/
);
gdy
*=
arg
.
num_group_
;
gdy
*=
arg
.
num_group_
*
num_workgroups_per_Conv_N
;
index_t
K_split
=
(
GemmK
+
KPerBlock
-
1
)
/
KPerBlock
*
KPerBlock
;
const
bool
has_main_k_block_loop
=
GridwiseGemm
::
CalculateHasMainKBlockLoop
(
K_split
);
...
...
@@ -579,7 +604,8 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3
arg
.
a_grid_desc_ak0_m_ak1_
,
arg
.
b_grid_desc_bk0_n_bk1_
,
arg
.
e_grid_desc_mblock_mperblock_nblock_nperblock_
,
arg
.
compute_ptr_offset_of_batch_
,
arg
.
compute_ptr_offset_of_groups_
,
arg
.
compute_ptr_offset_of_n_
,
arg
.
num_group_
);
}
else
...
...
@@ -594,7 +620,8 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3
arg
.
a_grid_desc_ak0_m_ak1_
,
arg
.
b_grid_desc_bk0_n_bk1_
,
arg
.
e_grid_desc_mblock_mperblock_nblock_nperblock_
,
arg
.
compute_ptr_offset_of_batch_
,
arg
.
compute_ptr_offset_of_groups_
,
arg
.
compute_ptr_offset_of_n_
,
arg
.
num_group_
);
}
};
...
...
include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_d_multiple_r_xdl_cshuffle.hpp
View file @
ce30621d
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-202
3
, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-202
4
, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
...
...
@@ -338,7 +338,8 @@ struct DeviceGroupedConvFwdMultipleDMultipleR_Xdl_CShuffle
conv_filter_strides
,
conv_filter_dilations
,
input_left_pads
,
input_right_pads
);
input_right_pads
,
a_g_n_c_wis_lengths
[
I1
]);
const
auto
in_gemmm_gemmk_desc
=
matrix_padder
.
PadADescriptor_M_K
(
in_gemmmraw_gemmkraw_desc
);
...
...
@@ -367,8 +368,8 @@ struct DeviceGroupedConvFwdMultipleDMultipleR_Xdl_CShuffle
const
std
::
array
<
index_t
,
NDimSpatial
+
3
>&
e_g_n_k_wos_strides
)
{
const
auto
out_gemmmraw_gemmnraw_desc
=
conv_to_gemm_transformer
.
template
MakeCDescriptor_M_N
<
ELay
>(
e_g_n_k_wos_lengths
,
e_g_n_k_wos_strides
);
conv_to_gemm_transformer
.
template
MakeCDescriptor_M_N
<
ELay
>(
e_g_n_k_wos_lengths
,
e_g_n_k_wos_strides
,
e_g_n_k_wos_lengths
[
I1
]
);
const
auto
out_gemmm_gemmn_desc
=
matrix_padder
.
PadCDescriptor_M_N
(
out_gemmmraw_gemmnraw_desc
);
...
...
include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_d_wmma_cshuffle.hpp
View file @
ce30621d
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-202
3
, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-202
4
, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
...
...
@@ -163,7 +163,8 @@ struct DeviceGroupedConvFwdMultipleD_Wmma_CShuffle
conv_filter_strides
,
conv_filter_dilations
,
input_left_pads
,
input_right_pads
);
input_right_pads
,
a_g_n_c_wis_lengths
[
I1
]);
const
auto
in_gemmm_gemmk_desc
=
matrix_padder
.
PadADescriptor_M_K
(
in_gemmmraw_gemmkraw_desc
);
...
...
@@ -255,8 +256,8 @@ struct DeviceGroupedConvFwdMultipleD_Wmma_CShuffle
const
std
::
array
<
index_t
,
NDimSpatial
+
3
>&
e_g_n_k_wos_strides
)
{
const
auto
out_gemmmraw_gemmnraw_desc
=
conv_to_gemm_transformer
.
template
MakeCDescriptor_M_N
<
ELay
>(
e_g_n_k_wos_lengths
,
e_g_n_k_wos_strides
);
conv_to_gemm_transformer
.
template
MakeCDescriptor_M_N
<
ELay
>(
e_g_n_k_wos_lengths
,
e_g_n_k_wos_strides
,
e_g_n_k_wos_lengths
[
I1
]
);
const
auto
out_gemmm_gemmn_desc
=
matrix_padder
.
PadCDescriptor_M_N
(
out_gemmmraw_gemmnraw_desc
);
...
...
include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_utils.hpp
View file @
ce30621d
...
...
@@ -68,14 +68,14 @@ template <index_t NumATensor, index_t NumBTensor, index_t NumDTensor>
struct
ComputePtrOffsetOfStridedBatch
<
NumATensor
,
NumBTensor
,
NumDTensor
,
ck
::
enable_if_t
<
(
NumATensor
>
1
||
NumBTensor
>
1
)
>>
enable_if_t
<
(
NumATensor
>
1
||
NumBTensor
>
1
)
>>
{
ComputePtrOffsetOfStridedBatch
()
=
default
;
ComputePtrOffsetOfStridedBatch
(
Array
<
ck
::
index_t
,
NumATensor
>&
BatchStrideAs
,
Array
<
ck
::
index_t
,
NumBTensor
>&
BatchStrideBs
,
Array
<
ck
::
index_t
,
NumDTensor
>&
BatchStrideDs
,
index_t
BatchStrideE
)
ComputePtrOffsetOfStridedBatch
(
Array
<
long_
index_t
,
NumATensor
>&
BatchStrideAs
,
Array
<
long_
index_t
,
NumBTensor
>&
BatchStrideBs
,
Array
<
long_
index_t
,
NumDTensor
>&
BatchStrideDs
,
long_
index_t
BatchStrideE
)
:
BatchStrideA_
(
BatchStrideAs
),
BatchStrideB_
(
BatchStrideBs
),
BatchStrideDs_
(
BatchStrideDs
),
...
...
@@ -87,7 +87,7 @@ struct ComputePtrOffsetOfStridedBatch<NumATensor,
{
Array
<
long_index_t
,
NumATensor
>
as_offset
;
static_for
<
0
,
NumATensor
,
1
>
{}(
[
&
](
auto
i
)
{
as_offset
(
i
)
=
g_idx
*
static_cast
<
long_index_t
>
(
BatchStrideA_
[
i
]
)
;
});
[
&
](
auto
i
)
{
as_offset
(
i
)
=
static_cast
<
long_index_t
>
(
g_idx
)
*
BatchStrideA_
[
i
];
});
return
as_offset
;
}
...
...
@@ -95,7 +95,7 @@ struct ComputePtrOffsetOfStridedBatch<NumATensor,
{
Array
<
long_index_t
,
NumBTensor
>
bs_offset
;
static_for
<
0
,
NumBTensor
,
1
>
{}(
[
&
](
auto
i
)
{
bs_offset
(
i
)
=
g_idx
*
static_cast
<
long_index_t
>
(
BatchStrideB_
[
i
]
)
;
});
[
&
](
auto
i
)
{
bs_offset
(
i
)
=
static_cast
<
long_index_t
>
(
g_idx
)
*
BatchStrideB_
[
i
];
});
return
bs_offset
;
}
...
...
@@ -103,40 +103,40 @@ struct ComputePtrOffsetOfStridedBatch<NumATensor,
{
Array
<
long_index_t
,
NumDTensor
>
ds_offset
;
static_for
<
0
,
NumDTensor
,
1
>
{}(
[
&
](
auto
i
)
{
ds_offset
(
i
)
=
g_idx
*
static_cast
<
long_index_t
>
(
BatchStrideDs_
[
i
]
)
;
});
[
&
](
auto
i
)
{
ds_offset
(
i
)
=
static_cast
<
long_index_t
>
(
g_idx
)
*
BatchStrideDs_
[
i
];
});
return
ds_offset
;
}
[[
maybe_unused
]]
__host__
__device__
constexpr
long_index_t
GetEPtrOffset
(
index_t
g_idx
)
const
{
return
g_idx
*
static_cast
<
long_index_t
>
(
BatchStrideE_
)
;
return
static_cast
<
long_index_t
>
(
g_idx
)
*
BatchStrideE_
;
}
// alias for kernels without multiple D
[[
maybe_unused
]]
__host__
__device__
constexpr
long_index_t
GetCPtrOffset
(
index_t
g_idx
)
const
{
return
g_idx
*
static_cast
<
long_index_t
>
(
BatchStrideE_
)
;
return
static_cast
<
long_index_t
>
(
g_idx
)
*
BatchStrideE_
;
}
Array
<
ck
::
index_t
,
NumATensor
>
BatchStrideA_
;
Array
<
ck
::
index_t
,
NumBTensor
>
BatchStrideB_
;
Array
<
ck
::
index_t
,
NumDTensor
>
BatchStrideDs_
;
index_t
BatchStrideE_
;
index_t
&
BatchStrideC_
=
BatchStrideE_
;
// alias for kernels without multiple D
Array
<
long_
index_t
,
NumATensor
>
BatchStrideA_
;
Array
<
long_
index_t
,
NumBTensor
>
BatchStrideB_
;
Array
<
long_
index_t
,
NumDTensor
>
BatchStrideDs_
;
long_
index_t
BatchStrideE_
;
long_
index_t
&
BatchStrideC_
=
BatchStrideE_
;
// alias for kernels without multiple D
};
template
<
index_t
NumATensor
,
index_t
NumBTensor
,
index_t
NumDTensor
>
struct
ComputePtrOffsetOfStridedBatch
<
NumATensor
,
NumBTensor
,
NumDTensor
,
ck
::
enable_if_t
<
(
NumATensor
==
1
&&
NumBTensor
==
1
)
>>
enable_if_t
<
(
NumATensor
==
1
&&
NumBTensor
==
1
)
>>
{
ComputePtrOffsetOfStridedBatch
()
=
default
;
ComputePtrOffsetOfStridedBatch
(
index_t
BatchStrideA
,
index_t
BatchStrideB
,
Array
<
ck
::
index_t
,
NumDTensor
>
BatchStrideDs
,
index_t
BatchStrideE
)
ComputePtrOffsetOfStridedBatch
(
long_
index_t
BatchStrideA
,
long_
index_t
BatchStrideB
,
Array
<
long_
index_t
,
NumDTensor
>
BatchStrideDs
,
long_
index_t
BatchStrideE
)
:
BatchStrideA_
(
BatchStrideA
),
BatchStrideB_
(
BatchStrideB
),
BatchStrideDs_
(
BatchStrideDs
),
...
...
@@ -146,38 +146,38 @@ struct ComputePtrOffsetOfStridedBatch<NumATensor,
__host__
__device__
constexpr
long_index_t
GetAPtrOffset
(
index_t
g_idx
)
const
{
return
g_idx
*
static_cast
<
long_index_t
>
(
BatchStrideA_
)
;
return
static_cast
<
long_index_t
>
(
g_idx
)
*
BatchStrideA_
;
}
__host__
__device__
constexpr
long_index_t
GetBPtrOffset
(
index_t
g_idx
)
const
{
return
g_idx
*
static_cast
<
long_index_t
>
(
BatchStrideB_
)
;
return
static_cast
<
long_index_t
>
(
g_idx
)
*
BatchStrideB_
;
}
__host__
__device__
constexpr
auto
GetDsPtrOffset
(
index_t
g_idx
)
const
{
Array
<
long_index_t
,
NumDTensor
>
ds_offset
;
static_for
<
0
,
NumDTensor
,
1
>
{}(
[
&
](
auto
i
)
{
ds_offset
(
i
)
=
g_idx
*
static_cast
<
long_index_t
>
(
BatchStrideDs_
[
i
]
)
;
});
[
&
](
auto
i
)
{
ds_offset
(
i
)
=
static_cast
<
long_index_t
>
(
g_idx
)
*
BatchStrideDs_
[
i
];
});
return
ds_offset
;
}
[[
maybe_unused
]]
__host__
__device__
constexpr
long_index_t
GetEPtrOffset
(
index_t
g_idx
)
const
{
return
g_idx
*
static_cast
<
long_index_t
>
(
BatchStrideE_
)
;
return
static_cast
<
long_index_t
>
(
g_idx
)
*
BatchStrideE_
;
}
// alias for kernels without multiple D
[[
maybe_unused
]]
__host__
__device__
constexpr
long_index_t
GetCPtrOffset
(
index_t
g_idx
)
const
{
return
g_idx
*
static_cast
<
long_index_t
>
(
BatchStrideE_
)
;
return
static_cast
<
long_index_t
>
(
g_idx
)
*
BatchStrideE_
;
}
ck
::
index_t
BatchStrideA_
;
ck
::
index_t
BatchStrideB_
;
Array
<
ck
::
index_t
,
NumDTensor
>
BatchStrideDs_
;
index_t
BatchStrideE_
;
index_t
&
BatchStrideC_
=
BatchStrideE_
;
// alias for kernels without multiple D
long_
index_t
BatchStrideA_
;
long_
index_t
BatchStrideB_
;
Array
<
long_
index_t
,
NumDTensor
>
BatchStrideDs_
;
long_
index_t
BatchStrideE_
;
long_
index_t
&
BatchStrideC_
=
BatchStrideE_
;
// alias for kernels without multiple D
};
template
<
bool
isTuple
,
typename
Tensors
>
...
...
include/ck/tensor_operation/gpu/device/impl/device_image_to_column_impl.hpp
View file @
ce30621d
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-202
3
, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-202
4
, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
...
...
@@ -108,7 +108,8 @@ struct DeviceImageToColumnImpl
conv_filter_strides
,
conv_filter_dilations
,
input_left_pads
,
input_right_pads
);
input_right_pads
,
N
);
const
auto
in_gemmm_gemmk_desc
=
matrix_padder
.
PadADescriptor_M_K
(
in_gemmmraw_gemmkraw_desc
);
...
...
include/ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_d_wmma_cshuffle.hpp
View file @
ce30621d
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-202
3
, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-202
4
, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
...
...
@@ -60,12 +60,9 @@ __global__ void
__builtin_amdgcn_readfirstlane
(
get_grid_size
()
/
batch_count
);
const
index_t
g_idx
=
__builtin_amdgcn_readfirstlane
(
get_block_1d_id
()
/
num_blocks_per_batch
);
const
long_index_t
a_batch_offset
=
__builtin_amdgcn_readfirstlane
(
static_cast
<
long_index_t
>
(
compute_ptr_offset_of_batch
.
GetAPtrOffset
(
g_idx
)));
const
long_index_t
b_batch_offset
=
__builtin_amdgcn_readfirstlane
(
static_cast
<
long_index_t
>
(
compute_ptr_offset_of_batch
.
GetBPtrOffset
(
g_idx
)));
const
long_index_t
e_batch_offset
=
__builtin_amdgcn_readfirstlane
(
static_cast
<
long_index_t
>
(
compute_ptr_offset_of_batch
.
GetEPtrOffset
(
g_idx
)));
const
long_index_t
a_batch_offset
=
compute_ptr_offset_of_batch
.
GetAPtrOffset
(
g_idx
);
const
long_index_t
b_batch_offset
=
compute_ptr_offset_of_batch
.
GetBPtrOffset
(
g_idx
);
const
long_index_t
e_batch_offset
=
compute_ptr_offset_of_batch
.
GetEPtrOffset
(
g_idx
);
const
auto
ds_batch_offset
=
compute_ptr_offset_of_batch
.
GetDsPtrOffset
(
g_idx
);
...
...
@@ -155,12 +152,9 @@ __global__ void
__builtin_amdgcn_readfirstlane
(
get_grid_size
()
/
batch_count
);
const
index_t
g_idx
=
__builtin_amdgcn_readfirstlane
(
get_block_1d_id
()
/
num_blocks_per_batch
);
const
long_index_t
a_batch_offset
=
__builtin_amdgcn_readfirstlane
(
static_cast
<
long_index_t
>
(
compute_ptr_offset_of_batch
.
GetAPtrOffset
(
g_idx
)));
const
long_index_t
b_batch_offset
=
__builtin_amdgcn_readfirstlane
(
static_cast
<
long_index_t
>
(
compute_ptr_offset_of_batch
.
GetBPtrOffset
(
g_idx
)));
const
long_index_t
e_batch_offset
=
__builtin_amdgcn_readfirstlane
(
static_cast
<
long_index_t
>
(
compute_ptr_offset_of_batch
.
GetEPtrOffset
(
g_idx
)));
const
long_index_t
a_batch_offset
=
compute_ptr_offset_of_batch
.
GetAPtrOffset
(
g_idx
);
const
long_index_t
b_batch_offset
=
compute_ptr_offset_of_batch
.
GetBPtrOffset
(
g_idx
);
const
long_index_t
e_batch_offset
=
compute_ptr_offset_of_batch
.
GetEPtrOffset
(
g_idx
);
const
auto
ds_batch_offset
=
compute_ptr_offset_of_batch
.
GetDsPtrOffset
(
g_idx
);
...
...
include/ck/tensor_operation/operator_transform/transform_conv_fwd_to_gemm.hpp
View file @
ce30621d
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-202
3
, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-202
4
, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
...
...
@@ -20,6 +20,71 @@ struct TransformConvFwdToGemm
static
constexpr
auto
I0
=
Number
<
0
>
{};
static
constexpr
auto
I1
=
Number
<
1
>
{};
static
long_index_t
calculate_element_space_size_impl
(
const
std
::
array
<
index_t
,
NDimSpatial
+
3
>&
lengths
,
const
std
::
array
<
index_t
,
NDimSpatial
+
3
>&
strides
,
index_t
i
)
{
long_index_t
acc
=
1
;
for
(;
i
<
(
NDimSpatial
+
3
);
i
++
)
{
acc
+=
static_cast
<
long_index_t
>
(
lengths
[
i
]
-
I1
)
*
static_cast
<
long_index_t
>
(
strides
[
i
]);
}
return
acc
;
}
template
<
typename
ADataType
,
typename
CDataType
>
static
index_t
GetSplitedNSize
(
const
std
::
array
<
index_t
,
NDimSpatial
+
3
>&
a_g_n_c_wis_lengths
,
const
std
::
array
<
index_t
,
NDimSpatial
+
3
>&
a_g_n_c_wis_strides
,
const
std
::
array
<
index_t
,
NDimSpatial
+
3
>&
c_g_n_k_wos_lengths
,
const
std
::
array
<
index_t
,
NDimSpatial
+
3
>&
c_g_n_k_wos_strides
)
{
const
long_index_t
a_element_space_size
=
calculate_element_space_size_impl
(
a_g_n_c_wis_lengths
,
a_g_n_c_wis_strides
,
I1
);
const
long_index_t
c_element_space_size
=
calculate_element_space_size_impl
(
c_g_n_k_wos_lengths
,
c_g_n_k_wos_strides
,
I1
);
const
long_index_t
element_space_size
=
math
::
max
(
a_element_space_size
*
sizeof
(
ADataType
),
c_element_space_size
*
sizeof
(
CDataType
));
constexpr
long_index_t
TwoGB
=
(
long_index_t
{
1
}
<<
31
);
const
index_t
N
=
a_g_n_c_wis_lengths
[
I1
];
if
(
element_space_size
>
TwoGB
)
{
// Minimum divisor of N to not exceed 2GB
const
auto
divisor
=
math
::
integer_divide_ceil
(
element_space_size
,
TwoGB
);
if
(
divisor
<=
static_cast
<
double
>
(
N
))
{
// Find least divisor of N larger than element_space_size / TwoGB
// Iterate up to sqrt(N). There are no divisors above this value.
for
(
index_t
least_divisor
=
divisor
;
least_divisor
*
least_divisor
<=
N
;
least_divisor
++
)
{
if
(
N
%
least_divisor
==
0
)
{
return
N
/
least_divisor
;
}
}
// Not found, process one Convolution N per block
return
1
;
}
else
{
// Not possible to support even after split N.
// Too large tensor.
return
N
;
}
}
else
{
// Split N is not needed.
return
N
;
}
}
// TODO: implement ck::tensor_layout::convolution that describe packed/strided dimemsion as
// properties
template
<
typename
ALayout
,
...
...
@@ -38,9 +103,9 @@ struct TransformConvFwdToGemm
const
std
::
array
<
index_t
,
NDimSpatial
>&
conv_filter_strides
,
const
std
::
array
<
index_t
,
NDimSpatial
>&
conv_filter_dilations
,
const
std
::
array
<
index_t
,
NDimSpatial
>&
input_left_pads
,
const
std
::
array
<
index_t
,
NDimSpatial
>&
input_right_pads
)
const
std
::
array
<
index_t
,
NDimSpatial
>&
input_right_pads
,
const
index_t
N
)
{
const
index_t
N
=
a_g_n_c_wis_lengths
[
1
];
const
index_t
C
=
a_g_n_c_wis_lengths
[
2
];
const
index_t
Wi
=
a_g_n_c_wis_lengths
[
3
];
...
...
@@ -151,9 +216,10 @@ struct TransformConvFwdToGemm
const
std
::
array
<
index_t
,
NDimSpatial
>&
conv_filter_strides
,
const
std
::
array
<
index_t
,
NDimSpatial
>&
conv_filter_dilations
,
const
std
::
array
<
index_t
,
NDimSpatial
>&
input_left_pads
,
const
std
::
array
<
index_t
,
NDimSpatial
>&
input_right_pads
)
const
std
::
array
<
index_t
,
NDimSpatial
>&
input_right_pads
,
const
index_t
N
)
{
const
index_t
N
=
a_g_n_c_wis_lengths
[
1
];
const
index_t
C
=
a_g_n_c_wis_lengths
[
2
];
const
index_t
Hi
=
a_g_n_c_wis_lengths
[
3
];
...
...
@@ -276,13 +342,14 @@ struct TransformConvFwdToGemm
const
std
::
array
<
index_t
,
NDimSpatial
+
3
>&
b_g_k_c_xs_lengths
,
const
std
::
array
<
index_t
,
NDimSpatial
+
3
>&
/* b_g_k_c_xs_strides */
,
const
std
::
array
<
index_t
,
NDimSpatial
+
3
>&
c_g_n_k_wos_lengths
,
const
std
::
array
<
index_t
,
NDimSpatial
+
3
>&
/* c_g_n_k_wos_strides
*/
,
const
std
::
array
<
index_t
,
NDimSpatial
+
3
>&
/* c_g_n_k_wos_strides*/
,
const
std
::
array
<
index_t
,
NDimSpatial
>&
conv_filter_strides
,
const
std
::
array
<
index_t
,
NDimSpatial
>&
conv_filter_dilations
,
const
std
::
array
<
index_t
,
NDimSpatial
>&
input_left_pads
,
const
std
::
array
<
index_t
,
NDimSpatial
>&
input_right_pads
)
const
std
::
array
<
index_t
,
NDimSpatial
>&
input_right_pads
,
const
index_t
N
)
{
const
index_t
N
=
a_g_n_c_wis_lengths
[
1
];
const
index_t
C
=
a_g_n_c_wis_lengths
[
2
];
const
index_t
Di
=
a_g_n_c_wis_lengths
[
3
];
...
...
@@ -478,9 +545,9 @@ struct TransformConvFwdToGemm
bool
>::
type
=
false
>
static
auto
MakeCDescriptor_M_N
(
const
std
::
array
<
index_t
,
NDimSpatial
+
3
>&
c_g_n_k_wos_lengths
,
const
std
::
array
<
index_t
,
NDimSpatial
+
3
>&
/* c_g_n_k_wos_strides */
)
const
std
::
array
<
index_t
,
NDimSpatial
+
3
>&
/* c_g_n_k_wos_strides */
,
const
index_t
N
)
{
const
index_t
N
=
c_g_n_k_wos_lengths
[
1
];
const
index_t
K
=
c_g_n_k_wos_lengths
[
2
];
const
index_t
NHoWo
=
...
...
@@ -502,9 +569,9 @@ struct TransformConvFwdToGemm
is_same_v
<
CLayout
,
tensor_layout
::
convolution
::
NDHWGK
>
,
bool
>::
type
=
false
>
static
auto
MakeCDescriptor_M_N
(
const
std
::
array
<
index_t
,
NDimSpatial
+
3
>&
c_g_n_k_wos_lengths
,
const
std
::
array
<
index_t
,
NDimSpatial
+
3
>&
c_g_n_k_wos_strides
)
const
std
::
array
<
index_t
,
NDimSpatial
+
3
>&
c_g_n_k_wos_strides
,
const
index_t
N
)
{
const
index_t
N
=
c_g_n_k_wos_lengths
[
1
];
const
index_t
K
=
c_g_n_k_wos_lengths
[
2
];
const
auto
KStride
=
I1
;
...
...
@@ -525,9 +592,9 @@ struct TransformConvFwdToGemm
typename
std
::
enable_if
<
is_same_v
<
CLayout
,
tensor_layout
::
convolution
::
G_K
>,
bool
>::
type
=
false
>
static
auto
MakeCDescriptor_M_N
(
const
std
::
array
<
index_t
,
NDimSpatial
+
3
>&
c_g_n_k_wos_lengths
,
const
std
::
array
<
index_t
,
NDimSpatial
+
3
>&
c_g_n_k_wos_strides
)
const
std
::
array
<
index_t
,
NDimSpatial
+
3
>&
c_g_n_k_wos_strides
,
const
index_t
N
)
{
const
index_t
N
=
c_g_n_k_wos_lengths
[
1
];
const
index_t
K
=
c_g_n_k_wos_lengths
[
2
];
const
index_t
KStride
=
c_g_n_k_wos_strides
[
2
];
...
...
include/ck_tile/core/arch/arch.hpp
View file @
ce30621d
...
...
@@ -61,10 +61,13 @@ CK_TILE_DEVICE index_t get_block_id() { return blockIdx.x; }
CK_TILE_DEVICE
void
block_sync_lds
()
{
#if CK_TILE_EXPERIMENTAL_BLOCK_SYNC_LDS_WITHOUT_SYNC_VMEM
asm
volatile
(
"\
s_waitcnt lgkmcnt(0)
\n
\
s_barrier \
"
::
);
// asm volatile("\
// s_waitcnt lgkmcnt(0) \n \
// s_barrier \
// " ::);
__builtin_amdgcn_s_waitcnt
(
0xc07f
);
__builtin_amdgcn_s_barrier
();
#else
__syncthreads
();
#endif
...
...
include/ck_tile/core/config.hpp
View file @
ce30621d
...
...
@@ -167,6 +167,10 @@
#define CK_TILE_USE_SUBDWORD_TILE_CAST 0
#endif
#ifndef CK_TILE_USE_PK_FP16_TILE_CAST
#define CK_TILE_USE_PK_FP16_TILE_CAST 0
#endif
// TODO: better solve this inside compiler
#ifndef CK_TILE_FMHA_FWD_FAST_EXP2
#define CK_TILE_FMHA_FWD_FAST_EXP2 0
...
...
include/ck_tile/core/tensor/tile_elementwise.hpp
View file @
ce30621d
...
...
@@ -110,7 +110,7 @@ CK_TILE_DEVICE void clear_tile(DstrTensors& dstr_tensor)
namespace
impl
{
// TODO: this is ugly
template
<
typename
OutDataType
,
typename
InTensor
>
CK_TILE_DEVICE
auto
cast_tile_pk_fp8
x4
(
const
InTensor
&
in_dstr_tensors
)
CK_TILE_DEVICE
auto
cast_tile_pk_fp8
_fp32
(
const
InTensor
&
in_dstr_tensors
)
{
#if defined(__gfx94__)
// This API is designed to use the _pk_ serious of function
...
...
@@ -156,6 +156,37 @@ CK_TILE_DEVICE auto cast_tile_pk_fp8x4(const InTensor& in_dstr_tensors)
#endif
}
template
<
typename
OutDataType
,
typename
InTensor
>
CK_TILE_DEVICE
auto
cast_tile_pk_fp16_fp32
(
const
InTensor
&
in_dstr_tensors
)
{
#if defined(__gfx908__) || defined(__gfx90a__) || defined(__gfx94__)
// This API is designed to use the _pk_ serious of function
constexpr
auto
in_tile_dstr
=
InTensor
::
get_tile_distribution
();
constexpr
index_t
thread_buffer_size
=
InTensor
::
get_thread_buffer_size
();
static_assert
(
thread_buffer_size
%
2
==
0
);
constexpr
index_t
thread_buffer_size_pk
=
thread_buffer_size
/
2
;
auto
out_dstr_tensor
=
make_static_distributed_tensor
<
OutDataType
>
(
in_tile_dstr
);
// TODO: this is rtz cvt, need be very careful
for
(
index_t
i
=
0
;
i
<
thread_buffer_size_pk
;
i
++
)
{
auto
o
=
__builtin_amdgcn_cvt_pkrtz
(
in_dstr_tensors
.
get_thread_buffer
()[
2
*
i
+
0
],
in_dstr_tensors
.
get_thread_buffer
()[
2
*
i
+
1
]);
out_dstr_tensor
.
get_thread_buffer
().
at
(
2
*
i
+
0
)
=
o
.
x
;
out_dstr_tensor
.
get_thread_buffer
().
at
(
2
*
i
+
1
)
=
o
.
y
;
}
return
out_dstr_tensor
;
#else
// fallback
return
tile_elementwise_in
(
type_convert
<
OutDataType
,
typename
InTensor
::
DataType
>
,
in_dstr_tensors
);
#endif
}
#if CK_TILE_USE_SUBDWORD_TILE_CAST
// this function assume either src or dst (or both) date type is under 1 dword
// we pack subdword value into 1 dword to avoid compiler's default subdword behavior(which is buggy)
...
...
@@ -229,8 +260,16 @@ CK_TILE_DEVICE auto cast_tile(const SrcTensor& src_tensor)
float
>
&&
(
SrcTensor
::
get_thread_buffer_size
()
%
4
==
0
))
{
return
impl
::
cast_tile_pk_fp8
x4
<
DstType
,
SrcTensor
>
(
src_tensor
);
return
impl
::
cast_tile_pk_fp8
_fp32
<
DstType
,
SrcTensor
>
(
src_tensor
);
}
#if CK_TILE_USE_PK_FP16_TILE_CAST
else
if
constexpr
(
std
::
is_same_v
<
DstType
,
fp16_t
>
&&
std
::
is_same_v
<
typename
SrcTensor
::
DataType
,
float
>
&&
(
SrcTensor
::
get_thread_buffer_size
()
%
2
==
0
))
{
return
impl
::
cast_tile_pk_fp16_fp32
<
DstType
,
SrcTensor
>
(
src_tensor
);
}
#endif
#if CK_TILE_USE_SUBDWORD_TILE_CAST
else
if
constexpr
(
sizeof
(
DstType
)
<
4
||
sizeof
(
typename
SrcTensor
::
DataType
)
<
4
)
{
...
...
Prev
1
2
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment