Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel_ROCM
Commits
a8efb3f0
Unverified
Commit
a8efb3f0
authored
Jul 16, 2024
by
Illia Silin
Committed by
GitHub
Jul 16, 2024
Browse files
Merge branch 'develop' into lwpck-1815
parents
06701e70
9cac2827
Changes
51
Expand all
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
1495 additions
and
366 deletions
+1495
-366
example/ck_tile/01_fmha/script/smoke_test_bwd.sh
example/ck_tile/01_fmha/script/smoke_test_bwd.sh
+2
-1
example/ck_tile/01_fmha/script/smoke_test_fwd.sh
example/ck_tile/01_fmha/script/smoke_test_fwd.sh
+3
-1
include/ck/tensor_operation/gpu/device/convolution_forward_specialization.hpp
...eration/gpu/device/convolution_forward_specialization.hpp
+3
-1
include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_two_stage_xdl_cshuffle.hpp
...device_grouped_conv_bwd_weight_two_stage_xdl_cshuffle.hpp
+38
-38
include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_abd_xdl_cshuffle.hpp
...mpl/device_grouped_conv_fwd_multiple_abd_xdl_cshuffle.hpp
+84
-43
include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_utils.hpp
...r_operation/gpu/device/impl/device_grouped_conv_utils.hpp
+16
-0
include/ck/tensor_operation/gpu/element/unary_element_wise_operation.hpp
...or_operation/gpu/element/unary_element_wise_operation.hpp
+25
-0
include/ck/tensor_operation/operator_transform/transform_conv_bwd_weight_to_gemm_v2.hpp
...erator_transform/transform_conv_bwd_weight_to_gemm_v2.hpp
+45
-45
include/ck/tensor_operation/operator_transform/transform_conv_fwd_to_gemm.hpp
...eration/operator_transform/transform_conv_fwd_to_gemm.hpp
+785
-232
library/include/ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_weight/device_grouped_conv_bwd_weight_two_stage_xdl_instance.hpp
...device_grouped_conv_bwd_weight_two_stage_xdl_instance.hpp
+4
-4
library/include/ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_merged_groups_instance.hpp
...wd/device_grouped_conv_fwd_xdl_merged_groups_instance.hpp
+96
-0
library/include/ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_outelementop_instance.hpp
...fwd/device_grouped_conv_fwd_xdl_outelementop_instance.hpp
+37
-0
library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_forward.hpp
...or_operation_instance/gpu/grouped_convolution_forward.hpp
+13
-0
library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_forward_convscale.hpp
...on_instance/gpu/grouped_convolution_forward_convscale.hpp
+24
-0
library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_forward_convscale_relu.hpp
...stance/gpu/grouped_convolution_forward_convscale_relu.hpp
+105
-0
library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_forward_xdl_merged_groups.inc
...nce/gpu/grouped_convolution_forward_xdl_merged_groups.inc
+112
-0
library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_xdl_universal_f16_f8_f16/device_gemm_xdl_universal_f16_f8_f16_mk_kn_mn.hpp
..._f8_f16/device_gemm_xdl_universal_f16_f8_f16_mk_kn_mn.hpp
+2
-1
library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/CMakeLists.txt
..._operation_instance/gpu/grouped_conv2d_fwd/CMakeLists.txt
+5
-0
library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/xdl/merged_groups/device_grouped_conv2d_fwd_xdl_merged_groups_nhwgc_gkyxc_nhwgk_bf16_instance.cpp
...fwd_xdl_merged_groups_nhwgc_gkyxc_nhwgk_bf16_instance.cpp
+48
-0
library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/xdl/merged_groups/device_grouped_conv2d_fwd_xdl_merged_groups_nhwgc_gkyxc_nhwgk_f16_instance.cpp
..._fwd_xdl_merged_groups_nhwgc_gkyxc_nhwgk_f16_instance.cpp
+48
-0
No files found.
example/ck_tile/01_fmha/script/smoke_test_bwd.sh
100644 → 100755
View file @
a8efb3f0
...
@@ -8,7 +8,7 @@ export CK_WARMUP=0
...
@@ -8,7 +8,7 @@ export CK_WARMUP=0
export
CK_REPEAT
=
1
export
CK_REPEAT
=
1
COMMON_ARGS
=
'-v=1'
COMMON_ARGS
=
'-v=1'
set
-x
for
prec
in
"fp16"
"bf16"
;
do
for
prec
in
"fp16"
"bf16"
;
do
for
perm
in
0 1
;
do
for
perm
in
0 1
;
do
for
hdim
in
32 64 128
;
do
for
hdim
in
32 64 128
;
do
...
@@ -31,3 +31,4 @@ done
...
@@ -31,3 +31,4 @@ done
done
done
done
done
done
done
set
+x
example/ck_tile/01_fmha/script/smoke_test_fwd.sh
View file @
a8efb3f0
...
@@ -10,7 +10,7 @@ export CK_REPEAT=1
...
@@ -10,7 +10,7 @@ export CK_REPEAT=1
COMMON_ARGS
=
'-v=1 -warmup=0 -repeat=1'
COMMON_ARGS
=
'-v=1 -warmup=0 -repeat=1'
# mode=0
# mode=0
# export HIP_VISIBLE_DEVICES=4
# export HIP_VISIBLE_DEVICES=4
set
-x
for
prec
in
"fp16"
"bf16"
;
do
for
prec
in
"fp16"
"bf16"
;
do
for
mode
in
1 0
;
do
for
mode
in
1 0
;
do
for
perm
in
0 1
;
do
for
perm
in
0 1
;
do
...
@@ -40,6 +40,7 @@ done
...
@@ -40,6 +40,7 @@ done
done
done
done
done
for
perm
in
0 1
;
do
for
perm
in
0 1
;
do
for
bias
in
"n"
"e"
"a"
;
do
for
bias
in
"n"
"e"
"a"
;
do
for
b
in
1 2
;
do
for
b
in
1 2
;
do
...
@@ -49,3 +50,4 @@ done
...
@@ -49,3 +50,4 @@ done
done
done
done
done
done
done
set
+x
include/ck/tensor_operation/gpu/device/convolution_forward_specialization.hpp
View file @
a8efb3f0
// SPDX-License-Identifier: MIT
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-202
3
, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-202
4
, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#pragma once
...
@@ -15,6 +15,7 @@ enum struct ConvolutionForwardSpecialization
...
@@ -15,6 +15,7 @@ enum struct ConvolutionForwardSpecialization
Filter1x1Pad0
,
Filter1x1Pad0
,
Filter1x1Stride1Pad0
,
Filter1x1Stride1Pad0
,
OddC
,
OddC
,
Filter3x3
,
};
};
inline
std
::
string
getConvForwardSpecializationString
(
const
ConvolutionForwardSpecialization
&
s
)
inline
std
::
string
getConvForwardSpecializationString
(
const
ConvolutionForwardSpecialization
&
s
)
...
@@ -25,6 +26,7 @@ inline std::string getConvForwardSpecializationString(const ConvolutionForwardSp
...
@@ -25,6 +26,7 @@ inline std::string getConvForwardSpecializationString(const ConvolutionForwardSp
case
ConvolutionForwardSpecialization
::
Filter1x1Pad0
:
return
"Filter1x1Pad0"
;
case
ConvolutionForwardSpecialization
::
Filter1x1Pad0
:
return
"Filter1x1Pad0"
;
case
ConvolutionForwardSpecialization
::
Filter1x1Stride1Pad0
:
return
"Filter1x1Stride1Pad0"
;
case
ConvolutionForwardSpecialization
::
Filter1x1Stride1Pad0
:
return
"Filter1x1Stride1Pad0"
;
case
ConvolutionForwardSpecialization
::
OddC
:
return
"OddC"
;
case
ConvolutionForwardSpecialization
::
OddC
:
return
"OddC"
;
case
ConvolutionForwardSpecialization
::
Filter3x3
:
return
"Filter3x3"
;
default:
return
"Unrecognized specialization!"
;
default:
return
"Unrecognized specialization!"
;
}
}
}
}
...
...
include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_two_stage_xdl_cshuffle.hpp
View file @
a8efb3f0
...
@@ -36,7 +36,7 @@ template <typename GridwiseGemm,
...
@@ -36,7 +36,7 @@ template <typename GridwiseGemm,
typename
BGridDesc_BK0_N_K1
,
typename
BGridDesc_BK0_N_K1
,
typename
CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock
,
typename
CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock
,
typename
ComputePtrOffsetOfBatch
,
typename
ComputePtrOffsetOfBatch
,
index_t
Num
Batch
ToMerge
,
index_t
Num
Groups
ToMerge
,
bool
HasMainKBlockLoop
,
bool
HasMainKBlockLoop
,
InMemoryDataOperationEnum
CGlobalMemoryDataOperation
,
InMemoryDataOperationEnum
CGlobalMemoryDataOperation
,
index_t
MinimumOccupancy
=
1
,
index_t
MinimumOccupancy
=
1
,
...
@@ -56,7 +56,7 @@ __global__ void
...
@@ -56,7 +56,7 @@ __global__ void
{
{
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \
defined(__gfx94__))
defined(__gfx94__))
const
index_t
g_idx
=
__builtin_amdgcn_readfirstlane
(
blockIdx
.
z
*
Num
Batch
ToMerge
);
const
index_t
g_idx
=
__builtin_amdgcn_readfirstlane
(
blockIdx
.
z
*
Num
Groups
ToMerge
);
const
index_t
k_idx
=
__builtin_amdgcn_readfirstlane
(
blockIdx
.
y
*
num_k_per_block
);
const
index_t
k_idx
=
__builtin_amdgcn_readfirstlane
(
blockIdx
.
y
*
num_k_per_block
);
const
long_index_t
a_batch_offset
=
const
long_index_t
a_batch_offset
=
...
@@ -92,7 +92,7 @@ template <typename GridwiseGemm,
...
@@ -92,7 +92,7 @@ template <typename GridwiseGemm,
typename
BGridDesc_BK0_N_K1
,
typename
BGridDesc_BK0_N_K1
,
typename
CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock
,
typename
CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock
,
typename
ComputePtrOffsetOfBatch
,
typename
ComputePtrOffsetOfBatch
,
index_t
Num
Batch
ToMerge
,
index_t
Num
Groups
ToMerge
,
bool
HasMainKBlockLoop
,
bool
HasMainKBlockLoop
,
InMemoryDataOperationEnum
CGlobalMemoryDataOperation
,
InMemoryDataOperationEnum
CGlobalMemoryDataOperation
,
index_t
MinimumOccupancy
=
1
,
index_t
MinimumOccupancy
=
1
,
...
@@ -113,7 +113,7 @@ __global__ void
...
@@ -113,7 +113,7 @@ __global__ void
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \
defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__))
defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__))
// offset base pointer for each work-group
// offset base pointer for each work-group
const
index_t
g_idx
=
__builtin_amdgcn_readfirstlane
(
blockIdx
.
z
*
Num
Batch
ToMerge
);
const
index_t
g_idx
=
__builtin_amdgcn_readfirstlane
(
blockIdx
.
z
*
Num
Groups
ToMerge
);
const
index_t
k_idx
=
__builtin_amdgcn_readfirstlane
(
blockIdx
.
y
*
num_k_per_block
);
const
index_t
k_idx
=
__builtin_amdgcn_readfirstlane
(
blockIdx
.
y
*
num_k_per_block
);
const
long_index_t
a_batch_offset
=
const
long_index_t
a_batch_offset
=
...
@@ -189,7 +189,7 @@ template <ck::index_t NDimSpatial,
...
@@ -189,7 +189,7 @@ template <ck::index_t NDimSpatial,
index_t
CBlockTransferScalarPerVector_NWaveNPerXdl
,
index_t
CBlockTransferScalarPerVector_NWaveNPerXdl
,
BlockGemmPipelineScheduler
BlkGemmPipeSched
=
BlockGemmPipelineScheduler
::
Intrawave
,
BlockGemmPipelineScheduler
BlkGemmPipeSched
=
BlockGemmPipelineScheduler
::
Intrawave
,
BlockGemmPipelineVersion
BlkGemmPipelineVer
=
BlockGemmPipelineVersion
::
v1
,
BlockGemmPipelineVersion
BlkGemmPipelineVer
=
BlockGemmPipelineVersion
::
v1
,
index_t
Num
Batch
ToMerge
=
1
,
index_t
Num
Groups
ToMerge
=
1
,
typename
ComputeTypeA
=
InDataType
,
typename
ComputeTypeA
=
InDataType
,
typename
ComputeTypeB
=
ComputeTypeA
>
typename
ComputeTypeB
=
ComputeTypeA
>
struct
DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle
struct
DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle
...
@@ -238,7 +238,7 @@ struct DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle
...
@@ -238,7 +238,7 @@ struct DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle
NPerBlock
,
NPerBlock
,
K1Number
,
K1Number
,
KPerBlock
/
K1Number
,
KPerBlock
/
K1Number
,
Num
Batch
ToMerge
,
Num
Groups
ToMerge
,
ConvBackwardWeightSpecialization
>
{};
ConvBackwardWeightSpecialization
>
{};
static
constexpr
auto
conv_to_gemm_transformer_v1
=
static
constexpr
auto
conv_to_gemm_transformer_v1
=
...
@@ -638,7 +638,7 @@ struct DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle
...
@@ -638,7 +638,7 @@ struct DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle
index_t
gdx
,
gdy
,
gdz
;
index_t
gdx
,
gdy
,
gdz
;
std
::
tie
(
gdx
,
gdy
,
gdz
)
=
GridwiseGemm
::
CalculateGridSize
(
std
::
tie
(
gdx
,
gdy
,
gdz
)
=
GridwiseGemm
::
CalculateGridSize
(
gemm_arg
.
M
,
gemm_arg
.
N
,
gemm_arg
.
KBatch
,
arg
.
Conv_G_
/
Num
Batch
ToMerge
);
gemm_arg
.
M
,
gemm_arg
.
N
,
gemm_arg
.
KBatch
,
arg
.
Conv_G_
/
Num
Groups
ToMerge
);
float
ave_time
=
0
;
float
ave_time
=
0
;
...
@@ -724,7 +724,7 @@ struct DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle
...
@@ -724,7 +724,7 @@ struct DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle
remove_reference_t
<
remove_reference_t
<
DeviceOp
::
CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock
>
,
DeviceOp
::
CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock
>
,
ComputePtrOffsetOfStridedBatch
<
I1
,
I1
,
I0
>
,
ComputePtrOffsetOfStridedBatch
<
I1
,
I1
,
I0
>
,
Num
Batch
ToMerge
,
Num
Groups
ToMerge
,
true
,
true
,
InMemoryDataOperationEnum
::
AtomicAdd
,
InMemoryDataOperationEnum
::
AtomicAdd
,
minimum_occupancy
>
;
minimum_occupancy
>
;
...
@@ -739,7 +739,7 @@ struct DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle
...
@@ -739,7 +739,7 @@ struct DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle
remove_reference_t
<
remove_reference_t
<
DeviceOp
::
CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock
>
,
DeviceOp
::
CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock
>
,
ComputePtrOffsetOfStridedBatch
<
I1
,
I1
,
I0
>
,
ComputePtrOffsetOfStridedBatch
<
I1
,
I1
,
I0
>
,
Num
Batch
ToMerge
,
Num
Groups
ToMerge
,
true
,
true
,
InMemoryDataOperationEnum
::
Set
,
InMemoryDataOperationEnum
::
Set
,
minimum_occupancy
>
;
minimum_occupancy
>
;
...
@@ -760,7 +760,7 @@ struct DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle
...
@@ -760,7 +760,7 @@ struct DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle
remove_reference_t
<
remove_reference_t
<
DeviceOp
::
CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock
>
,
DeviceOp
::
CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock
>
,
ComputePtrOffsetOfStridedBatch
<
I1
,
I1
,
I0
>
,
ComputePtrOffsetOfStridedBatch
<
I1
,
I1
,
I0
>
,
Num
Batch
ToMerge
,
Num
Groups
ToMerge
,
true
,
true
,
InMemoryDataOperationEnum
::
AtomicAdd
,
InMemoryDataOperationEnum
::
AtomicAdd
,
minimum_occupancy
,
minimum_occupancy
,
...
@@ -777,7 +777,7 @@ struct DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle
...
@@ -777,7 +777,7 @@ struct DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle
remove_reference_t
<
remove_reference_t
<
DeviceOp
::
CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock
>
,
DeviceOp
::
CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock
>
,
ComputePtrOffsetOfStridedBatch
<
I1
,
I1
,
I0
>
,
ComputePtrOffsetOfStridedBatch
<
I1
,
I1
,
I0
>
,
Num
Batch
ToMerge
,
Num
Groups
ToMerge
,
true
,
true
,
InMemoryDataOperationEnum
::
AtomicAdd
,
InMemoryDataOperationEnum
::
AtomicAdd
,
minimum_occupancy
,
minimum_occupancy
,
...
@@ -796,7 +796,7 @@ struct DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle
...
@@ -796,7 +796,7 @@ struct DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle
remove_reference_t
<
remove_reference_t
<
DeviceOp
::
CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock
>
,
DeviceOp
::
CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock
>
,
ComputePtrOffsetOfStridedBatch
<
I1
,
I1
,
I0
>
,
ComputePtrOffsetOfStridedBatch
<
I1
,
I1
,
I0
>
,
Num
Batch
ToMerge
,
Num
Groups
ToMerge
,
true
,
true
,
InMemoryDataOperationEnum
::
AtomicAdd
,
InMemoryDataOperationEnum
::
AtomicAdd
,
minimum_occupancy
,
minimum_occupancy
,
...
@@ -817,7 +817,7 @@ struct DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle
...
@@ -817,7 +817,7 @@ struct DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle
remove_reference_t
<
remove_reference_t
<
DeviceOp
::
CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock
>
,
DeviceOp
::
CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock
>
,
ComputePtrOffsetOfStridedBatch
<
I1
,
I1
,
I0
>
,
ComputePtrOffsetOfStridedBatch
<
I1
,
I1
,
I0
>
,
Num
Batch
ToMerge
,
Num
Groups
ToMerge
,
true
,
true
,
InMemoryDataOperationEnum
::
AtomicAdd
,
InMemoryDataOperationEnum
::
AtomicAdd
,
minimum_occupancy
,
minimum_occupancy
,
...
@@ -838,7 +838,7 @@ struct DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle
...
@@ -838,7 +838,7 @@ struct DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle
remove_reference_t
<
remove_reference_t
<
DeviceOp
::
CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock
>
,
DeviceOp
::
CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock
>
,
ComputePtrOffsetOfStridedBatch
<
I1
,
I1
,
I0
>
,
ComputePtrOffsetOfStridedBatch
<
I1
,
I1
,
I0
>
,
Num
Batch
ToMerge
,
Num
Groups
ToMerge
,
true
,
true
,
InMemoryDataOperationEnum
::
AtomicAdd
,
InMemoryDataOperationEnum
::
AtomicAdd
,
minimum_occupancy
,
minimum_occupancy
,
...
@@ -859,7 +859,7 @@ struct DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle
...
@@ -859,7 +859,7 @@ struct DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle
remove_reference_t
<
remove_reference_t
<
DeviceOp
::
CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock
>
,
DeviceOp
::
CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock
>
,
ComputePtrOffsetOfStridedBatch
<
I1
,
I1
,
I0
>
,
ComputePtrOffsetOfStridedBatch
<
I1
,
I1
,
I0
>
,
Num
Batch
ToMerge
,
Num
Groups
ToMerge
,
true
,
true
,
InMemoryDataOperationEnum
::
AtomicAdd
,
InMemoryDataOperationEnum
::
AtomicAdd
,
minimum_occupancy
,
minimum_occupancy
,
...
@@ -879,7 +879,7 @@ struct DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle
...
@@ -879,7 +879,7 @@ struct DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle
remove_reference_t
<
remove_reference_t
<
DeviceOp
::
CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock
>
,
DeviceOp
::
CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock
>
,
ComputePtrOffsetOfStridedBatch
<
I1
,
I1
,
I0
>
,
ComputePtrOffsetOfStridedBatch
<
I1
,
I1
,
I0
>
,
Num
Batch
ToMerge
,
Num
Groups
ToMerge
,
true
,
true
,
InMemoryDataOperationEnum
::
AtomicAdd
,
InMemoryDataOperationEnum
::
AtomicAdd
,
minimum_occupancy
,
minimum_occupancy
,
...
@@ -900,7 +900,7 @@ struct DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle
...
@@ -900,7 +900,7 @@ struct DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle
remove_reference_t
<
remove_reference_t
<
DeviceOp
::
CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock
>
,
DeviceOp
::
CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock
>
,
ComputePtrOffsetOfStridedBatch
<
I1
,
I1
,
I0
>
,
ComputePtrOffsetOfStridedBatch
<
I1
,
I1
,
I0
>
,
Num
Batch
ToMerge
,
Num
Groups
ToMerge
,
true
,
true
,
InMemoryDataOperationEnum
::
AtomicAdd
,
InMemoryDataOperationEnum
::
AtomicAdd
,
minimum_occupancy
,
minimum_occupancy
,
...
@@ -920,7 +920,7 @@ struct DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle
...
@@ -920,7 +920,7 @@ struct DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle
remove_reference_t
<
remove_reference_t
<
DeviceOp
::
CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock
>
,
DeviceOp
::
CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock
>
,
ComputePtrOffsetOfStridedBatch
<
I1
,
I1
,
I0
>
,
ComputePtrOffsetOfStridedBatch
<
I1
,
I1
,
I0
>
,
Num
Batch
ToMerge
,
Num
Groups
ToMerge
,
true
,
true
,
InMemoryDataOperationEnum
::
Set
,
InMemoryDataOperationEnum
::
Set
,
minimum_occupancy
,
minimum_occupancy
,
...
@@ -937,7 +937,7 @@ struct DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle
...
@@ -937,7 +937,7 @@ struct DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle
remove_reference_t
<
remove_reference_t
<
DeviceOp
::
CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock
>
,
DeviceOp
::
CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock
>
,
ComputePtrOffsetOfStridedBatch
<
I1
,
I1
,
I0
>
,
ComputePtrOffsetOfStridedBatch
<
I1
,
I1
,
I0
>
,
Num
Batch
ToMerge
,
Num
Groups
ToMerge
,
true
,
true
,
InMemoryDataOperationEnum
::
Set
,
InMemoryDataOperationEnum
::
Set
,
minimum_occupancy
,
minimum_occupancy
,
...
@@ -956,7 +956,7 @@ struct DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle
...
@@ -956,7 +956,7 @@ struct DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle
remove_reference_t
<
remove_reference_t
<
DeviceOp
::
CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock
>
,
DeviceOp
::
CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock
>
,
ComputePtrOffsetOfStridedBatch
<
I1
,
I1
,
I0
>
,
ComputePtrOffsetOfStridedBatch
<
I1
,
I1
,
I0
>
,
Num
Batch
ToMerge
,
Num
Groups
ToMerge
,
true
,
true
,
InMemoryDataOperationEnum
::
Set
,
InMemoryDataOperationEnum
::
Set
,
minimum_occupancy
,
minimum_occupancy
,
...
@@ -977,7 +977,7 @@ struct DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle
...
@@ -977,7 +977,7 @@ struct DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle
remove_reference_t
<
remove_reference_t
<
DeviceOp
::
CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock
>
,
DeviceOp
::
CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock
>
,
ComputePtrOffsetOfStridedBatch
<
I1
,
I1
,
I0
>
,
ComputePtrOffsetOfStridedBatch
<
I1
,
I1
,
I0
>
,
Num
Batch
ToMerge
,
Num
Groups
ToMerge
,
true
,
true
,
InMemoryDataOperationEnum
::
Set
,
InMemoryDataOperationEnum
::
Set
,
minimum_occupancy
,
minimum_occupancy
,
...
@@ -998,7 +998,7 @@ struct DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle
...
@@ -998,7 +998,7 @@ struct DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle
remove_reference_t
<
remove_reference_t
<
DeviceOp
::
CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock
>
,
DeviceOp
::
CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock
>
,
ComputePtrOffsetOfStridedBatch
<
I1
,
I1
,
I0
>
,
ComputePtrOffsetOfStridedBatch
<
I1
,
I1
,
I0
>
,
Num
Batch
ToMerge
,
Num
Groups
ToMerge
,
true
,
true
,
InMemoryDataOperationEnum
::
Set
,
InMemoryDataOperationEnum
::
Set
,
minimum_occupancy
,
minimum_occupancy
,
...
@@ -1019,7 +1019,7 @@ struct DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle
...
@@ -1019,7 +1019,7 @@ struct DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle
remove_reference_t
<
remove_reference_t
<
DeviceOp
::
CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock
>
,
DeviceOp
::
CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock
>
,
ComputePtrOffsetOfStridedBatch
<
I1
,
I1
,
I0
>
,
ComputePtrOffsetOfStridedBatch
<
I1
,
I1
,
I0
>
,
Num
Batch
ToMerge
,
Num
Groups
ToMerge
,
true
,
true
,
InMemoryDataOperationEnum
::
Set
,
InMemoryDataOperationEnum
::
Set
,
minimum_occupancy
,
minimum_occupancy
,
...
@@ -1039,7 +1039,7 @@ struct DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle
...
@@ -1039,7 +1039,7 @@ struct DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle
remove_reference_t
<
remove_reference_t
<
DeviceOp
::
CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock
>
,
DeviceOp
::
CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock
>
,
ComputePtrOffsetOfStridedBatch
<
I1
,
I1
,
I0
>
,
ComputePtrOffsetOfStridedBatch
<
I1
,
I1
,
I0
>
,
Num
Batch
ToMerge
,
Num
Groups
ToMerge
,
true
,
true
,
InMemoryDataOperationEnum
::
Set
,
InMemoryDataOperationEnum
::
Set
,
minimum_occupancy
,
minimum_occupancy
,
...
@@ -1060,7 +1060,7 @@ struct DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle
...
@@ -1060,7 +1060,7 @@ struct DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle
remove_reference_t
<
remove_reference_t
<
DeviceOp
::
CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock
>
,
DeviceOp
::
CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock
>
,
ComputePtrOffsetOfStridedBatch
<
I1
,
I1
,
I0
>
,
ComputePtrOffsetOfStridedBatch
<
I1
,
I1
,
I0
>
,
Num
Batch
ToMerge
,
Num
Groups
ToMerge
,
true
,
true
,
InMemoryDataOperationEnum
::
Set
,
InMemoryDataOperationEnum
::
Set
,
minimum_occupancy
,
minimum_occupancy
,
...
@@ -1084,7 +1084,7 @@ struct DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle
...
@@ -1084,7 +1084,7 @@ struct DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle
remove_reference_t
<
remove_reference_t
<
DeviceOp
::
CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock
>
,
DeviceOp
::
CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock
>
,
ComputePtrOffsetOfStridedBatch
<
I1
,
I1
,
I0
>
,
ComputePtrOffsetOfStridedBatch
<
I1
,
I1
,
I0
>
,
Num
Batch
ToMerge
,
Num
Groups
ToMerge
,
true
,
true
,
InMemoryDataOperationEnum
::
AtomicAdd
,
InMemoryDataOperationEnum
::
AtomicAdd
,
minimum_occupancy
,
minimum_occupancy
,
...
@@ -1100,7 +1100,7 @@ struct DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle
...
@@ -1100,7 +1100,7 @@ struct DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle
remove_reference_t
<
remove_reference_t
<
DeviceOp
::
CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock
>
,
DeviceOp
::
CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock
>
,
ComputePtrOffsetOfStridedBatch
<
I1
,
I1
,
I0
>
,
ComputePtrOffsetOfStridedBatch
<
I1
,
I1
,
I0
>
,
Num
Batch
ToMerge
,
Num
Groups
ToMerge
,
true
,
true
,
InMemoryDataOperationEnum
::
AtomicAdd
,
InMemoryDataOperationEnum
::
AtomicAdd
,
minimum_occupancy
,
minimum_occupancy
,
...
@@ -1119,7 +1119,7 @@ struct DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle
...
@@ -1119,7 +1119,7 @@ struct DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle
remove_reference_t
<
remove_reference_t
<
DeviceOp
::
CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock
>
,
DeviceOp
::
CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock
>
,
ComputePtrOffsetOfStridedBatch
<
I1
,
I1
,
I0
>
,
ComputePtrOffsetOfStridedBatch
<
I1
,
I1
,
I0
>
,
Num
Batch
ToMerge
,
Num
Groups
ToMerge
,
true
,
true
,
InMemoryDataOperationEnum
::
Set
,
InMemoryDataOperationEnum
::
Set
,
minimum_occupancy
,
minimum_occupancy
,
...
@@ -1135,7 +1135,7 @@ struct DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle
...
@@ -1135,7 +1135,7 @@ struct DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle
remove_reference_t
<
remove_reference_t
<
DeviceOp
::
CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock
>
,
DeviceOp
::
CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock
>
,
ComputePtrOffsetOfStridedBatch
<
I1
,
I1
,
I0
>
,
ComputePtrOffsetOfStridedBatch
<
I1
,
I1
,
I0
>
,
Num
Batch
ToMerge
,
Num
Groups
ToMerge
,
true
,
true
,
InMemoryDataOperationEnum
::
Set
,
InMemoryDataOperationEnum
::
Set
,
minimum_occupancy
,
minimum_occupancy
,
...
@@ -1157,7 +1157,7 @@ struct DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle
...
@@ -1157,7 +1157,7 @@ struct DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle
remove_reference_t
<
remove_reference_t
<
DeviceOp
::
CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock
>
,
DeviceOp
::
CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock
>
,
ComputePtrOffsetOfStridedBatch
<
I1
,
I1
,
I0
>
,
ComputePtrOffsetOfStridedBatch
<
I1
,
I1
,
I0
>
,
Num
Batch
ToMerge
,
Num
Groups
ToMerge
,
true
,
true
,
InMemoryDataOperationEnum
::
AtomicAdd
,
InMemoryDataOperationEnum
::
AtomicAdd
,
minimum_occupancy
,
minimum_occupancy
,
...
@@ -1173,7 +1173,7 @@ struct DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle
...
@@ -1173,7 +1173,7 @@ struct DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle
remove_reference_t
<
remove_reference_t
<
DeviceOp
::
CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock
>
,
DeviceOp
::
CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock
>
,
ComputePtrOffsetOfStridedBatch
<
I1
,
I1
,
I0
>
,
ComputePtrOffsetOfStridedBatch
<
I1
,
I1
,
I0
>
,
Num
Batch
ToMerge
,
Num
Groups
ToMerge
,
true
,
true
,
InMemoryDataOperationEnum
::
AtomicAdd
,
InMemoryDataOperationEnum
::
AtomicAdd
,
minimum_occupancy
,
minimum_occupancy
,
...
@@ -1192,7 +1192,7 @@ struct DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle
...
@@ -1192,7 +1192,7 @@ struct DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle
remove_reference_t
<
remove_reference_t
<
DeviceOp
::
CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock
>
,
DeviceOp
::
CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock
>
,
ComputePtrOffsetOfStridedBatch
<
I1
,
I1
,
I0
>
,
ComputePtrOffsetOfStridedBatch
<
I1
,
I1
,
I0
>
,
Num
Batch
ToMerge
,
Num
Groups
ToMerge
,
true
,
true
,
InMemoryDataOperationEnum
::
Set
,
InMemoryDataOperationEnum
::
Set
,
minimum_occupancy
,
minimum_occupancy
,
...
@@ -1208,7 +1208,7 @@ struct DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle
...
@@ -1208,7 +1208,7 @@ struct DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle
remove_reference_t
<
remove_reference_t
<
DeviceOp
::
CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock
>
,
DeviceOp
::
CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock
>
,
ComputePtrOffsetOfStridedBatch
<
I1
,
I1
,
I0
>
,
ComputePtrOffsetOfStridedBatch
<
I1
,
I1
,
I0
>
,
Num
Batch
ToMerge
,
Num
Groups
ToMerge
,
true
,
true
,
InMemoryDataOperationEnum
::
Set
,
InMemoryDataOperationEnum
::
Set
,
minimum_occupancy
,
minimum_occupancy
,
...
@@ -1232,7 +1232,7 @@ struct DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle
...
@@ -1232,7 +1232,7 @@ struct DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle
remove_reference_t
<
remove_reference_t
<
DeviceOp
::
CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock
>
,
DeviceOp
::
CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock
>
,
ComputePtrOffsetOfStridedBatch
<
I1
,
I1
,
I0
>
,
ComputePtrOffsetOfStridedBatch
<
I1
,
I1
,
I0
>
,
Num
Batch
ToMerge
,
Num
Groups
ToMerge
,
false
,
false
,
InMemoryDataOperationEnum
::
AtomicAdd
,
InMemoryDataOperationEnum
::
AtomicAdd
,
minimum_occupancy
>
;
minimum_occupancy
>
;
...
@@ -1247,7 +1247,7 @@ struct DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle
...
@@ -1247,7 +1247,7 @@ struct DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle
remove_reference_t
<
remove_reference_t
<
DeviceOp
::
CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock
>
,
DeviceOp
::
CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock
>
,
ComputePtrOffsetOfStridedBatch
<
I1
,
I1
,
I0
>
,
ComputePtrOffsetOfStridedBatch
<
I1
,
I1
,
I0
>
,
Num
Batch
ToMerge
,
Num
Groups
ToMerge
,
false
,
false
,
InMemoryDataOperationEnum
::
Set
,
InMemoryDataOperationEnum
::
Set
,
minimum_occupancy
>
;
minimum_occupancy
>
;
...
@@ -1389,7 +1389,7 @@ struct DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle
...
@@ -1389,7 +1389,7 @@ struct DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle
}
}
}
}
if
constexpr
(
Num
Batch
ToMerge
>
1
)
if
constexpr
(
Num
Groups
ToMerge
>
1
)
{
{
// support only if whole M and N can be proccessed on one block
// support only if whole M and N can be proccessed on one block
if
(
!
(
GemmM
<=
MPerBlock
&&
GemmN
<=
NPerBlock
))
if
(
!
(
GemmM
<=
MPerBlock
&&
GemmN
<=
NPerBlock
))
...
@@ -1400,7 +1400,7 @@ struct DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle
...
@@ -1400,7 +1400,7 @@ struct DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle
{
{
return
false
;
return
false
;
}
}
if
(
arg
.
Conv_G_
%
Num
Batch
ToMerge
!=
0
)
if
(
arg
.
Conv_G_
%
Num
Groups
ToMerge
!=
0
)
{
{
return
false
;
return
false
;
}
}
...
@@ -1563,7 +1563,7 @@ struct DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle
...
@@ -1563,7 +1563,7 @@ struct DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle
<<
BlkGemmPipelineSchedulerToString
[
BlkGemmPipeSched
]
<<
", "
<<
BlkGemmPipelineSchedulerToString
[
BlkGemmPipeSched
]
<<
", "
<<
"BlkGemmPipelineVersion: "
<<
"BlkGemmPipelineVersion: "
<<
BlkGemmPipelineVersionToString
[
BlkGemmPipelineVer
]
<<
", "
<<
BlkGemmPipelineVersionToString
[
BlkGemmPipelineVer
]
<<
", "
<<
Num
Batch
ToMerge
<<
Num
Groups
ToMerge
<<
">"
;
<<
">"
;
// clang-format on
// clang-format on
...
...
include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_abd_xdl_cshuffle.hpp
View file @
a8efb3f0
...
@@ -86,7 +86,6 @@ __global__ void
...
@@ -86,7 +86,6 @@ __global__ void
const
AElementwiseOperation
a_element_op
,
const
AElementwiseOperation
a_element_op
,
const
BElementwiseOperation
b_element_op
,
const
BElementwiseOperation
b_element_op
,
const
CDEElementwiseOperation
cde_element_op
,
const
CDEElementwiseOperation
cde_element_op
,
const
index_t
groups_count
,
const
AGridDesc_AK0_M_AK1
a_grid_desc_k0_m_k1
,
const
AGridDesc_AK0_M_AK1
a_grid_desc_k0_m_k1
,
const
BGridDesc_BK0_N_BK1
b_grid_desc_k0_n_k1
,
const
BGridDesc_BK0_N_BK1
b_grid_desc_k0_n_k1
,
const
DsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
const
DsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
...
@@ -101,14 +100,11 @@ __global__ void
...
@@ -101,14 +100,11 @@ __global__ void
defined(__gfx94__))
defined(__gfx94__))
// offset base pointer for each work-group
// offset base pointer for each work-group
const
index_t
num_blocks_per_batch
=
__builtin_amdgcn_readfirstlane
(
gridDim
.
y
/
groups_count
);
const
index_t
g_idx
=
__builtin_amdgcn_readfirstlane
(
blockIdx
.
y
);
const
index_t
&
num_blocks_per_n
=
groups_count
;
const
index_t
n_idx
=
__builtin_amdgcn_readfirstlane
(
blockIdx
.
z
);
const
index_t
g_idx
=
__builtin_amdgcn_readfirstlane
(
blockIdx
.
y
/
num_blocks_per_batch
);
const
long_index_t
e_group_offset
=
const
index_t
n_idx
=
__builtin_amdgcn_readfirstlane
(
blockIdx
.
y
/
num_blocks_per_n
);
const
long_index_t
e_batch_offset
=
amd_wave_read_first_lane
(
compute_ptr_offset_of_groups
.
GetEPtrOffset
(
g_idx
));
amd_wave_read_first_lane
(
compute_ptr_offset_of_groups
.
GetEPtrOffset
(
g_idx
));
const
auto
&
ds_
batch
_offset
=
compute_ptr_offset_of_groups
.
GetDsPtrOffset
(
g_idx
);
const
auto
&
ds_
group
_offset
=
compute_ptr_offset_of_groups
.
GetDsPtrOffset
(
g_idx
);
const
long_index_t
e_n_offset
=
const
long_index_t
e_n_offset
=
amd_wave_read_first_lane
(
compute_ptr_offset_of_n
.
GetEPtrOffset
(
n_idx
));
amd_wave_read_first_lane
(
compute_ptr_offset_of_n
.
GetEPtrOffset
(
n_idx
));
...
@@ -121,14 +117,14 @@ __global__ void
...
@@ -121,14 +117,14 @@ __global__ void
DsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
::
Size
();
DsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
::
Size
();
static_for
<
0
,
NumDTensor
,
1
>
{}(
static_for
<
0
,
NumDTensor
,
1
>
{}(
[
&
](
auto
i
)
{
p_ds_grid_grp
(
i
)
=
p_ds_grid
[
i
]
+
ds_
batch
_offset
[
i
];
});
[
&
](
auto
i
)
{
p_ds_grid_grp
(
i
)
=
p_ds_grid
[
i
]
+
ds_
group
_offset
[
i
];
});
if
constexpr
(
isMultiA
||
isMultiB
)
if
constexpr
(
isMultiA
||
isMultiB
)
{
{
AsPointer
p_as_grid_grp
;
AsPointer
p_as_grid_grp
;
BsPointer
p_bs_grid_grp
;
BsPointer
p_bs_grid_grp
;
const
auto
&
as_
batch
_offset
=
compute_ptr_offset_of_groups
.
GetAsPtrOffset
(
g_idx
);
const
auto
&
as_
group
_offset
=
compute_ptr_offset_of_groups
.
GetAsPtrOffset
(
g_idx
);
// compute_ptr_offset_of_n_ not need BatchStrideB so
// compute_ptr_offset_of_n_ not need BatchStrideB so
// in case of MultiA is false but isMultiB is true
// in case of MultiA is false but isMultiB is true
...
@@ -139,27 +135,27 @@ __global__ void
...
@@ -139,27 +135,27 @@ __global__ void
static
constexpr
index_t
NumATensor
=
AGridDesc_AK0_M_AK1
::
Size
();
static
constexpr
index_t
NumATensor
=
AGridDesc_AK0_M_AK1
::
Size
();
static_for
<
0
,
NumATensor
,
1
>
{}([
&
](
auto
i
)
{
static_for
<
0
,
NumATensor
,
1
>
{}([
&
](
auto
i
)
{
p_as_grid_grp
(
i
)
=
p_as_grid
[
i
]
+
as_
batch
_offset
[
i
]
+
as_n_offset
[
i
];
p_as_grid_grp
(
i
)
=
p_as_grid
[
i
]
+
as_
group
_offset
[
i
]
+
as_n_offset
[
i
];
});
});
}
}
else
else
{
{
const
long_index_t
a_n_offset
=
compute_ptr_offset_of_n
.
GetAPtrOffset
(
n_idx
);
const
long_index_t
a_n_offset
=
compute_ptr_offset_of_n
.
GetAPtrOffset
(
n_idx
);
static_for
<
0
,
1
,
1
>
{}(
static_for
<
0
,
1
,
1
>
{}(
[
&
](
auto
i
)
{
p_as_grid_grp
(
i
)
=
p_as_grid
[
i
]
+
as_
batch
_offset
[
i
]
+
a_n_offset
;
});
[
&
](
auto
i
)
{
p_as_grid_grp
(
i
)
=
p_as_grid
[
i
]
+
as_
group
_offset
[
i
]
+
a_n_offset
;
});
}
}
const
auto
&
bs_
batch
_offset
=
compute_ptr_offset_of_groups
.
GetBsPtrOffset
(
g_idx
);
const
auto
&
bs_
group
_offset
=
compute_ptr_offset_of_groups
.
GetBsPtrOffset
(
g_idx
);
static
constexpr
index_t
NumBTensor
=
BGridDesc_BK0_N_BK1
::
Size
();
static
constexpr
index_t
NumBTensor
=
BGridDesc_BK0_N_BK1
::
Size
();
static_for
<
0
,
NumBTensor
,
1
>
{}(
static_for
<
0
,
NumBTensor
,
1
>
{}(
[
&
](
auto
i
)
{
p_bs_grid_grp
(
i
)
=
p_bs_grid
[
i
]
+
bs_
batch
_offset
[
i
];
});
[
&
](
auto
i
)
{
p_bs_grid_grp
(
i
)
=
p_bs_grid
[
i
]
+
bs_
group
_offset
[
i
];
});
GridwiseGemm
::
template
Run
<
HasMainKBlockLoop
>(
GridwiseGemm
::
template
Run
<
HasMainKBlockLoop
>(
p_as_grid_grp
,
p_as_grid_grp
,
p_bs_grid_grp
,
p_bs_grid_grp
,
p_ds_grid_grp
,
p_ds_grid_grp
,
p_e_grid
+
e_
batch
_offset
+
e_n_offset
,
p_e_grid
+
e_
group
_offset
+
e_n_offset
,
p_shared
,
p_shared
,
a_element_op
,
a_element_op
,
b_element_op
,
b_element_op
,
...
@@ -172,19 +168,19 @@ __global__ void
...
@@ -172,19 +168,19 @@ __global__ void
}
}
else
else
{
{
const
long_index_t
a_
batch
_offset
=
const
long_index_t
a_
group
_offset
=
amd_wave_read_first_lane
(
compute_ptr_offset_of_groups
.
GetAPtrOffset
(
g_idx
));
amd_wave_read_first_lane
(
compute_ptr_offset_of_groups
.
GetAPtrOffset
(
g_idx
));
const
long_index_t
b_
batch
_offset
=
const
long_index_t
b_
group
_offset
=
amd_wave_read_first_lane
(
compute_ptr_offset_of_groups
.
GetBPtrOffset
(
g_idx
));
amd_wave_read_first_lane
(
compute_ptr_offset_of_groups
.
GetBPtrOffset
(
g_idx
));
const
long_index_t
a_n_offset
=
const
long_index_t
a_n_offset
=
amd_wave_read_first_lane
(
compute_ptr_offset_of_n
.
GetAPtrOffset
(
n_idx
));
amd_wave_read_first_lane
(
compute_ptr_offset_of_n
.
GetAPtrOffset
(
n_idx
));
GridwiseGemm
::
template
Run
<
HasMainKBlockLoop
>(
GridwiseGemm
::
template
Run
<
HasMainKBlockLoop
>(
p_as_grid
+
a_
batch
_offset
+
a_n_offset
,
p_as_grid
+
a_
group
_offset
+
a_n_offset
,
p_bs_grid
+
b_
batch
_offset
,
p_bs_grid
+
b_
group
_offset
,
p_ds_grid_grp
,
p_ds_grid_grp
,
p_e_grid
+
e_
batch
_offset
+
e_n_offset
,
p_e_grid
+
e_
group
_offset
+
e_n_offset
,
p_shared
,
p_shared
,
a_element_op
,
a_element_op
,
b_element_op
,
b_element_op
,
...
@@ -200,7 +196,6 @@ __global__ void
...
@@ -200,7 +196,6 @@ __global__ void
ignore
=
p_bs_grid
;
ignore
=
p_bs_grid
;
ignore
=
p_ds_grid
;
ignore
=
p_ds_grid
;
ignore
=
p_e_grid
;
ignore
=
p_e_grid
;
ignore
=
groups_count
;
ignore
=
a_grid_desc_k0_m_k1
;
ignore
=
a_grid_desc_k0_m_k1
;
ignore
=
b_grid_desc_k0_n_k1
;
ignore
=
b_grid_desc_k0_n_k1
;
ignore
=
ds_grid_desc_mblock_mperblock_nblock_nperblock
;
ignore
=
ds_grid_desc_mblock_mperblock_nblock_nperblock
;
...
@@ -287,7 +282,8 @@ template <index_t NDimSpatial,
...
@@ -287,7 +282,8 @@ template <index_t NDimSpatial,
// in tuple for MultiAB), unpack if tuple was
// in tuple for MultiAB), unpack if tuple was
// passed
// passed
typename
BComputeDataType
=
AComputeDataType
,
typename
BComputeDataType
=
AComputeDataType
,
LoopScheduler
LoopSched
=
make_default_loop_scheduler
()
>
LoopScheduler
LoopSched
=
make_default_loop_scheduler
(),
index_t
NumGroupsToMerge
=
1
>
struct
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
struct
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
:
public
DeviceGroupedConvFwdMultipleABD
<
NDimSpatial
,
:
public
DeviceGroupedConvFwdMultipleABD
<
NDimSpatial
,
ALayout
,
ALayout
,
...
@@ -306,6 +302,8 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
...
@@ -306,6 +302,8 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
{
{
using
DeviceOp
=
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
;
using
DeviceOp
=
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
;
static_assert
(
NumGroupsToMerge
>=
1
);
static
constexpr
bool
isMultiA
=
is_detected
<
is_tuple
,
ADataType
>::
value
;
static
constexpr
bool
isMultiA
=
is_detected
<
is_tuple
,
ADataType
>::
value
;
static
constexpr
bool
isMultiB
=
is_detected
<
is_tuple
,
BDataType
>::
value
;
static
constexpr
bool
isMultiB
=
is_detected
<
is_tuple
,
BDataType
>::
value
;
...
@@ -319,7 +317,7 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
...
@@ -319,7 +317,7 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
static
constexpr
auto
I3
=
Number
<
3
>
{};
static
constexpr
auto
I3
=
Number
<
3
>
{};
static
constexpr
auto
conv_to_gemm_transformer
=
static
constexpr
auto
conv_to_gemm_transformer
=
TransformConvFwdToGemm
<
NDimSpatial
,
ConvForwardSpecialization
>
{};
TransformConvFwdToGemm
<
NDimSpatial
,
ConvForwardSpecialization
,
NumGroupsToMerge
>
{};
static
constexpr
auto
matrix_padder
=
static
constexpr
auto
matrix_padder
=
MatrixPadder
<
GemmSpec
,
index_t
,
index_t
,
index_t
>
{
MPerBlock
,
NPerBlock
,
KPerBlock
};
MatrixPadder
<
GemmSpec
,
index_t
,
index_t
,
index_t
>
{
MPerBlock
,
NPerBlock
,
KPerBlock
};
...
@@ -550,7 +548,8 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
...
@@ -550,7 +548,8 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
{
{
static_for
<
0
,
NumATensor
,
1
>
{}([
&
](
auto
i
)
{
static_for
<
0
,
NumATensor
,
1
>
{}([
&
](
auto
i
)
{
// Init compute_ptr_offset_of_groups_ for multiple AB
// Init compute_ptr_offset_of_groups_ for multiple AB
compute_ptr_offset_of_groups_
.
BatchStrideA_
(
i
)
=
a_g_n_c_wis_strides
[
0
];
compute_ptr_offset_of_groups_
.
BatchStrideA_
(
i
)
=
a_g_n_c_wis_strides
[
0
]
*
NumGroupsToMerge
;
// Use GemmADataType/GemmBDataType to iterate over tuple (even if passed data
// Use GemmADataType/GemmBDataType to iterate over tuple (even if passed data
// type is not tuple)
// type is not tuple)
...
@@ -578,7 +577,8 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
...
@@ -578,7 +577,8 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
});
});
static_for
<
0
,
NumBTensor
,
1
>
{}([
&
](
auto
i
)
{
static_for
<
0
,
NumBTensor
,
1
>
{}([
&
](
auto
i
)
{
// Init compute_ptr_offset_of_groups_ for multiple AB
// Init compute_ptr_offset_of_groups_ for multiple AB
compute_ptr_offset_of_groups_
.
BatchStrideB_
(
i
)
=
b_g_k_c_xs_strides
[
0
];
compute_ptr_offset_of_groups_
.
BatchStrideB_
(
i
)
=
b_g_k_c_xs_strides
[
0
]
*
NumGroupsToMerge
;
using
DataType
=
remove_cvref_t
<
tuple_element_t
<
i
.
value
,
GemmBDataType
>>
;
using
DataType
=
remove_cvref_t
<
tuple_element_t
<
i
.
value
,
GemmBDataType
>>
;
// It is possible that one of the AB is a pointer and one is a tuple.
// It is possible that one of the AB is a pointer and one is a tuple.
...
@@ -598,8 +598,10 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
...
@@ -598,8 +598,10 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
}
}
else
else
{
{
compute_ptr_offset_of_groups_
.
BatchStrideA_
=
a_g_n_c_wis_strides
[
0
];
compute_ptr_offset_of_groups_
.
BatchStrideA_
=
compute_ptr_offset_of_groups_
.
BatchStrideB_
=
b_g_k_c_xs_strides
[
0
];
a_g_n_c_wis_strides
[
0
]
*
NumGroupsToMerge
;
compute_ptr_offset_of_groups_
.
BatchStrideB_
=
b_g_k_c_xs_strides
[
0
]
*
NumGroupsToMerge
;
compute_ptr_offset_of_n_
.
BatchStrideA_
=
a_g_n_c_wis_strides
[
1
]
*
conv_N_per_block_
;
compute_ptr_offset_of_n_
.
BatchStrideA_
=
a_g_n_c_wis_strides
[
1
]
*
conv_N_per_block_
;
// p_as and p_bs are pointers
// p_as and p_bs are pointers
...
@@ -616,7 +618,8 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
...
@@ -616,7 +618,8 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
p_ds_grid_
(
i
)
=
static_cast
<
const
DDataType
*>
(
p_ds
[
i
]);
p_ds_grid_
(
i
)
=
static_cast
<
const
DDataType
*>
(
p_ds
[
i
]);
// D batch stride
// D batch stride
compute_ptr_offset_of_groups_
.
BatchStrideDs_
(
i
)
=
ds_g_n_k_wos_strides
[
i
][
0
];
compute_ptr_offset_of_groups_
.
BatchStrideDs_
(
i
)
=
ds_g_n_k_wos_strides
[
i
][
0
]
*
NumGroupsToMerge
;
compute_ptr_offset_of_n_
.
BatchStrideDs_
(
i
)
=
compute_ptr_offset_of_n_
.
BatchStrideDs_
(
i
)
=
ds_g_n_k_wos_strides
[
i
][
1
]
*
conv_N_per_block_
;
ds_g_n_k_wos_strides
[
i
][
1
]
*
conv_N_per_block_
;
...
@@ -624,7 +627,7 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
...
@@ -624,7 +627,7 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
ds_grid_desc_m_n_
(
i
)
=
DeviceOp
::
MakeEGridDescriptor_M_N
<
DLayout
>
(
ds_grid_desc_m_n_
(
i
)
=
DeviceOp
::
MakeEGridDescriptor_M_N
<
DLayout
>
(
e_g_n_k_wos_lengths
,
ds_g_n_k_wos_strides
[
i
],
conv_N_per_block_
);
e_g_n_k_wos_lengths
,
ds_g_n_k_wos_strides
[
i
],
conv_N_per_block_
);
});
});
compute_ptr_offset_of_groups_
.
BatchStrideE_
=
e_g_n_k_wos_strides
[
0
];
compute_ptr_offset_of_groups_
.
BatchStrideE_
=
e_g_n_k_wos_strides
[
0
]
*
NumGroupsToMerge
;
compute_ptr_offset_of_n_
.
BatchStrideE_
=
e_g_n_k_wos_strides
[
1
]
*
conv_N_per_block_
;
compute_ptr_offset_of_n_
.
BatchStrideE_
=
e_g_n_k_wos_strides
[
1
]
*
conv_N_per_block_
;
// populate desc for Ds/E
// populate desc for Ds/E
...
@@ -745,8 +748,8 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
...
@@ -745,8 +748,8 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
arg
.
a_g_n_c_wis_lengths_
[
I1
]
/
arg
.
conv_N_per_block_
;
arg
.
a_g_n_c_wis_lengths_
[
I1
]
/
arg
.
conv_N_per_block_
;
const
index_t
gdx
=
arg
.
block_2_etile_map_
.
CalculateGridSize
(
arg
.
e_grid_desc_m_n_
);
const
index_t
gdx
=
arg
.
block_2_etile_map_
.
CalculateGridSize
(
arg
.
e_grid_desc_m_n_
);
const
index_t
gdy
=
arg
.
num_group_
*
num_workgroups_per_Conv_N
;
const
index_t
gdy
=
arg
.
num_group_
/
NumGroupsToMerge
;
const
index_t
gdz
=
1
;
const
index_t
gdz
=
num_workgroups_per_Conv_N
;
const
auto
K
=
const
auto
K
=
arg
.
a_grid_desc_ak0_m_ak1_
.
GetLength
(
I0
)
*
arg
.
a_grid_desc_ak0_m_ak1_
.
GetLength
(
I2
);
arg
.
a_grid_desc_ak0_m_ak1_
.
GetLength
(
I0
)
*
arg
.
a_grid_desc_ak0_m_ak1_
.
GetLength
(
I2
);
...
@@ -795,7 +798,6 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
...
@@ -795,7 +798,6 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
arg
.
a_element_op_
,
arg
.
a_element_op_
,
arg
.
b_element_op_
,
arg
.
b_element_op_
,
arg
.
cde_element_op_
,
arg
.
cde_element_op_
,
arg
.
a_g_n_c_wis_lengths_
[
0
],
// Group count
as_grid_desc_ak0_m_ak1
,
as_grid_desc_ak0_m_ak1
,
bs_grid_desc_bk0_n_bk1
,
bs_grid_desc_bk0_n_bk1
,
arg
.
ds_grid_desc_mblock_mperblock_nblock_nperblock_
,
arg
.
ds_grid_desc_mblock_mperblock_nblock_nperblock_
,
...
@@ -839,7 +841,6 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
...
@@ -839,7 +841,6 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
arg
.
a_element_op_
,
arg
.
a_element_op_
,
arg
.
b_element_op_
,
arg
.
b_element_op_
,
arg
.
cde_element_op_
,
arg
.
cde_element_op_
,
arg
.
a_g_n_c_wis_lengths_
[
0
],
// Group count
arg
.
a_grid_desc_ak0_m_ak1_
,
arg
.
a_grid_desc_ak0_m_ak1_
,
arg
.
b_grid_desc_bk0_n_bk1_
,
arg
.
b_grid_desc_bk0_n_bk1_
,
arg
.
ds_grid_desc_mblock_mperblock_nblock_nperblock_
,
arg
.
ds_grid_desc_mblock_mperblock_nblock_nperblock_
,
...
@@ -871,6 +872,10 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
...
@@ -871,6 +872,10 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
{
{
namespace
ctc
=
tensor_layout
::
convolution
;
namespace
ctc
=
tensor_layout
::
convolution
;
const
index_t
G
=
arg
.
b_g_k_c_xs_lengths_
[
I0
];
const
index_t
K
=
arg
.
b_g_k_c_xs_lengths_
[
I1
];
const
index_t
C
=
arg
.
b_g_k_c_xs_lengths_
[
I2
];
// check device
// check device
if
(
get_device_name
()
==
"gfx908"
)
if
(
get_device_name
()
==
"gfx908"
)
{
{
...
@@ -919,6 +924,42 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
...
@@ -919,6 +924,42 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
}
}
}
}
}
}
else
if
constexpr
(
ConvForwardSpecialization
==
ConvolutionForwardSpecialization
::
Filter3x3
)
{
if
(
C
!=
1
)
{
return
false
;
}
for
(
index_t
i
=
0
;
i
<
NDimSpatial
;
++
i
)
{
const
index_t
filter_spatial_dim
=
arg
.
b_g_k_c_xs_lengths_
[
i
+
I3
];
if
(
filter_spatial_dim
!=
I3
)
{
return
false
;
}
}
if
constexpr
(
!
is_NSpatialGK_GKSpatial_NSpatialGC
<
ALayout
,
BLayout
,
ELayout
>
())
{
return
false
;
}
}
if
constexpr
(
NumGroupsToMerge
>
1
)
{
if
(
!
(
C
==
1
))
{
return
false
;
}
if
(
G
%
NumGroupsToMerge
!=
0
)
{
return
false
;
}
if
constexpr
(
!
is_NSpatialGK_GKSpatial_NSpatialGC
<
ALayout
,
BLayout
,
ELayout
>
())
{
return
false
;
}
}
// check vector access of A
// check vector access of A
// FIXME: layout
// FIXME: layout
...
@@ -928,11 +969,16 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
...
@@ -928,11 +969,16 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
is_same_v
<
ALayout
,
ctc
::
NWGC
>
||
is_same_v
<
ALayout
,
ctc
::
NHWGC
>
||
is_same_v
<
ALayout
,
ctc
::
NWGC
>
||
is_same_v
<
ALayout
,
ctc
::
NHWGC
>
||
is_same_v
<
ALayout
,
ctc
::
NDHWGC
>
)
is_same_v
<
ALayout
,
ctc
::
NDHWGC
>
)
{
{
const
index_t
C
=
arg
.
a_g_n_c_wis_lengths_
[
2
];
// Check access per C
if
(
!
(
ABlockTransferSrcVectorDim
==
2
&&
C
%
ABlockTransferSrcScalarPerVector
==
0
))
if
(
!
(
ABlockTransferSrcVectorDim
==
2
&&
C
%
ABlockTransferSrcScalarPerVector
==
0
))
{
{
return
false
;
// If not possible, check access per G
if
(
!
(
ABlockTransferSrcVectorDim
==
1
&&
C
==
1
&&
is_NSpatialGK_GKSpatial_NSpatialGC
<
ALayout
,
BLayout
,
ELayout
>
()
&&
G
%
ABlockTransferSrcScalarPerVector
==
0
))
{
return
false
;
}
}
}
}
}
else
else
...
@@ -949,8 +995,6 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
...
@@ -949,8 +995,6 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
is_same_v
<
BLayout
,
ctc
::
KZYXGC
>
)
is_same_v
<
BLayout
,
ctc
::
KZYXGC
>
)
{
{
const
index_t
C
=
arg
.
b_g_k_c_xs_lengths_
[
2
];
if
(
!
(
BBlockTransferSrcVectorDim
==
2
&&
C
%
BBlockTransferSrcScalarPerVector
==
0
))
if
(
!
(
BBlockTransferSrcVectorDim
==
2
&&
C
%
BBlockTransferSrcScalarPerVector
==
0
))
{
{
return
false
;
return
false
;
...
@@ -974,8 +1018,6 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
...
@@ -974,8 +1018,6 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
is_same_v
<
DLayout
,
ctc
::
NWGK
>
||
is_same_v
<
DLayout
,
ctc
::
NHWGK
>
||
is_same_v
<
DLayout
,
ctc
::
NWGK
>
||
is_same_v
<
DLayout
,
ctc
::
NHWGK
>
||
is_same_v
<
DLayout
,
ctc
::
NDHWGK
>
||
is_same_v
<
DLayout
,
ctc
::
G_K
>
)
is_same_v
<
DLayout
,
ctc
::
NDHWGK
>
||
is_same_v
<
DLayout
,
ctc
::
G_K
>
)
{
{
const
index_t
K
=
arg
.
ds_g_n_k_wos_lengths_
[
i
][
2
];
if
(
!
(
K
%
CDEBlockTransferScalarPerVector_NPerBlock
==
0
))
if
(
!
(
K
%
CDEBlockTransferScalarPerVector_NPerBlock
==
0
))
{
{
valid
=
false
;
valid
=
false
;
...
@@ -1020,8 +1062,6 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
...
@@ -1020,8 +1062,6 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
is_same_v
<
ELayout
,
ctc
::
NWGK
>
||
is_same_v
<
ELayout
,
ctc
::
NHWGK
>
||
is_same_v
<
ELayout
,
ctc
::
NWGK
>
||
is_same_v
<
ELayout
,
ctc
::
NHWGK
>
||
is_same_v
<
ELayout
,
ctc
::
NDHWGK
>
)
is_same_v
<
ELayout
,
ctc
::
NDHWGK
>
)
{
{
const
index_t
K
=
arg
.
e_g_n_k_wos_lengths_
[
2
];
if
(
!
(
K
%
CDEBlockTransferScalarPerVector_NPerBlock
==
0
))
if
(
!
(
K
%
CDEBlockTransferScalarPerVector_NPerBlock
==
0
))
{
{
return
false
;
return
false
;
...
@@ -1172,7 +1212,8 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
...
@@ -1172,7 +1212,8 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
<<
BBlockTransferSrcScalarPerVector
<<
", "
<<
BBlockTransferSrcScalarPerVector
<<
", "
<<
CDEBlockTransferScalarPerVector_NPerBlock
<<
", "
<<
CDEBlockTransferScalarPerVector_NPerBlock
<<
", "
<<
CShuffleMXdlPerWavePerShuffle
<<
", "
<<
CShuffleMXdlPerWavePerShuffle
<<
", "
<<
CShuffleNXdlPerWavePerShuffle
<<
CShuffleNXdlPerWavePerShuffle
<<
", "
<<
NumGroupsToMerge
<<
">"
;
<<
">"
;
// clang-format on
// clang-format on
...
...
include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_utils.hpp
View file @
a8efb3f0
...
@@ -59,6 +59,22 @@ constexpr bool is_GNDHWK_GKZYXC_GNDHWC()
...
@@ -59,6 +59,22 @@ constexpr bool is_GNDHWK_GKZYXC_GNDHWC()
is_same_v
<
OutLayout
,
tensor_layout
::
convolution
::
GNDHWK
>
;
is_same_v
<
OutLayout
,
tensor_layout
::
convolution
::
GNDHWK
>
;
}
}
template
<
typename
InLayout
,
typename
WeiLayout
,
typename
OutLayout
>
constexpr
bool
is_NSpatialGK_GKSpatial_NSpatialGC
()
{
return
is_NWGK_GKXC_NWGC
<
InLayout
,
WeiLayout
,
OutLayout
>
()
||
is_NHWGK_GKYXC_NHWGC
<
InLayout
,
WeiLayout
,
OutLayout
>
()
||
is_NDHWGK_GKZYXC_NDHWGC
<
InLayout
,
WeiLayout
,
OutLayout
>
();
}
template
<
typename
InLayout
,
typename
WeiLayout
,
typename
OutLayout
>
constexpr
bool
is_GNSpatialK_GKSpatial_GNSpatialC
()
{
return
is_GNWK_GKXC_GNWC
<
InLayout
,
WeiLayout
,
OutLayout
>
()
||
is_GNHWK_GKYXC_GNHWC
<
InLayout
,
WeiLayout
,
OutLayout
>
()
||
is_GNDHWK_GKZYXC_GNDHWC
<
InLayout
,
WeiLayout
,
OutLayout
>
();
}
template
<
index_t
NumATensor
=
1
,
index_t
NumBTensor
=
1
,
index_t
NumDTensor
=
0
,
typename
=
void
>
template
<
index_t
NumATensor
=
1
,
index_t
NumBTensor
=
1
,
index_t
NumDTensor
=
0
,
typename
=
void
>
struct
ComputePtrOffsetOfStridedBatch
struct
ComputePtrOffsetOfStridedBatch
{
{
...
...
include/ck/tensor_operation/gpu/element/unary_element_wise_operation.hpp
View file @
a8efb3f0
...
@@ -1025,6 +1025,31 @@ struct ConvScale
...
@@ -1025,6 +1025,31 @@ struct ConvScale
float
scale_out_
;
float
scale_out_
;
};
};
struct
ConvScaleRelu
{
__host__
__device__
ConvScaleRelu
(
float
scale_in
=
1.
f
,
float
scale_wei
=
1.
f
,
float
scale_out
=
1.
f
)
:
scale_in_
(
scale_in
),
scale_wei_
(
scale_wei
),
scale_out_
(
scale_out
)
{
}
template
<
typename
E
,
typename
C
>
__host__
__device__
void
operator
()(
E
&
e
,
const
C
&
c
)
const
;
template
<
>
__host__
__device__
void
operator
()
<
f8_t
,
float
>
(
f8_t
&
e
,
const
float
&
c
)
const
{
float
x
;
Relu
{}.
template
operator
()
<
float
>(
x
,
c
*
scale_in_
*
scale_wei_
);
e
=
type_convert
<
f8_t
>
(
x
*
scale_out_
);
};
float
scale_in_
;
float
scale_wei_
;
float
scale_out_
;
};
// support fastconvert of int8 to fp16
// support fastconvert of int8 to fp16
template
<
typename
InputDataType
,
typename
OutputDataType
,
index_t
RegPackNumber
>
template
<
typename
InputDataType
,
typename
OutputDataType
,
index_t
RegPackNumber
>
...
...
include/ck/tensor_operation/operator_transform/transform_conv_bwd_weight_to_gemm_v2.hpp
View file @
a8efb3f0
...
@@ -27,7 +27,7 @@ template <index_t NDimSpatial,
...
@@ -27,7 +27,7 @@ template <index_t NDimSpatial,
index_t
NPerBlock
,
index_t
NPerBlock
,
index_t
GemmK1Number
,
index_t
GemmK1Number
,
index_t
K0PerBlock
,
index_t
K0PerBlock
,
index_t
Num
Batch
ToMerge
,
index_t
Num
Groups
ToMerge
,
device
::
ConvolutionBackwardWeightSpecialization
ConvBackwardWeightSpecialization
>
device
::
ConvolutionBackwardWeightSpecialization
ConvBackwardWeightSpecialization
>
struct
TransformConvBwdWeightToGemmV2
struct
TransformConvBwdWeightToGemmV2
{
{
...
@@ -45,7 +45,7 @@ struct TransformConvBwdWeightToGemmV2
...
@@ -45,7 +45,7 @@ struct TransformConvBwdWeightToGemmV2
const
index_t
BatchStride
=
output_strides
[
0
];
const
index_t
BatchStride
=
output_strides
[
0
];
const
index_t
WoStride
=
output_strides
[
4
];
const
index_t
WoStride
=
output_strides
[
4
];
const
auto
KStride
=
Number
<
1
>
{};
const
auto
KStride
=
Number
<
1
>
{};
return
make_naive_tensor_descriptor
(
make_tuple
(
N
*
Ho
*
Wo
,
Num
Batch
ToMerge
,
K
),
return
make_naive_tensor_descriptor
(
make_tuple
(
N
*
Ho
*
Wo
,
Num
Groups
ToMerge
,
K
),
make_tuple
(
WoStride
,
BatchStride
,
KStride
));
make_tuple
(
WoStride
,
BatchStride
,
KStride
));
}
}
...
@@ -65,13 +65,13 @@ struct TransformConvBwdWeightToGemmV2
...
@@ -65,13 +65,13 @@ struct TransformConvBwdWeightToGemmV2
if
constexpr
(
ConvBackwardWeightSpecialization
==
if
constexpr
(
ConvBackwardWeightSpecialization
==
device
::
ConvolutionBackwardWeightSpecialization
::
Filter1x1Stride1Pad0
)
device
::
ConvolutionBackwardWeightSpecialization
::
Filter1x1Stride1Pad0
)
{
{
return
make_naive_tensor_descriptor
(
make_tuple
(
N
*
Hi
*
Wi
,
Num
Batch
ToMerge
,
C
),
return
make_naive_tensor_descriptor
(
make_tuple
(
N
*
Hi
*
Wi
,
Num
Groups
ToMerge
,
C
),
make_tuple
(
WiStride
,
BatchStride
,
CStride
));
make_tuple
(
WiStride
,
BatchStride
,
CStride
));
}
}
else
else
{
{
return
make_naive_tensor_descriptor
(
return
make_naive_tensor_descriptor
(
make_tuple
(
N
,
Hi
,
Wi
,
Num
Batch
ToMerge
,
C
),
make_tuple
(
N
,
Hi
,
Wi
,
Num
Groups
ToMerge
,
C
),
make_tuple
(
NStride
,
HiStride
,
WiStride
,
BatchStride
,
CStride
));
make_tuple
(
NStride
,
HiStride
,
WiStride
,
BatchStride
,
CStride
));
}
}
}
}
...
@@ -88,30 +88,30 @@ struct TransformConvBwdWeightToGemmV2
...
@@ -88,30 +88,30 @@ struct TransformConvBwdWeightToGemmV2
const
auto
KStride
=
weights_strides
[
1
];
const
auto
KStride
=
weights_strides
[
1
];
const
auto
XStride
=
weights_strides
[
4
];
const
auto
XStride
=
weights_strides
[
4
];
const
auto
BatchStride
=
weights_strides
[
0
];
const
auto
BatchStride
=
weights_strides
[
0
];
// Add Num
Batch
ToMerge for Batch+M dimension and, 1 as a placehorder
// Add Num
Groups
ToMerge for Batch+M dimension and, 1 as a placehorder
// for Batch+N dimension
// for Batch+N dimension
const
auto
desc
=
make_naive_tensor_descriptor
(
const
auto
desc
=
make_naive_tensor_descriptor
(
make_tuple
(
Num
Batch
ToMerge
,
K
,
Y
*
X
,
1
,
C
),
make_tuple
(
Num
Groups
ToMerge
,
K
,
Y
*
X
,
1
,
C
),
make_tuple
(
BatchStride
,
KStride
,
XStride
,
BatchStride
,
CStride
));
make_tuple
(
BatchStride
,
KStride
,
XStride
,
BatchStride
,
CStride
));
// Padd 1 to Num
Batch
ToMerge
// Padd 1 to Num
Groups
ToMerge
const
auto
padded_desc
=
transform_tensor_descriptor
(
const
auto
padded_desc
=
transform_tensor_descriptor
(
desc
,
desc
,
make_tuple
(
make_pass_through_transform
(
Num
Batch
ToMerge
),
make_tuple
(
make_pass_through_transform
(
Num
Groups
ToMerge
),
make_pass_through_transform
(
K
),
make_pass_through_transform
(
K
),
make_pass_through_transform
(
Y
*
X
),
make_pass_through_transform
(
Y
*
X
),
make_pad_transform
(
1
,
0
,
Num
Batch
ToMerge
-
1
),
make_pad_transform
(
1
,
0
,
Num
Groups
ToMerge
-
1
),
make_pass_through_transform
(
C
)),
make_pass_through_transform
(
C
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{},
Sequence
<
3
>
{},
Sequence
<
4
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{},
Sequence
<
3
>
{},
Sequence
<
4
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{},
Sequence
<
3
>
{},
Sequence
<
4
>
{}));
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{},
Sequence
<
3
>
{},
Sequence
<
4
>
{}));
// We need only matrices from diagonal. Xor returns 0 for the same
// We need only matrices from diagonal. Xor returns 0 for the same
// values. So if matrices is not on diagonal then it will be stored in padding.
// values. So if matrices is not on diagonal then it will be stored in padding.
// To avoid use of modulo after xor we assume that NumBatch to merge is power of 2.
// To avoid use of modulo after xor we assume that NumBatch to merge is power of 2.
static_assert
(
Num
Batch
ToMerge
==
1
||
Num
Batch
ToMerge
==
2
||
Num
Batch
ToMerge
==
4
||
static_assert
(
Num
Groups
ToMerge
==
1
||
Num
Groups
ToMerge
==
2
||
Num
Groups
ToMerge
==
4
||
Num
Batch
ToMerge
==
8
||
Num
Batch
ToMerge
==
16
||
Num
Batch
ToMerge
==
32
||
Num
Groups
ToMerge
==
8
||
Num
Groups
ToMerge
==
16
||
Num
Groups
ToMerge
==
32
||
Num
Batch
ToMerge
==
64
);
Num
Groups
ToMerge
==
64
);
const
auto
unmerged_padded_desc
=
transform_tensor_descriptor
(
const
auto
unmerged_padded_desc
=
transform_tensor_descriptor
(
padded_desc
,
padded_desc
,
make_tuple
(
make_xor_transform
(
make_tuple
(
Num
Batch
ToMerge
,
Num
Batch
ToMerge
)),
make_tuple
(
make_xor_transform
(
make_tuple
(
Num
Groups
ToMerge
,
Num
Groups
ToMerge
)),
make_pass_through_transform
(
K
),
make_pass_through_transform
(
K
),
make_pass_through_transform
(
Y
*
X
),
make_pass_through_transform
(
Y
*
X
),
make_pass_through_transform
(
C
)),
make_pass_through_transform
(
C
)),
...
@@ -120,8 +120,8 @@ struct TransformConvBwdWeightToGemmV2
...
@@ -120,8 +120,8 @@ struct TransformConvBwdWeightToGemmV2
// Merge To M, N
// Merge To M, N
return
transform_tensor_descriptor
(
return
transform_tensor_descriptor
(
unmerged_padded_desc
,
unmerged_padded_desc
,
make_tuple
(
make_merge_transform
(
make_tuple
(
Num
Batch
ToMerge
,
K
)),
make_tuple
(
make_merge_transform
(
make_tuple
(
Num
Groups
ToMerge
,
K
)),
make_merge_transform
(
make_tuple
(
Y
*
X
,
Num
Batch
ToMerge
,
C
))),
make_merge_transform
(
make_tuple
(
Y
*
X
,
Num
Groups
ToMerge
,
C
))),
make_tuple
(
Sequence
<
0
,
1
>
{},
Sequence
<
2
,
3
,
4
>
{}),
make_tuple
(
Sequence
<
0
,
1
>
{},
Sequence
<
2
,
3
,
4
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}));
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}));
}
}
...
@@ -138,7 +138,7 @@ struct TransformConvBwdWeightToGemmV2
...
@@ -138,7 +138,7 @@ struct TransformConvBwdWeightToGemmV2
const
index_t
BatchStride
=
output_strides
[
0
];
const
index_t
BatchStride
=
output_strides
[
0
];
const
index_t
WoStride
=
output_strides
[
5
];
const
index_t
WoStride
=
output_strides
[
5
];
const
auto
KStride
=
Number
<
1
>
{};
const
auto
KStride
=
Number
<
1
>
{};
return
make_naive_tensor_descriptor
(
make_tuple
(
N
*
Do
*
Ho
*
Wo
,
Num
Batch
ToMerge
,
K
),
return
make_naive_tensor_descriptor
(
make_tuple
(
N
*
Do
*
Ho
*
Wo
,
Num
Groups
ToMerge
,
K
),
make_tuple
(
WoStride
,
BatchStride
,
KStride
));
make_tuple
(
WoStride
,
BatchStride
,
KStride
));
}
}
...
@@ -160,13 +160,13 @@ struct TransformConvBwdWeightToGemmV2
...
@@ -160,13 +160,13 @@ struct TransformConvBwdWeightToGemmV2
if
constexpr
(
ConvBackwardWeightSpecialization
==
if
constexpr
(
ConvBackwardWeightSpecialization
==
device
::
ConvolutionBackwardWeightSpecialization
::
Filter1x1Stride1Pad0
)
device
::
ConvolutionBackwardWeightSpecialization
::
Filter1x1Stride1Pad0
)
{
{
return
make_naive_tensor_descriptor
(
make_tuple
(
N
*
Di
*
Hi
*
Wi
,
Num
Batch
ToMerge
,
C
),
return
make_naive_tensor_descriptor
(
make_tuple
(
N
*
Di
*
Hi
*
Wi
,
Num
Groups
ToMerge
,
C
),
make_tuple
(
WiStride
,
BatchStride
,
CStride
));
make_tuple
(
WiStride
,
BatchStride
,
CStride
));
}
}
else
else
{
{
return
make_naive_tensor_descriptor
(
return
make_naive_tensor_descriptor
(
make_tuple
(
N
,
Di
,
Hi
,
Wi
,
Num
Batch
ToMerge
,
C
),
make_tuple
(
N
,
Di
,
Hi
,
Wi
,
Num
Groups
ToMerge
,
C
),
make_tuple
(
NStride
,
DiStride
,
HiStride
,
WiStride
,
BatchStride
,
CStride
));
make_tuple
(
NStride
,
DiStride
,
HiStride
,
WiStride
,
BatchStride
,
CStride
));
}
}
}
}
...
@@ -184,29 +184,29 @@ struct TransformConvBwdWeightToGemmV2
...
@@ -184,29 +184,29 @@ struct TransformConvBwdWeightToGemmV2
const
auto
KStride
=
weights_strides
[
1
];
const
auto
KStride
=
weights_strides
[
1
];
const
auto
XStride
=
weights_strides
[
5
];
const
auto
XStride
=
weights_strides
[
5
];
const
auto
BatchStride
=
weights_strides
[
0
];
const
auto
BatchStride
=
weights_strides
[
0
];
// Add Num
Batch
ToMerge for Batch+M dimension and, 1 for placehord for Batch+N dimension
// Add Num
Groups
ToMerge for Batch+M dimension and, 1 for placehord for Batch+N dimension
const
auto
desc
=
make_naive_tensor_descriptor
(
const
auto
desc
=
make_naive_tensor_descriptor
(
make_tuple
(
Num
Batch
ToMerge
,
K
,
Z
*
Y
*
X
,
1
,
C
),
make_tuple
(
Num
Groups
ToMerge
,
K
,
Z
*
Y
*
X
,
1
,
C
),
make_tuple
(
BatchStride
,
KStride
,
XStride
,
BatchStride
,
CStride
));
make_tuple
(
BatchStride
,
KStride
,
XStride
,
BatchStride
,
CStride
));
// Padd 1 to Num
Batch
ToMerge
// Padd 1 to Num
Groups
ToMerge
const
auto
padded_desc
=
transform_tensor_descriptor
(
const
auto
padded_desc
=
transform_tensor_descriptor
(
desc
,
desc
,
make_tuple
(
make_pass_through_transform
(
Num
Batch
ToMerge
),
make_tuple
(
make_pass_through_transform
(
Num
Groups
ToMerge
),
make_pass_through_transform
(
K
),
make_pass_through_transform
(
K
),
make_pass_through_transform
(
Z
*
Y
*
X
),
make_pass_through_transform
(
Z
*
Y
*
X
),
make_pad_transform
(
1
,
0
,
Num
Batch
ToMerge
-
1
),
make_pad_transform
(
1
,
0
,
Num
Groups
ToMerge
-
1
),
make_pass_through_transform
(
C
)),
make_pass_through_transform
(
C
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{},
Sequence
<
3
>
{},
Sequence
<
4
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{},
Sequence
<
3
>
{},
Sequence
<
4
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{},
Sequence
<
3
>
{},
Sequence
<
4
>
{}));
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{},
Sequence
<
3
>
{},
Sequence
<
4
>
{}));
// We need only matrices from diagonal. Xor returns 0 for the same
// We need only matrices from diagonal. Xor returns 0 for the same
// values. So if matrices is not on diagonal then it will be stored in padding.
// values. So if matrices is not on diagonal then it will be stored in padding.
// To avoid use of modulo after xor we assume that NumBatch to merge is power of 2.
// To avoid use of modulo after xor we assume that NumBatch to merge is power of 2.
static_assert
(
Num
Batch
ToMerge
==
1
||
Num
Batch
ToMerge
==
2
||
Num
Batch
ToMerge
==
4
||
static_assert
(
Num
Groups
ToMerge
==
1
||
Num
Groups
ToMerge
==
2
||
Num
Groups
ToMerge
==
4
||
Num
Batch
ToMerge
==
8
||
Num
Batch
ToMerge
==
16
||
Num
Batch
ToMerge
==
32
||
Num
Groups
ToMerge
==
8
||
Num
Groups
ToMerge
==
16
||
Num
Groups
ToMerge
==
32
||
Num
Batch
ToMerge
==
64
);
Num
Groups
ToMerge
==
64
);
const
auto
unmerged_padded_desc
=
transform_tensor_descriptor
(
const
auto
unmerged_padded_desc
=
transform_tensor_descriptor
(
padded_desc
,
padded_desc
,
make_tuple
(
make_xor_transform
(
make_tuple
(
Num
Batch
ToMerge
,
Num
Batch
ToMerge
)),
make_tuple
(
make_xor_transform
(
make_tuple
(
Num
Groups
ToMerge
,
Num
Groups
ToMerge
)),
make_pass_through_transform
(
K
),
make_pass_through_transform
(
K
),
make_pass_through_transform
(
Z
*
Y
*
X
),
make_pass_through_transform
(
Z
*
Y
*
X
),
make_pass_through_transform
(
C
)),
make_pass_through_transform
(
C
)),
...
@@ -215,8 +215,8 @@ struct TransformConvBwdWeightToGemmV2
...
@@ -215,8 +215,8 @@ struct TransformConvBwdWeightToGemmV2
// Merge To M, N
// Merge To M, N
return
transform_tensor_descriptor
(
return
transform_tensor_descriptor
(
unmerged_padded_desc
,
unmerged_padded_desc
,
make_tuple
(
make_merge_transform
(
make_tuple
(
Num
Batch
ToMerge
,
K
)),
make_tuple
(
make_merge_transform
(
make_tuple
(
Num
Groups
ToMerge
,
K
)),
make_merge_transform
(
make_tuple
(
Z
*
Y
*
X
,
Num
Batch
ToMerge
,
C
))),
make_merge_transform
(
make_tuple
(
Z
*
Y
*
X
,
Num
Groups
ToMerge
,
C
))),
make_tuple
(
Sequence
<
0
,
1
>
{},
Sequence
<
2
,
3
,
4
>
{}),
make_tuple
(
Sequence
<
0
,
1
>
{},
Sequence
<
2
,
3
,
4
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}));
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}));
}
}
...
@@ -262,8 +262,8 @@ struct TransformConvBwdWeightToGemmV2
...
@@ -262,8 +262,8 @@ struct TransformConvBwdWeightToGemmV2
const
index_t
InRightPadW
=
input_right_pads
[
1
];
const
index_t
InRightPadW
=
input_right_pads
[
1
];
const
index_t
GemmKTotal
=
N
*
Ho
*
Wo
;
const
index_t
GemmKTotal
=
N
*
Ho
*
Wo
;
const
index_t
GemmM
=
K
*
Num
Batch
ToMerge
;
const
index_t
GemmM
=
K
*
Num
Groups
ToMerge
;
const
index_t
GemmN
=
C
*
X
*
Y
*
Num
Batch
ToMerge
;
const
index_t
GemmN
=
C
*
X
*
Y
*
Num
Groups
ToMerge
;
const
auto
PadGemmM
=
MPerBlock
-
GemmM
%
MPerBlock
;
const
auto
PadGemmM
=
MPerBlock
-
GemmM
%
MPerBlock
;
const
auto
PadGemmN
=
NPerBlock
-
GemmN
%
NPerBlock
;
const
auto
PadGemmN
=
NPerBlock
-
GemmN
%
NPerBlock
;
...
@@ -286,7 +286,7 @@ struct TransformConvBwdWeightToGemmV2
...
@@ -286,7 +286,7 @@ struct TransformConvBwdWeightToGemmV2
out_grid_desc
,
out_grid_desc
,
make_tuple
(
make_tuple
(
make_right_pad_transform
(
GemmKTotal
,
GemmKPad
-
GemmKTotal
),
make_right_pad_transform
(
GemmKTotal
,
GemmKPad
-
GemmKTotal
),
make_merge_transform
(
make_tuple
(
Num
Batch
ToMerge
,
GemmM
/
Num
Batch
ToMerge
))),
make_merge_transform
(
make_tuple
(
Num
Groups
ToMerge
,
GemmM
/
Num
Groups
ToMerge
))),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
,
2
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
,
2
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}));
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}));
...
@@ -302,7 +302,7 @@ struct TransformConvBwdWeightToGemmV2
...
@@ -302,7 +302,7 @@ struct TransformConvBwdWeightToGemmV2
in_grid_desc
,
in_grid_desc
,
make_tuple
(
make_tuple
(
make_right_pad_transform
(
GemmKTotal
,
GemmKPad
-
GemmKTotal
),
make_right_pad_transform
(
GemmKTotal
,
GemmKPad
-
GemmKTotal
),
make_merge_transform
(
make_tuple
(
Num
Batch
ToMerge
,
GemmN
/
Num
Batch
ToMerge
))),
make_merge_transform
(
make_tuple
(
Num
Groups
ToMerge
,
GemmN
/
Num
Groups
ToMerge
))),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
,
2
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
,
2
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}));
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}));
...
@@ -324,7 +324,7 @@ struct TransformConvBwdWeightToGemmV2
...
@@ -324,7 +324,7 @@ struct TransformConvBwdWeightToGemmV2
out_grid_desc
,
out_grid_desc
,
make_tuple
(
make_tuple
(
make_right_pad_transform
(
GemmKTotal
,
GemmKPad
-
GemmKTotal
),
make_right_pad_transform
(
GemmKTotal
,
GemmKPad
-
GemmKTotal
),
make_merge_transform
(
make_tuple
(
Num
Batch
ToMerge
,
GemmM
/
Num
Batch
ToMerge
))),
make_merge_transform
(
make_tuple
(
Num
Groups
ToMerge
,
GemmM
/
Num
Groups
ToMerge
))),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
,
2
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
,
2
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}));
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}));
...
@@ -341,7 +341,7 @@ struct TransformConvBwdWeightToGemmV2
...
@@ -341,7 +341,7 @@ struct TransformConvBwdWeightToGemmV2
make_tuple
(
make_pass_through_transform
(
N
),
make_tuple
(
make_pass_through_transform
(
N
),
make_pad_transform
(
Hi
,
InLeftPadH
,
InRightPadH
),
make_pad_transform
(
Hi
,
InLeftPadH
,
InRightPadH
),
make_pad_transform
(
Wi
,
InLeftPadW
,
InRightPadW
),
make_pad_transform
(
Wi
,
InLeftPadW
,
InRightPadW
),
make_pass_through_transform
(
Num
Batch
ToMerge
),
make_pass_through_transform
(
Num
Groups
ToMerge
),
make_pass_through_transform
(
C
)),
make_pass_through_transform
(
C
)),
make_tuple
(
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{},
Sequence
<
3
>
{},
Sequence
<
4
>
{}),
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{},
Sequence
<
3
>
{},
Sequence
<
4
>
{}),
...
@@ -354,7 +354,7 @@ struct TransformConvBwdWeightToGemmV2
...
@@ -354,7 +354,7 @@ struct TransformConvBwdWeightToGemmV2
make_pass_through_transform
(
N
),
make_pass_through_transform
(
N
),
make_embed_transform
(
make_tuple
(
Y
,
Ho
),
make_tuple
(
ConvDilationH
,
ConvStrideH
)),
make_embed_transform
(
make_tuple
(
Y
,
Ho
),
make_tuple
(
ConvDilationH
,
ConvStrideH
)),
make_embed_transform
(
make_tuple
(
X
,
Wo
),
make_tuple
(
ConvDilationW
,
ConvStrideW
)),
make_embed_transform
(
make_tuple
(
X
,
Wo
),
make_tuple
(
ConvDilationW
,
ConvStrideW
)),
make_pass_through_transform
(
Num
Batch
ToMerge
),
make_pass_through_transform
(
Num
Groups
ToMerge
),
make_pass_through_transform
(
C
)),
make_pass_through_transform
(
C
)),
make_tuple
(
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{},
Sequence
<
3
>
{},
Sequence
<
4
>
{}),
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{},
Sequence
<
3
>
{},
Sequence
<
4
>
{}),
...
@@ -366,7 +366,7 @@ struct TransformConvBwdWeightToGemmV2
...
@@ -366,7 +366,7 @@ struct TransformConvBwdWeightToGemmV2
const
auto
in_gemmktotal_gemmn_grid_desc
=
transform_tensor_descriptor
(
const
auto
in_gemmktotal_gemmn_grid_desc
=
transform_tensor_descriptor
(
in_n_y_ho_x_wo_c_grid_desc
,
in_n_y_ho_x_wo_c_grid_desc
,
make_tuple
(
make_merge_transform
(
make_tuple
(
Y
,
X
,
Num
Batch
ToMerge
,
C
)),
make_tuple
(
make_merge_transform
(
make_tuple
(
Y
,
X
,
Num
Groups
ToMerge
,
C
)),
make_merge_transform
(
make_tuple
(
N
,
Ho
,
Wo
))),
make_merge_transform
(
make_tuple
(
N
,
Ho
,
Wo
))),
make_tuple
(
Sequence
<
1
,
3
,
5
,
6
>
{},
Sequence
<
0
,
2
,
4
>
{}),
make_tuple
(
Sequence
<
1
,
3
,
5
,
6
>
{},
Sequence
<
0
,
2
,
4
>
{}),
make_tuple
(
Sequence
<
1
>
{},
Sequence
<
0
>
{}));
make_tuple
(
Sequence
<
1
>
{},
Sequence
<
0
>
{}));
...
@@ -465,8 +465,8 @@ struct TransformConvBwdWeightToGemmV2
...
@@ -465,8 +465,8 @@ struct TransformConvBwdWeightToGemmV2
const
index_t
InRightPadW
=
input_right_pads
[
2
];
const
index_t
InRightPadW
=
input_right_pads
[
2
];
const
index_t
GemmKTotal
=
N
*
Do
*
Ho
*
Wo
;
const
index_t
GemmKTotal
=
N
*
Do
*
Ho
*
Wo
;
const
index_t
GemmM
=
K
*
Num
Batch
ToMerge
;
const
index_t
GemmM
=
K
*
Num
Groups
ToMerge
;
const
index_t
GemmN
=
C
*
Z
*
X
*
Y
*
Num
Batch
ToMerge
;
const
index_t
GemmN
=
C
*
Z
*
X
*
Y
*
Num
Groups
ToMerge
;
const
auto
PadGemmM
=
MPerBlock
-
GemmM
%
MPerBlock
;
const
auto
PadGemmM
=
MPerBlock
-
GemmM
%
MPerBlock
;
const
auto
PadGemmN
=
NPerBlock
-
GemmN
%
NPerBlock
;
const
auto
PadGemmN
=
NPerBlock
-
GemmN
%
NPerBlock
;
...
@@ -489,7 +489,7 @@ struct TransformConvBwdWeightToGemmV2
...
@@ -489,7 +489,7 @@ struct TransformConvBwdWeightToGemmV2
out_grid_desc
,
out_grid_desc
,
make_tuple
(
make_tuple
(
make_right_pad_transform
(
GemmKTotal
,
GemmKPad
-
GemmKTotal
),
make_right_pad_transform
(
GemmKTotal
,
GemmKPad
-
GemmKTotal
),
make_merge_transform
(
make_tuple
(
Num
Batch
ToMerge
,
GemmM
/
Num
Batch
ToMerge
))),
make_merge_transform
(
make_tuple
(
Num
Groups
ToMerge
,
GemmM
/
Num
Groups
ToMerge
))),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
,
2
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
,
2
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}));
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}));
...
@@ -505,7 +505,7 @@ struct TransformConvBwdWeightToGemmV2
...
@@ -505,7 +505,7 @@ struct TransformConvBwdWeightToGemmV2
in_grid_desc
,
in_grid_desc
,
make_tuple
(
make_tuple
(
make_right_pad_transform
(
GemmKTotal
,
GemmKPad
-
GemmKTotal
),
make_right_pad_transform
(
GemmKTotal
,
GemmKPad
-
GemmKTotal
),
make_merge_transform
(
make_tuple
(
Num
Batch
ToMerge
,
GemmN
/
Num
Batch
ToMerge
))),
make_merge_transform
(
make_tuple
(
Num
Groups
ToMerge
,
GemmN
/
Num
Groups
ToMerge
))),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
,
2
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
,
2
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}));
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}));
...
@@ -527,7 +527,7 @@ struct TransformConvBwdWeightToGemmV2
...
@@ -527,7 +527,7 @@ struct TransformConvBwdWeightToGemmV2
out_grid_desc
,
out_grid_desc
,
make_tuple
(
make_tuple
(
make_right_pad_transform
(
GemmKTotal
,
GemmKPad
-
GemmKTotal
),
make_right_pad_transform
(
GemmKTotal
,
GemmKPad
-
GemmKTotal
),
make_merge_transform
(
make_tuple
(
Num
Batch
ToMerge
,
GemmM
/
Num
Batch
ToMerge
))),
make_merge_transform
(
make_tuple
(
Num
Groups
ToMerge
,
GemmM
/
Num
Groups
ToMerge
))),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
,
2
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
,
2
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}));
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}));
...
@@ -545,7 +545,7 @@ struct TransformConvBwdWeightToGemmV2
...
@@ -545,7 +545,7 @@ struct TransformConvBwdWeightToGemmV2
make_pad_transform
(
Di
,
InLeftPadD
,
InRightPadD
),
make_pad_transform
(
Di
,
InLeftPadD
,
InRightPadD
),
make_pad_transform
(
Hi
,
InLeftPadH
,
InRightPadH
),
make_pad_transform
(
Hi
,
InLeftPadH
,
InRightPadH
),
make_pad_transform
(
Wi
,
InLeftPadW
,
InRightPadW
),
make_pad_transform
(
Wi
,
InLeftPadW
,
InRightPadW
),
make_pass_through_transform
(
Num
Batch
ToMerge
),
make_pass_through_transform
(
Num
Groups
ToMerge
),
make_pass_through_transform
(
C
)),
make_pass_through_transform
(
C
)),
make_tuple
(
Sequence
<
0
>
{},
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
1
>
{},
...
@@ -567,7 +567,7 @@ struct TransformConvBwdWeightToGemmV2
...
@@ -567,7 +567,7 @@ struct TransformConvBwdWeightToGemmV2
make_embed_transform
(
make_tuple
(
Z
,
Do
),
make_tuple
(
ConvDilationD
,
ConvStrideD
)),
make_embed_transform
(
make_tuple
(
Z
,
Do
),
make_tuple
(
ConvDilationD
,
ConvStrideD
)),
make_embed_transform
(
make_tuple
(
Y
,
Ho
),
make_tuple
(
ConvDilationH
,
ConvStrideH
)),
make_embed_transform
(
make_tuple
(
Y
,
Ho
),
make_tuple
(
ConvDilationH
,
ConvStrideH
)),
make_embed_transform
(
make_tuple
(
X
,
Wo
),
make_tuple
(
ConvDilationW
,
ConvStrideW
)),
make_embed_transform
(
make_tuple
(
X
,
Wo
),
make_tuple
(
ConvDilationW
,
ConvStrideW
)),
make_pass_through_transform
(
Num
Batch
ToMerge
),
make_pass_through_transform
(
Num
Groups
ToMerge
),
make_pass_through_transform
(
C
)),
make_pass_through_transform
(
C
)),
make_tuple
(
Sequence
<
0
>
{},
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
1
>
{},
...
@@ -584,7 +584,7 @@ struct TransformConvBwdWeightToGemmV2
...
@@ -584,7 +584,7 @@ struct TransformConvBwdWeightToGemmV2
const
auto
in_gemmktotal_gemmn_grid_desc
=
transform_tensor_descriptor
(
const
auto
in_gemmktotal_gemmn_grid_desc
=
transform_tensor_descriptor
(
in_n_z_do_y_ho_x_wo_c_grid_desc
,
in_n_z_do_y_ho_x_wo_c_grid_desc
,
make_tuple
(
make_merge_transform
(
make_tuple
(
Z
,
Y
,
X
,
Num
Batch
ToMerge
,
C
)),
make_tuple
(
make_merge_transform
(
make_tuple
(
Z
,
Y
,
X
,
Num
Groups
ToMerge
,
C
)),
make_merge_transform
(
make_tuple
(
N
,
Do
,
Ho
,
Wo
))),
make_merge_transform
(
make_tuple
(
N
,
Do
,
Ho
,
Wo
))),
make_tuple
(
Sequence
<
1
,
3
,
5
,
7
,
8
>
{},
Sequence
<
0
,
2
,
4
,
6
>
{}),
make_tuple
(
Sequence
<
1
,
3
,
5
,
7
,
8
>
{},
Sequence
<
0
,
2
,
4
,
6
>
{}),
make_tuple
(
Sequence
<
1
>
{},
Sequence
<
0
>
{}));
make_tuple
(
Sequence
<
1
>
{},
Sequence
<
0
>
{}));
...
...
include/ck/tensor_operation/operator_transform/transform_conv_fwd_to_gemm.hpp
View file @
a8efb3f0
This diff is collapsed.
Click to expand it.
library/include/ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_weight/device_grouped_conv_bwd_weight_two_stage_xdl_instance.hpp
View file @
a8efb3f0
...
@@ -40,10 +40,10 @@ template <ck::index_t NDimSpatial,
...
@@ -40,10 +40,10 @@ template <ck::index_t NDimSpatial,
BlockGemmPipelineVersion
PipelineVersion
>
BlockGemmPipelineVersion
PipelineVersion
>
using
device_grouped_conv_bwd_weight_two_stage_xdl_c_shuffle_f16_instances
=
std
::
tuple
<
using
device_grouped_conv_bwd_weight_two_stage_xdl_c_shuffle_f16_instances
=
std
::
tuple
<
// clang-format off
// clang-format off
//#########################################| Num| InLayout| WeiLayout| OutLayout| InData| WeiData| OutData| AccData| In| Wei| Out| ConvBackward| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransfer| CBlockTransfer| BlockGemm| BlockGemm| Num
Batch
|
//#########################################| Num| InLayout| WeiLayout| OutLayout| InData| WeiData| OutData| AccData| In| Wei| Out| ConvBackward| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransfer| CBlockTransfer| BlockGemm| BlockGemm| Num
Groups
|
//#########################################| Dim| | | | Type| Type| Type| Type| Elementwise| Elementwise| Elementwise| Weight| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| ClusterLengths| ScalarPerVector| Pipeline| Pipeline| ToMerge|
//#########################################| Dim| | | | Type| Type| Type| Type| Elementwise| Elementwise| Elementwise| Weight| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| ClusterLengths| ScalarPerVector| Pipeline| Pipeline|
ToMerge|
//#########################################| Spatial| | | | | | | | Operation| Operation| Operation| Specialization| | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| MBlock_MPerBlock| NWaveNPerXdl| Scheduler| Version| |
//#########################################| Spatial| | | | | | | | Operation| Operation| Operation| Specialization| | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| MBlock_MPerBlock| NWaveNPerXdl| Scheduler| Version|
|
//#########################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | NBlock_NPerBlock| | | | |
//#########################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | NBlock_NPerBlock| | | |
|
DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle
<
NDimSpatial
,
ALayout
,
BLayout
,
ELayout
,
F16
,
F16
,
F16
,
F32
,
PassThrough
,
PassThrough
,
PassThrough
,
ConvSpec
,
64
,
16
,
16
,
32
,
8
,
16
,
16
,
1
,
1
,
S
<
4
,
8
,
1
>
,
S
<
2
,
0
,
1
>
,
S
<
1
,
0
,
2
>
,
1
,
1
,
4
,
false
,
S
<
4
,
8
,
1
>
,
S
<
2
,
0
,
1
>
,
S
<
1
,
0
,
2
>
,
1
,
1
,
4
,
false
,
1
,
1
,
S
<
1
,
8
,
1
,
8
>
,
1
,
Scheduler
,
PipelineVersion
,
1
>
,
DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle
<
NDimSpatial
,
ALayout
,
BLayout
,
ELayout
,
F16
,
F16
,
F16
,
F32
,
PassThrough
,
PassThrough
,
PassThrough
,
ConvSpec
,
64
,
16
,
16
,
32
,
8
,
16
,
16
,
1
,
1
,
S
<
4
,
8
,
1
>
,
S
<
2
,
0
,
1
>
,
S
<
1
,
0
,
2
>
,
1
,
1
,
4
,
false
,
S
<
4
,
8
,
1
>
,
S
<
2
,
0
,
1
>
,
S
<
1
,
0
,
2
>
,
1
,
1
,
4
,
false
,
1
,
1
,
S
<
1
,
8
,
1
,
8
>
,
1
,
Scheduler
,
PipelineVersion
,
1
>
,
DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle
<
NDimSpatial
,
ALayout
,
BLayout
,
ELayout
,
F16
,
F16
,
F16
,
F32
,
PassThrough
,
PassThrough
,
PassThrough
,
ConvSpec
,
64
,
32
,
32
,
32
,
8
,
32
,
32
,
1
,
1
,
S
<
4
,
8
,
1
>
,
S
<
2
,
0
,
1
>
,
S
<
1
,
0
,
2
>
,
1
,
2
,
2
,
false
,
S
<
4
,
16
,
1
>
,
S
<
2
,
0
,
1
>
,
S
<
1
,
0
,
2
>
,
1
,
2
,
2
,
false
,
1
,
1
,
S
<
1
,
8
,
1
,
8
>
,
1
,
Scheduler
,
PipelineVersion
,
2
>
,
DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle
<
NDimSpatial
,
ALayout
,
BLayout
,
ELayout
,
F16
,
F16
,
F16
,
F32
,
PassThrough
,
PassThrough
,
PassThrough
,
ConvSpec
,
64
,
32
,
32
,
32
,
8
,
32
,
32
,
1
,
1
,
S
<
4
,
8
,
1
>
,
S
<
2
,
0
,
1
>
,
S
<
1
,
0
,
2
>
,
1
,
2
,
2
,
false
,
S
<
4
,
16
,
1
>
,
S
<
2
,
0
,
1
>
,
S
<
1
,
0
,
2
>
,
1
,
2
,
2
,
false
,
1
,
1
,
S
<
1
,
8
,
1
,
8
>
,
1
,
Scheduler
,
PipelineVersion
,
2
>
,
DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle
<
NDimSpatial
,
ALayout
,
BLayout
,
ELayout
,
F16
,
F16
,
F16
,
F32
,
PassThrough
,
PassThrough
,
PassThrough
,
ConvSpec
,
64
,
32
,
64
,
32
,
8
,
32
,
32
,
1
,
2
,
S
<
4
,
8
,
1
>
,
S
<
2
,
0
,
1
>
,
S
<
1
,
0
,
2
>
,
1
,
4
,
4
,
false
,
S
<
4
,
16
,
1
>
,
S
<
2
,
0
,
1
>
,
S
<
1
,
0
,
2
>
,
1
,
4
,
4
,
false
,
1
,
1
,
S
<
1
,
8
,
1
,
8
>
,
1
,
Scheduler
,
PipelineVersion
,
4
>
,
DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle
<
NDimSpatial
,
ALayout
,
BLayout
,
ELayout
,
F16
,
F16
,
F16
,
F32
,
PassThrough
,
PassThrough
,
PassThrough
,
ConvSpec
,
64
,
32
,
64
,
32
,
8
,
32
,
32
,
1
,
2
,
S
<
4
,
8
,
1
>
,
S
<
2
,
0
,
1
>
,
S
<
1
,
0
,
2
>
,
1
,
4
,
4
,
false
,
S
<
4
,
16
,
1
>
,
S
<
2
,
0
,
1
>
,
S
<
1
,
0
,
2
>
,
1
,
4
,
4
,
false
,
1
,
1
,
S
<
1
,
8
,
1
,
8
>
,
1
,
Scheduler
,
PipelineVersion
,
4
>
,
...
...
library/include/ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_merged_groups_instance.hpp
0 → 100644
View file @
a8efb3f0
// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
#include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_abd_xdl_cshuffle.hpp"
#include "ck/tensor_operation/gpu/device/convolution_forward_specialization.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
namespace
ck
{
namespace
tensor_operation
{
namespace
device
{
namespace
instance
{
using
BF16
=
ck
::
bhalf_t
;
using
F16
=
ck
::
half_t
;
using
F32
=
float
;
template
<
ck
::
index_t
...
Is
>
using
S
=
ck
::
Sequence
<
Is
...
>
;
using
Empty_Tuple
=
ck
::
Tuple
<>
;
using
namespace
ck
::
tensor_layout
::
convolution
;
using
PassThrough
=
ck
::
tensor_operation
::
element_wise
::
PassThrough
;
static
constexpr
auto
ConvFwdDefault
=
ck
::
tensor_operation
::
device
::
ConvolutionForwardSpecialization
::
Default
;
static
constexpr
auto
ConvFwd3x3
=
ConvolutionForwardSpecialization
::
Filter3x3
;
static
constexpr
auto
GemmMNKPadding
=
GemmSpecialization
::
MNKPadding
;
template
<
index_t
NDimSpatial
,
typename
ALayout
,
typename
BLayout
,
typename
DsLayout
,
typename
ELayout
,
ConvolutionForwardSpecialization
ConvSpec
>
using
device_grouped_conv_fwd_xdl_merged_groups_bf16_instances
=
std
::
tuple
<
// clang-format off
//########################################| NumDim| A| B| Ds| E| AData| BData| AccData| CShuffle| Ds| EData| A| B| CDE| ConvForward| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| ACompute| BCompute| BlockGemm| NumGroups|
//########################################| Spatial| Layout| Layout| Layout| Layout| Type| Type| Type| DataType| DataType| Type| Elementwise| Elementwise| Elementwise| Specialization| Specialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| Type| Type| Pipeline| ToMerge|
//########################################| | | | | | | | | | | | Operation| Operation| Operation| | | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| | | Scheduler| |
//########################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
// Instances with NumGroupsPerBatch > 1
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
<
NDimSpatial
,
ALayout
,
BLayout
,
DsLayout
,
ELayout
,
BF16
,
BF16
,
F32
,
BF16
,
DsLayout
,
BF16
,
PassThrough
,
PassThrough
,
PassThrough
,
ConvSpec
,
GemmMNKPadding
,
1
,
64
,
64
,
16
,
16
,
4
,
4
,
16
,
16
,
4
,
1
,
S
<
4
,
16
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
4
,
1
,
S
<
4
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
1
,
4
,
1
,
1
,
1
,
S
<
1
,
16
,
1
,
4
>
,
1
,
BF16
,
BF16
,
LoopScheduler
::
Default
,
8
>
,
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
<
NDimSpatial
,
ALayout
,
BLayout
,
DsLayout
,
ELayout
,
BF16
,
BF16
,
F32
,
BF16
,
DsLayout
,
BF16
,
PassThrough
,
PassThrough
,
PassThrough
,
ConvSpec
,
GemmMNKPadding
,
1
,
64
,
64
,
16
,
16
,
4
,
4
,
16
,
16
,
4
,
1
,
S
<
4
,
16
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
4
,
1
,
S
<
4
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
1
,
4
,
1
,
1
,
1
,
S
<
1
,
16
,
1
,
4
>
,
1
,
BF16
,
BF16
,
LoopScheduler
::
Default
,
16
>
,
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
<
NDimSpatial
,
ALayout
,
BLayout
,
DsLayout
,
ELayout
,
BF16
,
BF16
,
F32
,
BF16
,
DsLayout
,
BF16
,
PassThrough
,
PassThrough
,
PassThrough
,
ConvSpec
,
GemmMNKPadding
,
1
,
64
,
64
,
16
,
16
,
4
,
4
,
16
,
16
,
4
,
1
,
S
<
4
,
16
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
4
,
1
,
S
<
4
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
1
,
4
,
1
,
1
,
1
,
S
<
1
,
16
,
1
,
4
>
,
1
,
BF16
,
BF16
,
LoopScheduler
::
Default
,
32
>
// clang-format on
>
;
template
<
index_t
NDimSpatial
,
typename
ALayout
,
typename
BLayout
,
typename
DsLayout
,
typename
ELayout
,
ConvolutionForwardSpecialization
ConvSpec
>
using
device_grouped_conv_fwd_xdl_merged_groups_f16_instances
=
std
::
tuple
<
// clang-format off
//########################################| NumDim| A| B| Ds| E| AData| BData| AccData| CShuffle| Ds| EData| A| B| CDE| ConvForward| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer|
//########################################| Spatial| Layout| Layout| Layout| Layout| Type| Type| Type| DataType| DataType| Type| Elementwise| Elementwise| Elementwise| Specialization| Specialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector|
//########################################| | | | | | | | | | | | Operation| Operation| Operation| | | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl|
//########################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
// Instances with NumGroupsPerBatch > 1
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
<
NDimSpatial
,
ALayout
,
BLayout
,
DsLayout
,
ELayout
,
F16
,
F16
,
F32
,
F16
,
DsLayout
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
,
ConvSpec
,
GemmMNKPadding
,
1
,
64
,
64
,
16
,
16
,
4
,
4
,
16
,
16
,
4
,
1
,
S
<
4
,
16
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
4
,
1
,
S
<
4
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
1
,
4
,
1
,
1
,
1
,
S
<
1
,
16
,
1
,
4
>
,
1
,
F16
,
F16
,
LoopScheduler
::
Default
,
8
>
,
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
<
NDimSpatial
,
ALayout
,
BLayout
,
DsLayout
,
ELayout
,
F16
,
F16
,
F32
,
F16
,
DsLayout
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
,
ConvSpec
,
GemmMNKPadding
,
1
,
64
,
64
,
16
,
16
,
4
,
4
,
16
,
16
,
4
,
1
,
S
<
4
,
16
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
4
,
1
,
S
<
4
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
1
,
4
,
1
,
1
,
1
,
S
<
1
,
16
,
1
,
4
>
,
1
,
F16
,
F16
,
LoopScheduler
::
Default
,
16
>
,
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
<
NDimSpatial
,
ALayout
,
BLayout
,
DsLayout
,
ELayout
,
F16
,
F16
,
F32
,
F16
,
DsLayout
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
,
ConvSpec
,
GemmMNKPadding
,
1
,
64
,
64
,
16
,
16
,
4
,
4
,
16
,
16
,
4
,
1
,
S
<
4
,
16
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
4
,
1
,
S
<
4
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
1
,
4
,
1
,
1
,
1
,
S
<
1
,
16
,
1
,
4
>
,
1
,
F16
,
F16
,
LoopScheduler
::
Default
,
32
>
// clang-format on
>
;
template
<
index_t
NDimSpatial
,
typename
ALayout
,
typename
BLayout
,
typename
DsLayout
,
typename
ELayout
,
ConvolutionForwardSpecialization
ConvSpec
>
using
device_grouped_conv_fwd_xdl_merged_groups_f32_instances
=
std
::
tuple
<
// clang-format off
//########################################| NumDim| A| B| Ds| E| AData| BData| AccData| CShuffle| Ds| EData| A| B| CDE| ConvForward| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer|
//########################################| Spatial| Layout| Layout| Layout| Layout| Type| Type| Type| DataType| DataType| Type| Elementwise| Elementwise| Elementwise| Specialization| Specialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector|
//########################################| | | | | | | | | | | | Operation| Operation| Operation| | | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl|
//########################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
// Instances with NumGroupsPerBatch > 1
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
<
NDimSpatial
,
ALayout
,
BLayout
,
DsLayout
,
ELayout
,
F32
,
F32
,
F32
,
F32
,
DsLayout
,
F32
,
PassThrough
,
PassThrough
,
PassThrough
,
ConvSpec
,
GemmMNKPadding
,
1
,
64
,
64
,
16
,
16
,
4
,
4
,
16
,
16
,
4
,
1
,
S
<
4
,
16
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
4
,
1
,
S
<
4
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
1
,
4
,
1
,
1
,
1
,
S
<
1
,
16
,
1
,
4
>
,
1
,
F32
,
F32
,
LoopScheduler
::
Default
,
8
>
,
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
<
NDimSpatial
,
ALayout
,
BLayout
,
DsLayout
,
ELayout
,
F32
,
F32
,
F32
,
F32
,
DsLayout
,
F32
,
PassThrough
,
PassThrough
,
PassThrough
,
ConvSpec
,
GemmMNKPadding
,
1
,
64
,
64
,
16
,
16
,
4
,
4
,
16
,
16
,
4
,
1
,
S
<
4
,
16
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
4
,
1
,
S
<
4
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
1
,
4
,
1
,
1
,
1
,
S
<
1
,
16
,
1
,
4
>
,
1
,
F32
,
F32
,
LoopScheduler
::
Default
,
16
>
,
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
<
NDimSpatial
,
ALayout
,
BLayout
,
DsLayout
,
ELayout
,
F32
,
F32
,
F32
,
F32
,
DsLayout
,
F32
,
PassThrough
,
PassThrough
,
PassThrough
,
ConvSpec
,
GemmMNKPadding
,
1
,
64
,
64
,
16
,
16
,
4
,
4
,
16
,
16
,
4
,
1
,
S
<
4
,
16
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
4
,
1
,
S
<
4
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
1
,
4
,
1
,
1
,
1
,
S
<
1
,
16
,
1
,
4
>
,
1
,
F32
,
F32
,
LoopScheduler
::
Default
,
32
>
// clang-format on
>
;
}
// namespace instance
}
// namespace device
}
// namespace tensor_operation
}
// namespace ck
library/include/ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_outelementop_instance.hpp
View file @
a8efb3f0
...
@@ -147,6 +147,43 @@ using device_grouped_conv_fwd_xdl_outelementop_f8_bf8_instances = std::tuple<
...
@@ -147,6 +147,43 @@ using device_grouped_conv_fwd_xdl_outelementop_f8_bf8_instances = std::tuple<
// clang-format on
// clang-format on
>
;
>
;
template
<
index_t
NDimSpatial
,
typename
ALayout
,
typename
BLayout
,
typename
DsLayout
,
typename
ELayout
,
ConvolutionForwardSpecialization
ConvSpec
,
typename
OutElementOp
>
using
device_grouped_conv_fwd_xdl_outelementop_bf8_f8_instances
=
std
::
tuple
<
// clang-format off
//########################################| NumDim| A| B| Ds| E| AData| BData| AccData| CShuffle| Ds| EData| A| B| CDE| ConvForward| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| Compute| Compute|
//########################################| Spatial| Layout| Layout| Layout| Layout| Type| Type| Type| DataType| DataType| Type| Elementwise| Elementwise| Elementwise| Specialization| Specialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| TypeA| TypeB|
//########################################| | | | | | | | | | | | Operation| Operation| Operation| | | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| | |
//########################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
#if defined(CK_ENABLE_FP8) && defined(CK_ENABLE_BF8)
// generic instance
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
<
NDimSpatial
,
ALayout
,
BLayout
,
DsLayout
,
ELayout
,
BF8
,
F8
,
F32
,
F32
,
Tuple
<>
,
F8
,
PassThrough
,
PassThrough
,
OutElementOp
,
ConvSpec
,
GemmMNKPadding
,
1
,
64
,
64
,
64
,
32
,
8
,
8
,
32
,
32
,
2
,
2
,
S
<
4
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
1
,
8
,
1
,
S
<
4
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
1
,
8
,
1
,
1
,
1
,
S
<
1
,
16
,
1
,
4
>
,
1
,
BF8
,
F8
>
,
// instances for small conv.K and conv.C
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
<
NDimSpatial
,
ALayout
,
BLayout
,
DsLayout
,
ELayout
,
BF8
,
F8
,
F32
,
F32
,
Tuple
<>
,
F8
,
PassThrough
,
PassThrough
,
OutElementOp
,
ConvSpec
,
GemmMNKPadding
,
1
,
64
,
64
,
32
,
32
,
8
,
8
,
32
,
32
,
2
,
1
,
S
<
4
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
1
,
S
<
4
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
1
,
1
,
1
,
S
<
1
,
16
,
1
,
4
>
,
1
,
BF8
,
F8
>
,
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
<
NDimSpatial
,
ALayout
,
BLayout
,
DsLayout
,
ELayout
,
BF8
,
F8
,
F32
,
F32
,
Tuple
<>
,
F8
,
PassThrough
,
PassThrough
,
OutElementOp
,
ConvSpec
,
GemmMNKPadding
,
1
,
256
,
128
,
128
,
32
,
8
,
8
,
32
,
32
,
2
,
2
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
1
,
8
,
1
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
1
,
8
,
1
,
1
,
1
,
S
<
1
,
32
,
1
,
8
>
,
8
,
BF8
,
F8
>
,
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
<
NDimSpatial
,
ALayout
,
BLayout
,
DsLayout
,
ELayout
,
BF8
,
F8
,
F32
,
F32
,
Tuple
<>
,
F8
,
PassThrough
,
PassThrough
,
OutElementOp
,
ConvSpec
,
GemmMNKPadding
,
1
,
256
,
256
,
128
,
32
,
8
,
8
,
32
,
32
,
4
,
2
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
1
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
1
,
1
,
1
,
S
<
1
,
32
,
1
,
8
>
,
8
,
BF8
,
F8
>
,
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
<
NDimSpatial
,
ALayout
,
BLayout
,
DsLayout
,
ELayout
,
BF8
,
F8
,
F32
,
F32
,
Tuple
<>
,
F8
,
PassThrough
,
PassThrough
,
OutElementOp
,
ConvSpec
,
GemmMNKPadding
,
1
,
256
,
128
,
256
,
32
,
8
,
8
,
32
,
32
,
2
,
4
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
1
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
1
,
1
,
1
,
S
<
1
,
32
,
1
,
8
>
,
8
,
BF8
,
F8
>
,
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
<
NDimSpatial
,
ALayout
,
BLayout
,
DsLayout
,
ELayout
,
BF8
,
F8
,
F32
,
F32
,
Tuple
<>
,
F8
,
PassThrough
,
PassThrough
,
OutElementOp
,
ConvSpec
,
GemmMNKPadding
,
1
,
128
,
128
,
128
,
32
,
8
,
8
,
32
,
32
,
4
,
2
,
S
<
4
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
1
,
S
<
4
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
1
,
1
,
1
,
S
<
1
,
16
,
1
,
8
>
,
8
,
BF8
,
F8
>
,
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
<
NDimSpatial
,
ALayout
,
BLayout
,
DsLayout
,
ELayout
,
BF8
,
F8
,
F32
,
F32
,
Tuple
<>
,
F8
,
PassThrough
,
PassThrough
,
OutElementOp
,
ConvSpec
,
GemmMNKPadding
,
1
,
256
,
128
,
128
,
32
,
8
,
8
,
32
,
32
,
2
,
2
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
1
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
1
,
1
,
1
,
S
<
1
,
32
,
1
,
8
>
,
8
,
BF8
,
F8
>
,
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
<
NDimSpatial
,
ALayout
,
BLayout
,
DsLayout
,
ELayout
,
BF8
,
F8
,
F32
,
F32
,
Tuple
<>
,
F8
,
PassThrough
,
PassThrough
,
OutElementOp
,
ConvSpec
,
GemmMNKPadding
,
1
,
128
,
128
,
64
,
32
,
8
,
8
,
32
,
32
,
2
,
2
,
S
<
4
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
1
,
S
<
4
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
1
,
1
,
1
,
S
<
1
,
32
,
1
,
4
>
,
8
,
BF8
,
F8
>
,
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
<
NDimSpatial
,
ALayout
,
BLayout
,
DsLayout
,
ELayout
,
BF8
,
F8
,
F32
,
F32
,
Tuple
<>
,
F8
,
PassThrough
,
PassThrough
,
OutElementOp
,
ConvSpec
,
GemmMNKPadding
,
1
,
128
,
64
,
128
,
32
,
8
,
8
,
32
,
32
,
2
,
2
,
S
<
4
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
1
,
S
<
4
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
1
,
1
,
1
,
S
<
1
,
16
,
1
,
8
>
,
8
,
BF8
,
F8
>
,
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
<
NDimSpatial
,
ALayout
,
BLayout
,
DsLayout
,
ELayout
,
BF8
,
F8
,
F32
,
F32
,
Tuple
<>
,
F8
,
PassThrough
,
PassThrough
,
OutElementOp
,
ConvSpec
,
GemmMNKPadding
,
1
,
64
,
64
,
64
,
32
,
8
,
8
,
32
,
32
,
2
,
2
,
S
<
4
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
1
,
S
<
4
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
1
,
1
,
1
,
S
<
1
,
16
,
1
,
4
>
,
8
,
BF8
,
F8
>
,
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
<
NDimSpatial
,
ALayout
,
BLayout
,
DsLayout
,
ELayout
,
BF8
,
F8
,
F32
,
F32
,
Tuple
<>
,
F8
,
PassThrough
,
PassThrough
,
OutElementOp
,
ConvSpec
,
GemmMNKPadding
,
1
,
256
,
128
,
64
,
32
,
8
,
8
,
32
,
32
,
2
,
1
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
1
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
1
,
1
,
1
,
S
<
1
,
32
,
1
,
8
>
,
8
,
BF8
,
F8
>
,
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
<
NDimSpatial
,
ALayout
,
BLayout
,
DsLayout
,
ELayout
,
BF8
,
F8
,
F32
,
F32
,
Tuple
<>
,
F8
,
PassThrough
,
PassThrough
,
OutElementOp
,
ConvSpec
,
GemmMNKPadding
,
1
,
256
,
64
,
128
,
32
,
8
,
8
,
32
,
32
,
1
,
2
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
1
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
1
,
1
,
1
,
S
<
1
,
32
,
1
,
8
>
,
8
,
BF8
,
F8
>
,
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
<
NDimSpatial
,
ALayout
,
BLayout
,
DsLayout
,
ELayout
,
BF8
,
F8
,
F32
,
F32
,
Tuple
<>
,
F8
,
PassThrough
,
PassThrough
,
OutElementOp
,
ConvSpec
,
GemmMNKPadding
,
1
,
128
,
128
,
32
,
32
,
8
,
8
,
32
,
32
,
2
,
1
,
S
<
4
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
1
,
S
<
4
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
1
,
1
,
1
,
S
<
1
,
32
,
1
,
4
>
,
8
,
BF8
,
F8
>
,
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
<
NDimSpatial
,
ALayout
,
BLayout
,
DsLayout
,
ELayout
,
BF8
,
F8
,
F32
,
F32
,
Tuple
<>
,
F8
,
PassThrough
,
PassThrough
,
OutElementOp
,
ConvSpec
,
GemmMNKPadding
,
1
,
128
,
32
,
128
,
32
,
8
,
8
,
32
,
32
,
1
,
2
,
S
<
4
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
1
,
S
<
4
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
1
,
1
,
1
,
S
<
1
,
16
,
1
,
8
>
,
8
,
BF8
,
F8
>
,
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
<
NDimSpatial
,
ALayout
,
BLayout
,
DsLayout
,
ELayout
,
BF8
,
F8
,
F32
,
F32
,
Tuple
<>
,
F8
,
PassThrough
,
PassThrough
,
OutElementOp
,
ConvSpec
,
GemmMNKPadding
,
1
,
64
,
64
,
32
,
32
,
8
,
8
,
32
,
32
,
2
,
1
,
S
<
4
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
1
,
S
<
4
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
1
,
1
,
1
,
S
<
1
,
16
,
1
,
4
>
,
8
,
BF8
,
F8
>
,
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
<
NDimSpatial
,
ALayout
,
BLayout
,
DsLayout
,
ELayout
,
BF8
,
F8
,
F32
,
F32
,
Tuple
<>
,
F8
,
PassThrough
,
PassThrough
,
OutElementOp
,
ConvSpec
,
GemmMNKPadding
,
1
,
64
,
32
,
64
,
32
,
8
,
8
,
32
,
32
,
1
,
2
,
S
<
4
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
1
,
S
<
4
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
1
,
1
,
1
,
S
<
1
,
16
,
1
,
4
>
,
8
,
BF8
,
F8
>
#endif
// clang-format on
>
;
}
// namespace instance
}
// namespace instance
}
// namespace device
}
// namespace device
}
// namespace tensor_operation
}
// namespace tensor_operation
...
...
library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_forward.hpp
View file @
a8efb3f0
...
@@ -17,6 +17,7 @@
...
@@ -17,6 +17,7 @@
#endif
#endif
#ifdef CK_USE_XDL
#ifdef CK_USE_XDL
#include "grouped_convolution_forward_xdl.inc"
#include "grouped_convolution_forward_xdl.inc"
#include "grouped_convolution_forward_xdl_merged_groups.inc"
#include "grouped_convolution_forward_comp_xdl.inc"
#include "grouped_convolution_forward_comp_xdl.inc"
#include "grouped_convolution_forward_mem_inter_xdl.inc"
#include "grouped_convolution_forward_mem_inter_xdl.inc"
#include "grouped_convolution_forward_mem_intra_xdl.inc"
#include "grouped_convolution_forward_mem_intra_xdl.inc"
...
@@ -199,6 +200,8 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGroupe
...
@@ -199,6 +200,8 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGroupe
is_same_v
<
BComputeType
,
float
>
)
is_same_v
<
BComputeType
,
float
>
)
{
{
add_device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_f32_instances
(
op_ptrs
);
add_device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_f32_instances
(
op_ptrs
);
add_device_grouped_conv2d_fwd_xdl_merged_groups_nhwgc_gkyxc_nhwgk_f32_instances
(
op_ptrs
);
add_device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_f32_comp_instances
(
op_ptrs
);
add_device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_f32_comp_instances
(
op_ptrs
);
add_device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_f32_mem_intra_instances
(
add_device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_f32_mem_intra_instances
(
op_ptrs
);
op_ptrs
);
...
@@ -212,6 +215,8 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGroupe
...
@@ -212,6 +215,8 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGroupe
is_same_v
<
BComputeType
,
half_t
>
)
is_same_v
<
BComputeType
,
half_t
>
)
{
{
add_device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_f16_instances
(
op_ptrs
);
add_device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_f16_instances
(
op_ptrs
);
add_device_grouped_conv2d_fwd_xdl_merged_groups_nhwgc_gkyxc_nhwgk_f16_instances
(
op_ptrs
);
add_device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_f16_comp_instances
(
op_ptrs
);
add_device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_f16_comp_instances
(
op_ptrs
);
add_device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_f16_mem_intra_instances
(
add_device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_f16_mem_intra_instances
(
op_ptrs
);
op_ptrs
);
...
@@ -227,6 +232,8 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGroupe
...
@@ -227,6 +232,8 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGroupe
is_same_v
<
BComputeType
,
ck
::
bhalf_t
>
)
is_same_v
<
BComputeType
,
ck
::
bhalf_t
>
)
{
{
add_device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_bf16_instances
(
op_ptrs
);
add_device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_bf16_instances
(
op_ptrs
);
add_device_grouped_conv2d_fwd_xdl_merged_groups_nhwgc_gkyxc_nhwgk_bf16_instances
(
op_ptrs
);
add_device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_bf16_comp_instances
(
op_ptrs
);
add_device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_bf16_comp_instances
(
op_ptrs
);
add_device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_bf16_mem_intra_instances
(
add_device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_bf16_mem_intra_instances
(
op_ptrs
);
op_ptrs
);
...
@@ -284,6 +291,8 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGroupe
...
@@ -284,6 +291,8 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGroupe
is_same_v
<
BComputeType
,
float
>
)
is_same_v
<
BComputeType
,
float
>
)
{
{
add_device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_f32_instances
(
op_ptrs
);
add_device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_f32_instances
(
op_ptrs
);
add_device_grouped_conv3d_fwd_xdl_merged_groups_ndhwgc_gkzyxc_ndhwgk_f32_instances
(
op_ptrs
);
add_device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_f32_comp_instances
(
op_ptrs
);
add_device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_f32_comp_instances
(
op_ptrs
);
add_device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_f32_mem_intra_instances
(
add_device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_f32_mem_intra_instances
(
op_ptrs
);
op_ptrs
);
...
@@ -338,6 +347,8 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGroupe
...
@@ -338,6 +347,8 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGroupe
is_same_v
<
BComputeType
,
half_t
>
)
is_same_v
<
BComputeType
,
half_t
>
)
{
{
add_device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_f16_instances
(
op_ptrs
);
add_device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_f16_instances
(
op_ptrs
);
add_device_grouped_conv3d_fwd_xdl_merged_groups_ndhwgc_gkzyxc_ndhwgk_f16_instances
(
op_ptrs
);
add_device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_f16_comp_instances
(
op_ptrs
);
add_device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_f16_comp_instances
(
op_ptrs
);
add_device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_f16_mem_intra_instances
(
add_device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_f16_mem_intra_instances
(
op_ptrs
);
op_ptrs
);
...
@@ -353,6 +364,8 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGroupe
...
@@ -353,6 +364,8 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGroupe
is_same_v
<
BComputeType
,
ck
::
bhalf_t
>
)
is_same_v
<
BComputeType
,
ck
::
bhalf_t
>
)
{
{
add_device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_instances
(
op_ptrs
);
add_device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_instances
(
op_ptrs
);
add_device_grouped_conv3d_fwd_xdl_merged_groups_ndhwgc_gkzyxc_ndhwgk_bf16_instances
(
op_ptrs
);
add_device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_comp_instances
(
op_ptrs
);
add_device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_comp_instances
(
op_ptrs
);
add_device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_mem_intra_instances
(
add_device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_mem_intra_instances
(
op_ptrs
);
op_ptrs
);
...
...
library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_forward_convscale.hpp
View file @
a8efb3f0
...
@@ -70,6 +70,22 @@ void add_device_grouped_conv3d_fwd_xdl_convscale_ndhwgc_gkzyxc_ndhwgk_f8_bf8_ins
...
@@ -70,6 +70,22 @@ void add_device_grouped_conv3d_fwd_xdl_convscale_ndhwgc_gkzyxc_ndhwgk_f8_bf8_ins
ConvScale
,
ConvScale
,
F8
,
F8
,
BF8
>>>&
instances
);
BF8
>>>&
instances
);
void
add_device_grouped_conv3d_fwd_xdl_convscale_ndhwgc_gkzyxc_ndhwgk_bf8_f8_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGroupedConvFwdMultipleABD
<
3
,
NDHWGC
,
GKZYXC
,
ck
::
Tuple
<>
,
NDHWGK
,
BF8
,
F8
,
ck
::
Tuple
<>
,
F8
,
PassThrough
,
PassThrough
,
ConvScale
,
BF8
,
F8
>>>&
instances
);
#endif
#endif
template
<
ck
::
index_t
NumDimSpatial
,
template
<
ck
::
index_t
NumDimSpatial
,
...
@@ -147,6 +163,14 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGroupe
...
@@ -147,6 +163,14 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGroupe
add_device_grouped_conv3d_fwd_xdl_convscale_ndhwgc_gkzyxc_ndhwgk_f8_bf8_instances
(
add_device_grouped_conv3d_fwd_xdl_convscale_ndhwgc_gkzyxc_ndhwgk_f8_bf8_instances
(
op_ptrs
);
op_ptrs
);
}
}
if
constexpr
(
is_same_v
<
InDataType
,
bf8_t
>
&&
is_same_v
<
WeiDataType
,
f8_t
>
&&
is_same_v
<
OutDataType
,
f8_t
>
&&
is_same_v
<
AComputeType
,
bf8_t
>
&&
is_same_v
<
BComputeType
,
f8_t
>
)
{
add_device_grouped_conv3d_fwd_xdl_convscale_ndhwgc_gkzyxc_ndhwgk_bf8_f8_instances
(
op_ptrs
);
}
#endif
#endif
}
}
return
op_ptrs
;
return
op_ptrs
;
...
...
library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_forward_convscale_relu.hpp
0 → 100644
View file @
a8efb3f0
// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <vector>
#include <memory>
#include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/device_grouped_conv_fwd_multiple_abd.hpp"
#include "ck/tensor_operation/gpu/element/unary_element_wise_operation.hpp"
#include "ck/library/tensor_operation_instance/device_operation_instance_factory.hpp"
namespace
ck
{
namespace
tensor_operation
{
namespace
device
{
namespace
instance
{
using
PassThrough
=
ck
::
tensor_operation
::
element_wise
::
PassThrough
;
using
ConvScaleRelu
=
ck
::
tensor_operation
::
element_wise
::
ConvScaleRelu
;
#ifdef CK_ENABLE_FP8
void
add_device_grouped_conv3d_fwd_xdl_convscale_relu_ndhwgc_gkzyxc_ndhwgk_f8_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGroupedConvFwdMultipleABD
<
3
,
NDHWGC
,
GKZYXC
,
ck
::
Tuple
<>
,
NDHWGK
,
F8
,
F8
,
ck
::
Tuple
<>
,
F8
,
PassThrough
,
PassThrough
,
ConvScaleRelu
,
F8
,
F8
>>>&
instances
);
#endif
template
<
ck
::
index_t
NumDimSpatial
,
typename
InLayout
,
typename
WeiLayout
,
typename
DLayouts
,
typename
OutLayout
,
typename
InDataType
,
typename
WeiDataType
,
typename
DDataTypes
,
typename
OutDataType
,
typename
AComputeType
,
typename
BComputeType
>
struct
DeviceOperationInstanceFactory
<
ck
::
tensor_operation
::
device
::
DeviceGroupedConvFwdMultipleABD
<
NumDimSpatial
,
InLayout
,
WeiLayout
,
DLayouts
,
OutLayout
,
InDataType
,
WeiDataType
,
DDataTypes
,
OutDataType
,
PassThrough
,
PassThrough
,
ConvScaleRelu
,
AComputeType
,
BComputeType
>>
{
using
DeviceOp
=
DeviceGroupedConvFwdMultipleABD
<
NumDimSpatial
,
InLayout
,
WeiLayout
,
DLayouts
,
OutLayout
,
InDataType
,
WeiDataType
,
DDataTypes
,
OutDataType
,
PassThrough
,
PassThrough
,
ConvScaleRelu
,
AComputeType
,
BComputeType
>
;
static
auto
GetInstances
()
{
std
::
vector
<
std
::
unique_ptr
<
DeviceOp
>>
op_ptrs
;
if
constexpr
(
NumDimSpatial
==
3
&&
is_same_v
<
InLayout
,
NDHWGC
>
&&
is_same_v
<
WeiLayout
,
GKZYXC
>
&&
is_same_v
<
OutLayout
,
NDHWGK
>
)
{
#ifdef CK_ENABLE_FP8
if
constexpr
(
is_same_v
<
InDataType
,
f8_t
>
&&
is_same_v
<
WeiDataType
,
f8_t
>
&&
is_same_v
<
OutDataType
,
f8_t
>
&&
is_same_v
<
AComputeType
,
f8_t
>
&&
is_same_v
<
BComputeType
,
f8_t
>
)
{
add_device_grouped_conv3d_fwd_xdl_convscale_relu_ndhwgc_gkzyxc_ndhwgk_f8_instances
(
op_ptrs
);
}
#endif
}
return
op_ptrs
;
}
};
}
// namespace instance
}
// namespace device
}
// namespace tensor_operation
}
// namespace ck
library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_forward_xdl_merged_groups.inc
0 → 100644
View file @
a8efb3f0
// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
namespace
ck
{
namespace
tensor_operation
{
namespace
device
{
namespace
instance
{
// grouped conv2d forward, NHWGC/GKYXC/NHWGK
#ifdef CK_ENABLE_BF16
void
add_device_grouped_conv2d_fwd_xdl_merged_groups_nhwgc_gkyxc_nhwgk_bf16_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGroupedConvFwdMultipleABD
<
2
,
NHWGC
,
GKYXC
,
Empty_Tuple
,
NHWGK
,
BF16
,
BF16
,
Empty_Tuple
,
BF16
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
instances
);
#endif
#ifdef CK_ENABLE_FP16
void
add_device_grouped_conv2d_fwd_xdl_merged_groups_nhwgc_gkyxc_nhwgk_f16_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGroupedConvFwdMultipleABD
<
2
,
NHWGC
,
GKYXC
,
Empty_Tuple
,
NHWGK
,
F16
,
F16
,
Empty_Tuple
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
instances
);
#endif
#ifdef CK_ENABLE_FP32
void
add_device_grouped_conv2d_fwd_xdl_merged_groups_nhwgc_gkyxc_nhwgk_f32_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGroupedConvFwdMultipleABD
<
2
,
NHWGC
,
GKYXC
,
Empty_Tuple
,
NHWGK
,
F32
,
F32
,
Empty_Tuple
,
F32
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
instances
);
#endif
#ifdef CK_ENABLE_BF16
// grouped conv3d forward, NDHWGC/GKZYXC/NDHWGK
void
add_device_grouped_conv3d_fwd_xdl_merged_groups_ndhwgc_gkzyxc_ndhwgk_bf16_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGroupedConvFwdMultipleABD
<
3
,
NDHWGC
,
GKZYXC
,
Empty_Tuple
,
NDHWGK
,
BF16
,
BF16
,
Empty_Tuple
,
BF16
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
instances
);
#endif
#ifdef CK_ENABLE_FP16
void
add_device_grouped_conv3d_fwd_xdl_merged_groups_ndhwgc_gkzyxc_ndhwgk_f16_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGroupedConvFwdMultipleABD
<
3
,
NDHWGC
,
GKZYXC
,
Empty_Tuple
,
NDHWGK
,
F16
,
F16
,
Empty_Tuple
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
instances
);
#endif
#ifdef CK_ENABLE_FP32
void
add_device_grouped_conv3d_fwd_xdl_merged_groups_ndhwgc_gkzyxc_ndhwgk_f32_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGroupedConvFwdMultipleABD
<
3
,
NDHWGC
,
GKZYXC
,
Empty_Tuple
,
NDHWGK
,
F32
,
F32
,
Empty_Tuple
,
F32
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
instances
);
#endif
}
// namespace instance
}
// namespace device
}
// namespace tensor_operation
}
// namespace ck
library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_xdl_universal_f16_f8_f16/device_gemm_xdl_universal_f16_f8_f16_mk_kn_mn.hpp
View file @
a8efb3f0
...
@@ -43,7 +43,8 @@ using device_gemm_xdl_universal_f16_f8_f16_mk_kn_mn_comp_instances = std::tuple<
...
@@ -43,7 +43,8 @@ using device_gemm_xdl_universal_f16_f8_f16_mk_kn_mn_comp_instances = std::tuple<
DeviceGemm_Xdl_CShuffleV3
<
Row
,
Row
,
Row
,
F16
,
F8
,
F16
,
F32
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmSpec
,
256
,
256
,
256
,
32
,
8
,
4
,
32
,
32
,
4
,
4
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
0
,
S
<
8
,
32
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
8
,
4
,
0
,
1
,
1
,
S
<
1
,
32
,
1
,
8
>
,
8
,
BlockGemmPipelineScheduler
::
Intrawave
,
BlockGemmPipelineVersion
::
v4
>
,
DeviceGemm_Xdl_CShuffleV3
<
Row
,
Row
,
Row
,
F16
,
F8
,
F16
,
F32
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmSpec
,
256
,
256
,
256
,
32
,
8
,
4
,
32
,
32
,
4
,
4
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
0
,
S
<
8
,
32
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
8
,
4
,
0
,
1
,
1
,
S
<
1
,
32
,
1
,
8
>
,
8
,
BlockGemmPipelineScheduler
::
Intrawave
,
BlockGemmPipelineVersion
::
v4
>
,
DeviceGemm_Xdl_CShuffleV3
<
Row
,
Row
,
Row
,
F16
,
F8
,
F16
,
F32
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmSpec
,
256
,
128
,
128
,
64
,
8
,
4
,
32
,
32
,
2
,
2
,
S
<
8
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
0
,
S
<
16
,
16
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
8
,
4
,
0
,
1
,
1
,
S
<
1
,
32
,
1
,
8
>
,
8
,
BlockGemmPipelineScheduler
::
Intrawave
,
BlockGemmPipelineVersion
::
v4
>
,
DeviceGemm_Xdl_CShuffleV3
<
Row
,
Row
,
Row
,
F16
,
F8
,
F16
,
F32
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmSpec
,
256
,
128
,
128
,
64
,
8
,
4
,
32
,
32
,
2
,
2
,
S
<
8
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
0
,
S
<
16
,
16
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
8
,
4
,
0
,
1
,
1
,
S
<
1
,
32
,
1
,
8
>
,
8
,
BlockGemmPipelineScheduler
::
Intrawave
,
BlockGemmPipelineVersion
::
v4
>
,
DeviceGemm_Xdl_CShuffleV3
<
Row
,
Row
,
Row
,
F16
,
F8
,
F16
,
F32
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmSpec
,
256
,
256
,
256
,
32
,
8
,
4
,
32
,
32
,
4
,
4
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
0
,
S
<
8
,
32
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
8
,
4
,
0
,
1
,
1
,
S
<
1
,
32
,
1
,
8
>
,
8
,
BlockGemmPipelineScheduler
::
Intrawave
,
BlockGemmPipelineVersion
::
v5
>
,
// Disable due to test failure
// DeviceGemm_Xdl_CShuffleV3< Row, Row, Row, F16, F8, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 256, 256, 32, 8, 4, 32, 32, 4, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 4, 0, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v5>,
DeviceGemm_Xdl_CShuffleV3
<
Row
,
Row
,
Row
,
F16
,
F8
,
F16
,
F32
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmSpec
,
256
,
256
,
256
,
32
,
8
,
4
,
32
,
32
,
4
,
4
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
0
,
S
<
8
,
32
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
8
,
4
,
0
,
1
,
1
,
S
<
1
,
32
,
1
,
8
>
,
8
,
BlockGemmPipelineScheduler
::
Intrawave
,
BlockGemmPipelineVersion
::
v3
>
,
DeviceGemm_Xdl_CShuffleV3
<
Row
,
Row
,
Row
,
F16
,
F8
,
F16
,
F32
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmSpec
,
256
,
256
,
256
,
32
,
8
,
4
,
32
,
32
,
4
,
4
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
0
,
S
<
8
,
32
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
8
,
4
,
0
,
1
,
1
,
S
<
1
,
32
,
1
,
8
>
,
8
,
BlockGemmPipelineScheduler
::
Intrawave
,
BlockGemmPipelineVersion
::
v3
>
,
DeviceGemm_Xdl_CShuffleV3
<
Row
,
Row
,
Row
,
F16
,
F8
,
F16
,
F32
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmSpec
,
256
,
224
,
256
,
64
,
8
,
4
,
16
,
16
,
7
,
8
,
S
<
8
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
0
,
S
<
16
,
16
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
16
,
4
,
0
,
1
,
2
,
S
<
1
,
32
,
1
,
8
>
,
8
,
BlockGemmPipelineScheduler
::
Intrawave
,
BlockGemmPipelineVersion
::
v3
>
,
DeviceGemm_Xdl_CShuffleV3
<
Row
,
Row
,
Row
,
F16
,
F8
,
F16
,
F32
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmSpec
,
256
,
224
,
256
,
64
,
8
,
4
,
16
,
16
,
7
,
8
,
S
<
8
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
0
,
S
<
16
,
16
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
16
,
4
,
0
,
1
,
2
,
S
<
1
,
32
,
1
,
8
>
,
8
,
BlockGemmPipelineScheduler
::
Intrawave
,
BlockGemmPipelineVersion
::
v3
>
,
DeviceGemm_Xdl_CShuffleV3
<
Row
,
Row
,
Row
,
F16
,
F8
,
F16
,
F32
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmSpec
,
256
,
128
,
128
,
64
,
8
,
4
,
32
,
32
,
2
,
2
,
S
<
8
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
0
,
S
<
16
,
16
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
8
,
4
,
0
,
1
,
1
,
S
<
1
,
32
,
1
,
8
>
,
8
,
BlockGemmPipelineScheduler
::
Intrawave
,
BlockGemmPipelineVersion
::
v3
>
,
DeviceGemm_Xdl_CShuffleV3
<
Row
,
Row
,
Row
,
F16
,
F8
,
F16
,
F32
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmSpec
,
256
,
128
,
128
,
64
,
8
,
4
,
32
,
32
,
2
,
2
,
S
<
8
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
0
,
S
<
16
,
16
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
8
,
4
,
0
,
1
,
1
,
S
<
1
,
32
,
1
,
8
>
,
8
,
BlockGemmPipelineScheduler
::
Intrawave
,
BlockGemmPipelineVersion
::
v3
>
,
...
...
library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/CMakeLists.txt
View file @
a8efb3f0
...
@@ -9,6 +9,11 @@ add_instance_library(device_grouped_conv2d_fwd_instance
...
@@ -9,6 +9,11 @@ add_instance_library(device_grouped_conv2d_fwd_instance
xdl/device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_bf16_instance.cpp
xdl/device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_bf16_instance.cpp
xdl/device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_f16_instance.cpp
xdl/device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_f16_instance.cpp
xdl/device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_f32_instance.cpp
xdl/device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_f32_instance.cpp
# merged groups
# NHWGC, GKYXC, NHWGK
xdl/merged_groups/device_grouped_conv2d_fwd_xdl_merged_groups_nhwgc_gkyxc_nhwgk_bf16_instance.cpp
xdl/merged_groups/device_grouped_conv2d_fwd_xdl_merged_groups_nhwgc_gkyxc_nhwgk_f16_instance.cpp
xdl/merged_groups/device_grouped_conv2d_fwd_xdl_merged_groups_nhwgc_gkyxc_nhwgk_f32_instance.cpp
#mem
#mem
# NHWGC, GKYXC, NHWGK
# NHWGC, GKYXC, NHWGK
xdl/mem/device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_bf16_mem_intra_instance.cpp
xdl/mem/device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_bf16_mem_intra_instance.cpp
...
...
library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/xdl/merged_groups/device_grouped_conv2d_fwd_xdl_merged_groups_nhwgc_gkyxc_nhwgk_bf16_instance.cpp
0 → 100644
View file @
a8efb3f0
// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
#include "ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_merged_groups_instance.hpp"
namespace
ck
{
namespace
tensor_operation
{
namespace
device
{
namespace
instance
{
// Compilation parameters for in[n, hi, wi, g, c] * wei[g, k, y, x, c] = out[n, ho, wo, g, k]
void
add_device_grouped_conv2d_fwd_xdl_merged_groups_nhwgc_gkyxc_nhwgk_bf16_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGroupedConvFwdMultipleABD
<
2
,
NHWGC
,
GKYXC
,
Empty_Tuple
,
NHWGK
,
BF16
,
BF16
,
Empty_Tuple
,
BF16
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
instances
)
{
add_device_operation_instances
(
instances
,
device_grouped_conv_fwd_xdl_merged_groups_bf16_instances
<
2
,
NHWGC
,
GKYXC
,
Empty_Tuple
,
NHWGK
,
ConvFwdDefault
>
{});
add_device_operation_instances
(
instances
,
device_grouped_conv_fwd_xdl_merged_groups_bf16_instances
<
2
,
NHWGC
,
GKYXC
,
Empty_Tuple
,
NHWGK
,
ConvFwd3x3
>
{});
}
}
// namespace instance
}
// namespace device
}
// namespace tensor_operation
}
// namespace ck
library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/xdl/merged_groups/device_grouped_conv2d_fwd_xdl_merged_groups_nhwgc_gkyxc_nhwgk_f16_instance.cpp
0 → 100644
View file @
a8efb3f0
// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
#include "ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_merged_groups_instance.hpp"
namespace
ck
{
namespace
tensor_operation
{
namespace
device
{
namespace
instance
{
// Compilation parameters for in[n, hi, wi, g, c] * wei[g, k, y, x, c] = out[n, ho, wo, g, k]
void
add_device_grouped_conv2d_fwd_xdl_merged_groups_nhwgc_gkyxc_nhwgk_f16_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGroupedConvFwdMultipleABD
<
2
,
NHWGC
,
GKYXC
,
Empty_Tuple
,
NHWGK
,
F16
,
F16
,
Empty_Tuple
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
instances
)
{
add_device_operation_instances
(
instances
,
device_grouped_conv_fwd_xdl_merged_groups_f16_instances
<
2
,
NHWGC
,
GKYXC
,
Empty_Tuple
,
NHWGK
,
ConvFwdDefault
>
{});
add_device_operation_instances
(
instances
,
device_grouped_conv_fwd_xdl_merged_groups_f16_instances
<
2
,
NHWGC
,
GKYXC
,
Empty_Tuple
,
NHWGK
,
ConvFwd3x3
>
{});
}
}
// namespace instance
}
// namespace device
}
// namespace tensor_operation
}
// namespace ck
Prev
1
2
3
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment