Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel_ROCM
Commits
e4112de7
Commit
e4112de7
authored
May 22, 2024
by
Jun Liu
Browse files
Merge branch 'develop' into amd-develop
parents
a6ef5c39
fd72380a
Changes
39
Hide whitespace changes
Inline
Side-by-side
Showing
19 changed files
with
926 additions
and
116 deletions
+926
-116
include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3_multi_abd.hpp
...tion/gpu/grid/gridwise_gemm_xdl_cshuffle_v3_multi_abd.hpp
+11
-13
include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_splitk_lds_direct_load.hpp
.../gpu/grid/gridwise_gemm_xdlops_splitk_lds_direct_load.hpp
+2
-3
include/ck/tensor_operation/operator_transform/transform_conv_bwd_weight_to_gemm_v2.hpp
...erator_transform/transform_conv_bwd_weight_to_gemm_v2.hpp
+640
-0
include/ck_tile/core/config.hpp
include/ck_tile/core/config.hpp
+21
-9
include/ck_tile/core/numeric/float8.hpp
include/ck_tile/core/numeric/float8.hpp
+10
-10
include/ck_tile/core/tensor/tile_elementwise.hpp
include/ck_tile/core/tensor/tile_elementwise.hpp
+1
-1
include/ck_tile/ops/gemm/warp/warp_gemm_attribute_mfma_impl.hpp
...e/ck_tile/ops/gemm/warp/warp_gemm_attribute_mfma_impl.hpp
+10
-14
library/include/ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_weight/device_grouped_conv_bwd_weight_two_stage_xdl_instance.hpp
...device_grouped_conv_bwd_weight_two_stage_xdl_instance.hpp
+16
-6
library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_backward_weight.hpp
...tion_instance/gpu/grouped_convolution_backward_weight.hpp
+6
-2
library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_backward_weight_xdl.inc
..._instance/gpu/grouped_convolution_backward_weight_xdl.inc
+26
-2
library/src/tensor_operation_instance/gpu/CMakeLists.txt
library/src/tensor_operation_instance/gpu/CMakeLists.txt
+42
-8
library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_weight/CMakeLists.txt
...ion_instance/gpu/grouped_conv2d_bwd_weight/CMakeLists.txt
+3
-1
library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_weight/xdl/device_grouped_conv2d_bwd_weight_two_stage_xdl_nhwgc_gkyxc_nhwgk_f16_pipev2_instance.cpp
...t_two_stage_xdl_nhwgc_gkyxc_nhwgk_f16_pipev2_instance.cpp
+4
-11
library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_weight/xdl/device_grouped_conv2d_bwd_weight_two_stage_xdl_nhwgc_gkyxc_nhwgk_f16_pipev5_instance.cpp
...t_two_stage_xdl_nhwgc_gkyxc_nhwgk_f16_pipev5_instance.cpp
+41
-0
library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_weight/CMakeLists.txt
...ion_instance/gpu/grouped_conv3d_bwd_weight/CMakeLists.txt
+10
-8
library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_weight/xdl/device_grouped_conv3d_bwd_weight_two_stage_xdl_ndhwgc_gkzyxc_ndhwgk_f16_pipev2_instance.cpp
...wo_stage_xdl_ndhwgc_gkzyxc_ndhwgk_f16_pipev2_instance.cpp
+4
-11
library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_weight/xdl/device_grouped_conv3d_bwd_weight_two_stage_xdl_ndhwgc_gkzyxc_ndhwgk_f16_pipev5_instance.cpp
...wo_stage_xdl_ndhwgc_gkzyxc_ndhwgk_f16_pipev5_instance.cpp
+41
-0
test/CMakeLists.txt
test/CMakeLists.txt
+32
-4
test/grouped_convnd_bwd_weight/test_grouped_convnd_bwd_weight.cpp
...uped_convnd_bwd_weight/test_grouped_convnd_bwd_weight.cpp
+6
-13
No files found.
include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3_multi_abd.hpp
View file @
e4112de7
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-202
3
, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-202
4
, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
...
...
@@ -33,8 +33,7 @@ __global__ void
// __attribute__((amdgpu_waves_per_eu(1, 1)))
kernel_gemm_xdl_cshuffle_v3
(
typename
GridwiseGemm
::
Argument
karg
)
{
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \
defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__))
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx9__))
__shared__
char
p_shared
[
GridwiseGemm
::
GetSharedMemoryNumberOfByte
()];
GridwiseGemm
::
template
Run
<
HasMainKBlockLoop
,
CGlobalMemoryDataOperation
,
TailNum
>(
...
...
@@ -49,7 +48,7 @@ __global__ void
karg
.
c_element_op
);
#else
ignore
=
karg
;
#endif // end of if (defined(__gfx9
08__) || defined(__gfx90a
__))
#endif // end of if (defined(__gfx9__))
}
template
<
typename
GridwiseGemm
,
...
...
@@ -64,8 +63,7 @@ __global__ void
// __attribute__((amdgpu_waves_per_eu(1, 1)))
kernel_gemm_xdl_cshuffle_v3_2lds
(
typename
GridwiseGemm
::
Argument
karg
)
{
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \
defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__))
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx9__))
// Pass two lds pointer is the key to tell compiler that ds_read/write
// operate on different lds chunk at same time without order dependecy
__shared__
char
p_shared_0
[
GridwiseGemm
::
GetSharedMemoryNumberOfByte
()];
...
...
@@ -84,7 +82,7 @@ __global__ void
karg
.
c_element_op
);
#else
ignore
=
karg
;
#endif // end of if (defined(__gfx9
08__) || defined(__gfx90a
__))
#endif // end of if (defined(__gfx9__))
}
template
<
typename
ALayout
,
...
...
@@ -783,8 +781,8 @@ struct GridwiseGemm_xdl_cshuffle_v3
constexpr
auto
a_lds_block_desc_permuted
=
transform_tensor_descriptor
(
a_lds_block_desc
,
make_tuple
(
make_xor_transform
(
make_tuple
(
Number
<
MPerBlock
/
MLdsLayer
>
{},
Number
<
AK0Number
*
MLdsLayer
>
{})),
make_tuple
(
make_xor_
with_modulo_
transform
(
make_tuple
(
Number
<
MPerBlock
/
MLdsLayer
>
{},
Number
<
AK0Number
*
MLdsLayer
>
{})),
make_pass_through_transform
(
AK1Number
)),
make_tuple
(
Sequence
<
1
,
0
>
{},
Sequence
<
2
>
{}),
make_tuple
(
Sequence
<
1
,
0
>
{},
Sequence
<
2
>
{}));
...
...
@@ -849,7 +847,7 @@ struct GridwiseGemm_xdl_cshuffle_v3
make_tuple
(
make_pass_through_transform
(
Number
<
KThreadWrite
/
kfold
/
KThreadReadPerm
>
{}),
make_pass_through_transform
(
Number
<
K0PerThreadWrite
>
{}),
make_xor_transform
(
make_xor_
with_modulo_
transform
(
make_tuple
(
Number
<
KThreadReadPerm
*
M1
>
{},
Number
<
kfold
*
M0
/
mpair
>
{})),
make_pass_through_transform
(
Number
<
mpair
>
{}),
make_pass_through_transform
(
AK1Number
)),
...
...
@@ -920,8 +918,8 @@ struct GridwiseGemm_xdl_cshuffle_v3
constexpr
auto
b_lds_block_desc_permuted
=
transform_tensor_descriptor
(
b_lds_block_desc
,
make_tuple
(
make_xor_transform
(
make_tuple
(
Number
<
NPerBlock
/
NLdsLayer
>
{},
Number
<
BK0Number
*
NLdsLayer
>
{})),
make_tuple
(
make_xor_
with_modulo_
transform
(
make_tuple
(
Number
<
NPerBlock
/
NLdsLayer
>
{},
Number
<
BK0Number
*
NLdsLayer
>
{})),
make_pass_through_transform
(
BK1Number
)),
make_tuple
(
Sequence
<
1
,
0
>
{},
Sequence
<
2
>
{}),
make_tuple
(
Sequence
<
1
,
0
>
{},
Sequence
<
2
>
{}));
...
...
@@ -983,7 +981,7 @@ struct GridwiseGemm_xdl_cshuffle_v3
make_tuple
(
make_pass_through_transform
(
Number
<
KThreadWrite
/
kfold
/
KThreadReadPerm
>
{}),
make_pass_through_transform
(
Number
<
K0PerThreadWrite
>
{}),
make_xor_transform
(
make_xor_
with_modulo_
transform
(
make_tuple
(
Number
<
KThreadReadPerm
*
N1
>
{},
Number
<
kfold
*
N0
/
npair
>
{})),
make_pass_through_transform
(
Number
<
npair
>
{}),
make_pass_through_transform
(
BK1Number
)),
...
...
include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_splitk_lds_direct_load.hpp
View file @
e4112de7
...
...
@@ -38,8 +38,7 @@ __global__ void
const
BElementwiseOperation
b_element_op
,
const
CElementwiseOperation
c_element_op
)
{
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \
defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__))
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx9__))
constexpr
index_t
shared_size
=
GridwiseGemm
::
GetSharedMemoryNumberOfByte
();
__shared__
uint8_t
p_shared
[
shared_size
];
...
...
@@ -52,7 +51,7 @@ __global__ void
ignore
=
a_element_op
;
ignore
=
b_element_op
;
ignore
=
c_element_op
;
#endif // end of if (defined(__gfx9
08__) || defined(__gfx90a
__))
#endif // end of if (defined(__gfx9__))
}
template
<
index_t
BlockSize
,
...
...
include/ck/tensor_operation/operator_transform/transform_conv_bwd_weight_to_gemm_v2.hpp
0 → 100644
View file @
e4112de7
// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck/library/utility/numeric.hpp"
#include "ck/utility/common_header.hpp"
#include "ck/tensor_description/tensor_descriptor.hpp"
#include "ck/tensor_description/tensor_descriptor_helper.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/convolution_backward_weight_specialization.hpp"
namespace
ck
{
namespace
tensor_operation
{
/**
* @brief Transform conv bwd weight to gemm v2
*
* This version does following things:
* 1. Merge KBatch with K0 to align descriptor with universal gemm
* 2. Merge Batch with M and N dimension. It allows to increase compute in
* case of small M and N. It also allows to vector load and store in case of
* K = 1, C = 1 and NHWGC layout.
*/
template
<
index_t
NDimSpatial
,
index_t
MPerBlock
,
index_t
NPerBlock
,
index_t
GemmK1Number
,
index_t
K0PerBlock
,
index_t
NumBatchToMerge
,
device
::
ConvolutionBackwardWeightSpecialization
ConvBackwardWeightSpecialization
>
struct
TransformConvBwdWeightToGemmV2
{
static
constexpr
auto
I0
=
Number
<
0
>
{};
static
constexpr
auto
I1
=
Number
<
1
>
{};
template
<
index_t
NDim
,
typename
enable_if
<
NDim
==
2
,
bool
>
::
type
=
false
>
constexpr
static
auto
make_out_grid_desc
(
const
index_t
N
,
const
index_t
Ho
,
const
index_t
Wo
,
const
index_t
K
,
const
std
::
array
<
index_t
,
NDimSpatial
+
3
>&
output_strides
)
{
const
index_t
BatchStride
=
output_strides
[
0
];
const
index_t
WoStride
=
output_strides
[
4
];
const
auto
KStride
=
Number
<
1
>
{};
return
make_naive_tensor_descriptor
(
make_tuple
(
N
*
Ho
*
Wo
,
NumBatchToMerge
,
K
),
make_tuple
(
WoStride
,
BatchStride
,
KStride
));
}
template
<
index_t
NDim
,
typename
enable_if
<
NDim
==
2
,
bool
>
::
type
=
false
>
constexpr
static
auto
make_in_grid_desc
(
const
index_t
N
,
const
index_t
Hi
,
const
index_t
Wi
,
const
index_t
C
,
const
std
::
array
<
index_t
,
NDimSpatial
+
3
>&
input_strides
)
{
const
index_t
BatchStride
=
input_strides
[
0
];
const
index_t
NStride
=
input_strides
[
1
];
const
index_t
HiStride
=
input_strides
[
3
];
const
index_t
WiStride
=
input_strides
[
4
];
const
auto
CStride
=
input_strides
[
2
];
if
constexpr
(
ConvBackwardWeightSpecialization
==
device
::
ConvolutionBackwardWeightSpecialization
::
Filter1x1Stride1Pad0
)
{
return
make_naive_tensor_descriptor
(
make_tuple
(
N
*
Hi
*
Wi
,
NumBatchToMerge
,
C
),
make_tuple
(
WiStride
,
BatchStride
,
CStride
));
}
else
{
return
make_naive_tensor_descriptor
(
make_tuple
(
N
,
Hi
,
Wi
,
NumBatchToMerge
,
C
),
make_tuple
(
NStride
,
HiStride
,
WiStride
,
BatchStride
,
CStride
));
}
}
template
<
index_t
NDim
,
typename
enable_if
<
NDim
==
2
,
bool
>
::
type
=
false
>
constexpr
static
auto
make_wei_grid_desc
(
const
index_t
K
,
const
index_t
Y
,
const
index_t
X
,
const
index_t
C
,
const
std
::
array
<
index_t
,
NDimSpatial
+
3
>&
weights_strides
)
{
const
auto
CStride
=
Number
<
1
>
{};
const
auto
KStride
=
weights_strides
[
1
];
const
auto
XStride
=
weights_strides
[
4
];
const
auto
BatchStride
=
weights_strides
[
0
];
// Add NumBatchToMerge for Batch+M dimension and, 1 as a placehorder
// for Batch+N dimension
const
auto
desc
=
make_naive_tensor_descriptor
(
make_tuple
(
NumBatchToMerge
,
K
,
Y
*
X
,
1
,
C
),
make_tuple
(
BatchStride
,
KStride
,
XStride
,
BatchStride
,
CStride
));
// Padd 1 to NumBatchToMerge
const
auto
padded_desc
=
transform_tensor_descriptor
(
desc
,
make_tuple
(
make_pass_through_transform
(
NumBatchToMerge
),
make_pass_through_transform
(
K
),
make_pass_through_transform
(
Y
*
X
),
make_pad_transform
(
1
,
0
,
NumBatchToMerge
-
1
),
make_pass_through_transform
(
C
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{},
Sequence
<
3
>
{},
Sequence
<
4
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{},
Sequence
<
3
>
{},
Sequence
<
4
>
{}));
// We need only matrices from diagonal. Xor returns 0 for the same
// values. So if matrices is not on diagonal then it will be stored in padding.
// To avoid use of modulo after xor we assume that NumBatch to merge is power of 2.
static_assert
(
NumBatchToMerge
==
1
||
NumBatchToMerge
==
2
||
NumBatchToMerge
==
4
||
NumBatchToMerge
==
8
||
NumBatchToMerge
==
16
||
NumBatchToMerge
==
32
||
NumBatchToMerge
==
64
);
const
auto
unmerged_padded_desc
=
transform_tensor_descriptor
(
padded_desc
,
make_tuple
(
make_xor_transform
(
make_tuple
(
NumBatchToMerge
,
NumBatchToMerge
)),
make_pass_through_transform
(
K
),
make_pass_through_transform
(
Y
*
X
),
make_pass_through_transform
(
C
)),
make_tuple
(
Sequence
<
0
,
3
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{},
Sequence
<
4
>
{}),
make_tuple
(
Sequence
<
0
,
3
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{},
Sequence
<
4
>
{}));
// Merge To M, N
return
transform_tensor_descriptor
(
unmerged_padded_desc
,
make_tuple
(
make_merge_transform
(
make_tuple
(
NumBatchToMerge
,
K
)),
make_merge_transform
(
make_tuple
(
Y
*
X
,
NumBatchToMerge
,
C
))),
make_tuple
(
Sequence
<
0
,
1
>
{},
Sequence
<
2
,
3
,
4
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}));
}
template
<
index_t
NDim
,
typename
enable_if
<
NDim
==
3
,
bool
>
::
type
=
false
>
constexpr
static
auto
make_out_grid_desc
(
const
index_t
N
,
const
index_t
Do
,
const
index_t
Ho
,
const
index_t
Wo
,
const
index_t
K
,
const
std
::
array
<
index_t
,
NDimSpatial
+
3
>&
output_strides
)
{
const
index_t
BatchStride
=
output_strides
[
0
];
const
index_t
WoStride
=
output_strides
[
5
];
const
auto
KStride
=
Number
<
1
>
{};
return
make_naive_tensor_descriptor
(
make_tuple
(
N
*
Do
*
Ho
*
Wo
,
NumBatchToMerge
,
K
),
make_tuple
(
WoStride
,
BatchStride
,
KStride
));
}
template
<
index_t
NDim
,
typename
enable_if
<
NDim
==
3
,
bool
>
::
type
=
false
>
constexpr
static
auto
make_in_grid_desc
(
const
index_t
N
,
const
index_t
Di
,
const
index_t
Hi
,
const
index_t
Wi
,
const
index_t
C
,
const
std
::
array
<
index_t
,
NDimSpatial
+
3
>&
input_strides
)
{
const
index_t
BatchStride
=
input_strides
[
0
];
const
index_t
NStride
=
input_strides
[
1
];
const
index_t
DiStride
=
input_strides
[
3
];
const
index_t
HiStride
=
input_strides
[
4
];
const
index_t
WiStride
=
input_strides
[
5
];
const
auto
CStride
=
input_strides
[
2
];
if
constexpr
(
ConvBackwardWeightSpecialization
==
device
::
ConvolutionBackwardWeightSpecialization
::
Filter1x1Stride1Pad0
)
{
return
make_naive_tensor_descriptor
(
make_tuple
(
N
*
Di
*
Hi
*
Wi
,
NumBatchToMerge
,
C
),
make_tuple
(
WiStride
,
BatchStride
,
CStride
));
}
else
{
return
make_naive_tensor_descriptor
(
make_tuple
(
N
,
Di
,
Hi
,
Wi
,
NumBatchToMerge
,
C
),
make_tuple
(
NStride
,
DiStride
,
HiStride
,
WiStride
,
BatchStride
,
CStride
));
}
}
template
<
index_t
NDim
,
typename
enable_if
<
NDim
==
3
,
bool
>
::
type
=
false
>
constexpr
static
auto
make_wei_grid_desc
(
const
index_t
K
,
const
index_t
Z
,
const
index_t
Y
,
const
index_t
X
,
const
index_t
C
,
const
std
::
array
<
index_t
,
NDimSpatial
+
3
>&
weights_strides
)
{
const
auto
CStride
=
Number
<
1
>
{};
const
auto
KStride
=
weights_strides
[
1
];
const
auto
XStride
=
weights_strides
[
5
];
const
auto
BatchStride
=
weights_strides
[
0
];
// Add NumBatchToMerge for Batch+M dimension and, 1 for placehord for Batch+N dimension
const
auto
desc
=
make_naive_tensor_descriptor
(
make_tuple
(
NumBatchToMerge
,
K
,
Z
*
Y
*
X
,
1
,
C
),
make_tuple
(
BatchStride
,
KStride
,
XStride
,
BatchStride
,
CStride
));
// Padd 1 to NumBatchToMerge
const
auto
padded_desc
=
transform_tensor_descriptor
(
desc
,
make_tuple
(
make_pass_through_transform
(
NumBatchToMerge
),
make_pass_through_transform
(
K
),
make_pass_through_transform
(
Z
*
Y
*
X
),
make_pad_transform
(
1
,
0
,
NumBatchToMerge
-
1
),
make_pass_through_transform
(
C
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{},
Sequence
<
3
>
{},
Sequence
<
4
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{},
Sequence
<
3
>
{},
Sequence
<
4
>
{}));
// We need only matrices from diagonal. Xor returns 0 for the same
// values. So if matrices is not on diagonal then it will be stored in padding.
// To avoid use of modulo after xor we assume that NumBatch to merge is power of 2.
static_assert
(
NumBatchToMerge
==
1
||
NumBatchToMerge
==
2
||
NumBatchToMerge
==
4
||
NumBatchToMerge
==
8
||
NumBatchToMerge
==
16
||
NumBatchToMerge
==
32
||
NumBatchToMerge
==
64
);
const
auto
unmerged_padded_desc
=
transform_tensor_descriptor
(
padded_desc
,
make_tuple
(
make_xor_transform
(
make_tuple
(
NumBatchToMerge
,
NumBatchToMerge
)),
make_pass_through_transform
(
K
),
make_pass_through_transform
(
Z
*
Y
*
X
),
make_pass_through_transform
(
C
)),
make_tuple
(
Sequence
<
0
,
3
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{},
Sequence
<
4
>
{}),
make_tuple
(
Sequence
<
0
,
3
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{},
Sequence
<
4
>
{}));
// Merge To M, N
return
transform_tensor_descriptor
(
unmerged_padded_desc
,
make_tuple
(
make_merge_transform
(
make_tuple
(
NumBatchToMerge
,
K
)),
make_merge_transform
(
make_tuple
(
Z
*
Y
*
X
,
NumBatchToMerge
,
C
))),
make_tuple
(
Sequence
<
0
,
1
>
{},
Sequence
<
2
,
3
,
4
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}));
}
template
<
index_t
NDim
,
typename
enable_if
<
NDim
==
2
,
bool
>
::
type
=
false
>
static
auto
MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N
(
const
index_t
N
,
const
index_t
K
,
const
index_t
C
,
const
std
::
array
<
index_t
,
NDimSpatial
>&
input_spatial_lengths
,
const
std
::
array
<
index_t
,
NDimSpatial
>&
filter_spatial_lengths
,
const
std
::
array
<
index_t
,
NDimSpatial
>&
output_spatial_lengths
,
const
std
::
array
<
index_t
,
NDimSpatial
+
3
>&
input_strides
,
const
std
::
array
<
index_t
,
NDimSpatial
+
3
>&
weights_strides
,
const
std
::
array
<
index_t
,
NDimSpatial
+
3
>&
output_strides
,
const
std
::
array
<
index_t
,
NDimSpatial
>&
conv_filter_strides
,
const
std
::
array
<
index_t
,
NDimSpatial
>&
conv_filter_dilations
,
const
std
::
array
<
index_t
,
NDimSpatial
>&
input_left_pads
,
const
std
::
array
<
index_t
,
NDimSpatial
>&
input_right_pads
,
const
index_t
batch_k
)
{
using
namespace
ck
;
const
index_t
Hi
=
input_spatial_lengths
[
0
];
const
index_t
Wi
=
input_spatial_lengths
[
1
];
const
index_t
Ho
=
output_spatial_lengths
[
0
];
const
index_t
Wo
=
output_spatial_lengths
[
1
];
const
index_t
Y
=
filter_spatial_lengths
[
0
];
const
index_t
X
=
filter_spatial_lengths
[
1
];
const
index_t
ConvStrideH
=
conv_filter_strides
[
0
];
const
index_t
ConvStrideW
=
conv_filter_strides
[
1
];
const
index_t
ConvDilationH
=
conv_filter_dilations
[
0
];
const
index_t
ConvDilationW
=
conv_filter_dilations
[
1
];
const
index_t
InLeftPadH
=
input_left_pads
[
0
];
const
index_t
InLeftPadW
=
input_left_pads
[
1
];
const
index_t
InRightPadH
=
input_right_pads
[
0
];
const
index_t
InRightPadW
=
input_right_pads
[
1
];
const
index_t
GemmKTotal
=
N
*
Ho
*
Wo
;
const
index_t
GemmM
=
K
*
NumBatchToMerge
;
const
index_t
GemmN
=
C
*
X
*
Y
*
NumBatchToMerge
;
const
auto
PadGemmM
=
MPerBlock
-
GemmM
%
MPerBlock
;
const
auto
PadGemmN
=
NPerBlock
-
GemmN
%
NPerBlock
;
const
index_t
GemmKBatch
=
batch_k
;
const
index_t
GemmK0
=
math
::
integer_divide_ceil
(
GemmKTotal
,
GemmK1Number
*
K0PerBlock
*
GemmKBatch
)
*
K0PerBlock
;
const
index_t
GemmKPad
=
GemmKBatch
*
GemmK0
*
GemmK1Number
;
const
auto
out_grid_desc
=
make_out_grid_desc
<
NDim
>
(
N
,
Ho
,
Wo
,
K
,
output_strides
);
const
auto
in_grid_desc
=
make_in_grid_desc
<
NDim
>
(
N
,
Hi
,
Wi
,
C
,
input_strides
);
const
auto
wei_grid_desc
=
make_wei_grid_desc
<
NDim
>
(
K
,
Y
,
X
,
C
,
weights_strides
);
if
constexpr
(
ConvBackwardWeightSpecialization
==
device
::
ConvolutionBackwardWeightSpecialization
::
Filter1x1Stride1Pad0
)
{
// A: output tensor
const
auto
out_gemmkpad_gemmm_grid_desc
=
transform_tensor_descriptor
(
out_grid_desc
,
make_tuple
(
make_right_pad_transform
(
GemmKTotal
,
GemmKPad
-
GemmKTotal
),
make_merge_transform
(
make_tuple
(
NumBatchToMerge
,
GemmM
/
NumBatchToMerge
))),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
,
2
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}));
const
auto
out_gemmkbatch_gemmk0_gemmm_gemmk1_grid_desc
=
transform_tensor_descriptor
(
out_gemmkpad_gemmm_grid_desc
,
make_tuple
(
make_unmerge_transform
(
make_tuple
(
GemmKBatch
*
GemmK0
,
GemmK1Number
)),
make_pass_through_transform
(
GemmM
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
0
,
2
>
{},
Sequence
<
1
>
{}));
// B: input tensor
const
auto
in_gemmkpad_gemmn_grid_desc
=
transform_tensor_descriptor
(
in_grid_desc
,
make_tuple
(
make_right_pad_transform
(
GemmKTotal
,
GemmKPad
-
GemmKTotal
),
make_merge_transform
(
make_tuple
(
NumBatchToMerge
,
GemmN
/
NumBatchToMerge
))),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
,
2
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}));
const
auto
in_gemmkbatch_gemmk0_gemmn_gemmk1_grid_desc
=
transform_tensor_descriptor
(
in_gemmkpad_gemmn_grid_desc
,
make_tuple
(
make_unmerge_transform
(
make_tuple
(
GemmKBatch
*
GemmK0
,
GemmK1Number
)),
make_pass_through_transform
(
GemmN
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
0
,
2
>
{},
Sequence
<
1
>
{}));
return
make_tuple
(
out_gemmkbatch_gemmk0_gemmm_gemmk1_grid_desc
,
in_gemmkbatch_gemmk0_gemmn_gemmk1_grid_desc
,
wei_grid_desc
);
}
else
{
// A: output tensor
const
auto
out_gemmkpad_gemmm_grid_desc
=
transform_tensor_descriptor
(
out_grid_desc
,
make_tuple
(
make_right_pad_transform
(
GemmKTotal
,
GemmKPad
-
GemmKTotal
),
make_merge_transform
(
make_tuple
(
NumBatchToMerge
,
GemmM
/
NumBatchToMerge
))),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
,
2
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}));
const
auto
out_gemmkbatch_gemmk0_gemmm_gemmk1_grid_desc
=
transform_tensor_descriptor
(
out_gemmkpad_gemmm_grid_desc
,
make_tuple
(
make_unmerge_transform
(
make_tuple
(
GemmKBatch
*
GemmK0
,
GemmK1Number
)),
make_pass_through_transform
(
GemmM
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
0
,
2
>
{},
Sequence
<
1
>
{}));
// B: input tensor
const
auto
in_n_hip_wip_c_grid_desc
=
transform_tensor_descriptor
(
in_grid_desc
,
make_tuple
(
make_pass_through_transform
(
N
),
make_pad_transform
(
Hi
,
InLeftPadH
,
InRightPadH
),
make_pad_transform
(
Wi
,
InLeftPadW
,
InRightPadW
),
make_pass_through_transform
(
NumBatchToMerge
),
make_pass_through_transform
(
C
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{},
Sequence
<
3
>
{},
Sequence
<
4
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{},
Sequence
<
3
>
{},
Sequence
<
4
>
{}));
const
auto
in_n_y_ho_x_wo_c_grid_desc
=
transform_tensor_descriptor
(
in_n_hip_wip_c_grid_desc
,
make_tuple
(
make_pass_through_transform
(
N
),
make_embed_transform
(
make_tuple
(
Y
,
Ho
),
make_tuple
(
ConvDilationH
,
ConvStrideH
)),
make_embed_transform
(
make_tuple
(
X
,
Wo
),
make_tuple
(
ConvDilationW
,
ConvStrideW
)),
make_pass_through_transform
(
NumBatchToMerge
),
make_pass_through_transform
(
C
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{},
Sequence
<
3
>
{},
Sequence
<
4
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
,
2
>
{},
Sequence
<
3
,
4
>
{},
Sequence
<
5
>
{},
Sequence
<
6
>
{}));
const
auto
in_gemmktotal_gemmn_grid_desc
=
transform_tensor_descriptor
(
in_n_y_ho_x_wo_c_grid_desc
,
make_tuple
(
make_merge_transform
(
make_tuple
(
Y
,
X
,
NumBatchToMerge
,
C
)),
make_merge_transform
(
make_tuple
(
N
,
Ho
,
Wo
))),
make_tuple
(
Sequence
<
1
,
3
,
5
,
6
>
{},
Sequence
<
0
,
2
,
4
>
{}),
make_tuple
(
Sequence
<
1
>
{},
Sequence
<
0
>
{}));
const
auto
in_gemmkpad_gemmn_grid_desc
=
transform_tensor_descriptor
(
in_gemmktotal_gemmn_grid_desc
,
make_tuple
(
make_right_pad_transform
(
GemmKTotal
,
GemmKPad
-
GemmKTotal
),
make_pass_through_transform
(
GemmN
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}));
const
auto
in_gemmkbatch_gemmk0_gemmn_gemmk1_grid_desc
=
transform_tensor_descriptor
(
in_gemmkpad_gemmn_grid_desc
,
make_tuple
(
make_unmerge_transform
(
make_tuple
(
GemmKBatch
*
GemmK0
,
GemmK1Number
)),
make_pass_through_transform
(
GemmN
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
0
,
2
>
{},
Sequence
<
1
>
{}));
// Padd
const
auto
out_gemmkbatch_gemmk0_gemmm_gemmk1_pad_grid_desc
=
transform_tensor_descriptor
(
out_gemmkbatch_gemmk0_gemmm_gemmk1_grid_desc
,
make_tuple
(
make_pass_through_transform
(
GemmKBatch
*
GemmK0
),
make_right_pad_transform
(
GemmM
,
PadGemmM
),
make_pass_through_transform
(
GemmK1Number
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{}));
const
auto
in_gemmkbatch_gemmk0_gemmn_gemmk1_pad_grid_desc
=
transform_tensor_descriptor
(
in_gemmkbatch_gemmk0_gemmn_gemmk1_grid_desc
,
make_tuple
(
make_pass_through_transform
(
GemmKBatch
*
GemmK0
),
make_right_pad_transform
(
GemmN
,
PadGemmN
),
make_pass_through_transform
(
GemmK1Number
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{}));
const
auto
wei_gemmm_gemmn_pad_grid_desc
=
transform_tensor_descriptor
(
wei_grid_desc
,
make_tuple
(
make_right_pad_transform
(
GemmM
,
PadGemmM
),
make_right_pad_transform
(
GemmN
,
PadGemmN
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}));
return
make_tuple
(
out_gemmkbatch_gemmk0_gemmm_gemmk1_pad_grid_desc
,
in_gemmkbatch_gemmk0_gemmn_gemmk1_pad_grid_desc
,
wei_gemmm_gemmn_pad_grid_desc
);
}
}
template
<
index_t
NDim
,
typename
enable_if
<
NDim
==
3
,
bool
>
::
type
=
false
>
static
auto
MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N
(
const
index_t
N
,
const
index_t
K
,
const
index_t
C
,
const
std
::
array
<
index_t
,
NDimSpatial
>&
input_spatial_lengths
,
const
std
::
array
<
index_t
,
NDimSpatial
>&
filter_spatial_lengths
,
const
std
::
array
<
index_t
,
NDimSpatial
>&
output_spatial_lengths
,
const
std
::
array
<
index_t
,
NDimSpatial
+
3
>&
input_strides
,
const
std
::
array
<
index_t
,
NDimSpatial
+
3
>&
weights_strides
,
const
std
::
array
<
index_t
,
NDimSpatial
+
3
>&
output_strides
,
const
std
::
array
<
index_t
,
NDimSpatial
>&
conv_filter_strides
,
const
std
::
array
<
index_t
,
NDimSpatial
>&
conv_filter_dilations
,
const
std
::
array
<
index_t
,
NDimSpatial
>&
input_left_pads
,
const
std
::
array
<
index_t
,
NDimSpatial
>&
input_right_pads
,
const
index_t
batch_k
)
{
using
namespace
ck
;
const
index_t
Di
=
input_spatial_lengths
[
0
];
const
index_t
Hi
=
input_spatial_lengths
[
1
];
const
index_t
Wi
=
input_spatial_lengths
[
2
];
const
index_t
Do
=
output_spatial_lengths
[
0
];
const
index_t
Ho
=
output_spatial_lengths
[
1
];
const
index_t
Wo
=
output_spatial_lengths
[
2
];
const
index_t
Z
=
filter_spatial_lengths
[
0
];
const
index_t
Y
=
filter_spatial_lengths
[
1
];
const
index_t
X
=
filter_spatial_lengths
[
2
];
const
index_t
ConvStrideD
=
conv_filter_strides
[
0
];
const
index_t
ConvStrideH
=
conv_filter_strides
[
1
];
const
index_t
ConvStrideW
=
conv_filter_strides
[
2
];
const
index_t
ConvDilationD
=
conv_filter_dilations
[
0
];
const
index_t
ConvDilationH
=
conv_filter_dilations
[
1
];
const
index_t
ConvDilationW
=
conv_filter_dilations
[
2
];
const
index_t
InLeftPadD
=
input_left_pads
[
0
];
const
index_t
InLeftPadH
=
input_left_pads
[
1
];
const
index_t
InLeftPadW
=
input_left_pads
[
2
];
const
index_t
InRightPadD
=
input_right_pads
[
0
];
const
index_t
InRightPadH
=
input_right_pads
[
1
];
const
index_t
InRightPadW
=
input_right_pads
[
2
];
const
index_t
GemmKTotal
=
N
*
Do
*
Ho
*
Wo
;
const
index_t
GemmM
=
K
*
NumBatchToMerge
;
const
index_t
GemmN
=
C
*
Z
*
X
*
Y
*
NumBatchToMerge
;
const
auto
PadGemmM
=
MPerBlock
-
GemmM
%
MPerBlock
;
const
auto
PadGemmN
=
NPerBlock
-
GemmN
%
NPerBlock
;
const
index_t
GemmKBatch
=
batch_k
;
const
index_t
GemmK0
=
math
::
integer_divide_ceil
(
GemmKTotal
,
GemmK1Number
*
K0PerBlock
*
GemmKBatch
)
*
K0PerBlock
;
const
index_t
GemmKPad
=
GemmKBatch
*
GemmK0
*
GemmK1Number
;
const
auto
out_grid_desc
=
make_out_grid_desc
<
NDim
>
(
N
,
Do
,
Ho
,
Wo
,
K
,
output_strides
);
const
auto
in_grid_desc
=
make_in_grid_desc
<
NDim
>
(
N
,
Di
,
Hi
,
Wi
,
C
,
input_strides
);
const
auto
wei_grid_desc
=
make_wei_grid_desc
<
NDim
>
(
K
,
Z
,
Y
,
X
,
C
,
weights_strides
);
if
constexpr
(
ConvBackwardWeightSpecialization
==
device
::
ConvolutionBackwardWeightSpecialization
::
Filter1x1Stride1Pad0
)
{
// A: output tensor
const
auto
out_gemmkpad_gemmm_grid_desc
=
transform_tensor_descriptor
(
out_grid_desc
,
make_tuple
(
make_right_pad_transform
(
GemmKTotal
,
GemmKPad
-
GemmKTotal
),
make_merge_transform
(
make_tuple
(
NumBatchToMerge
,
GemmM
/
NumBatchToMerge
))),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
,
2
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}));
const
auto
out_gemmkbatch_gemmk0_gemmm_gemmk1_grid_desc
=
transform_tensor_descriptor
(
out_gemmkpad_gemmm_grid_desc
,
make_tuple
(
make_unmerge_transform
(
make_tuple
(
GemmKBatch
*
GemmK0
,
GemmK1Number
)),
make_pass_through_transform
(
GemmM
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
0
,
2
>
{},
Sequence
<
1
>
{}));
// B: input tensor
const
auto
in_gemmkpad_gemmn_grid_desc
=
transform_tensor_descriptor
(
in_grid_desc
,
make_tuple
(
make_right_pad_transform
(
GemmKTotal
,
GemmKPad
-
GemmKTotal
),
make_merge_transform
(
make_tuple
(
NumBatchToMerge
,
GemmN
/
NumBatchToMerge
))),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
,
2
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}));
const
auto
in_gemmkbatch_gemmk0_gemmn_gemmk1_grid_desc
=
transform_tensor_descriptor
(
in_gemmkpad_gemmn_grid_desc
,
make_tuple
(
make_unmerge_transform
(
make_tuple
(
GemmKBatch
*
GemmK0
,
GemmK1Number
)),
make_pass_through_transform
(
GemmN
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
0
,
2
>
{},
Sequence
<
1
>
{}));
return
make_tuple
(
out_gemmkbatch_gemmk0_gemmm_gemmk1_grid_desc
,
in_gemmkbatch_gemmk0_gemmn_gemmk1_grid_desc
,
wei_grid_desc
);
}
else
{
// A: output tensor
const
auto
out_gemmkpad_gemmm_grid_desc
=
transform_tensor_descriptor
(
out_grid_desc
,
make_tuple
(
make_right_pad_transform
(
GemmKTotal
,
GemmKPad
-
GemmKTotal
),
make_merge_transform
(
make_tuple
(
NumBatchToMerge
,
GemmM
/
NumBatchToMerge
))),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
,
2
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}));
const
auto
out_gemmkbatch_gemmk0_gemmm_gemmk1_grid_desc
=
transform_tensor_descriptor
(
out_gemmkpad_gemmm_grid_desc
,
make_tuple
(
make_unmerge_transform
(
make_tuple
(
GemmKBatch
*
GemmK0
,
GemmK1Number
)),
make_pass_through_transform
(
GemmM
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
0
,
2
>
{},
Sequence
<
1
>
{}));
// B: input tensor
const
auto
in_n_dip_hip_wip_c_grid_desc
=
transform_tensor_descriptor
(
in_grid_desc
,
make_tuple
(
make_pass_through_transform
(
N
),
make_pad_transform
(
Di
,
InLeftPadD
,
InRightPadD
),
make_pad_transform
(
Hi
,
InLeftPadH
,
InRightPadH
),
make_pad_transform
(
Wi
,
InLeftPadW
,
InRightPadW
),
make_pass_through_transform
(
NumBatchToMerge
),
make_pass_through_transform
(
C
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{},
Sequence
<
3
>
{},
Sequence
<
4
>
{},
Sequence
<
5
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{},
Sequence
<
3
>
{},
Sequence
<
4
>
{},
Sequence
<
5
>
{}));
const
auto
in_n_z_do_y_ho_x_wo_c_grid_desc
=
transform_tensor_descriptor
(
in_n_dip_hip_wip_c_grid_desc
,
make_tuple
(
make_pass_through_transform
(
N
),
make_embed_transform
(
make_tuple
(
Z
,
Do
),
make_tuple
(
ConvDilationD
,
ConvStrideD
)),
make_embed_transform
(
make_tuple
(
Y
,
Ho
),
make_tuple
(
ConvDilationH
,
ConvStrideH
)),
make_embed_transform
(
make_tuple
(
X
,
Wo
),
make_tuple
(
ConvDilationW
,
ConvStrideW
)),
make_pass_through_transform
(
NumBatchToMerge
),
make_pass_through_transform
(
C
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{},
Sequence
<
3
>
{},
Sequence
<
4
>
{},
Sequence
<
5
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
,
2
>
{},
Sequence
<
3
,
4
>
{},
Sequence
<
5
,
6
>
{},
Sequence
<
7
>
{},
Sequence
<
8
>
{}));
const
auto
in_gemmktotal_gemmn_grid_desc
=
transform_tensor_descriptor
(
in_n_z_do_y_ho_x_wo_c_grid_desc
,
make_tuple
(
make_merge_transform
(
make_tuple
(
Z
,
Y
,
X
,
NumBatchToMerge
,
C
)),
make_merge_transform
(
make_tuple
(
N
,
Do
,
Ho
,
Wo
))),
make_tuple
(
Sequence
<
1
,
3
,
5
,
7
,
8
>
{},
Sequence
<
0
,
2
,
4
,
6
>
{}),
make_tuple
(
Sequence
<
1
>
{},
Sequence
<
0
>
{}));
const
auto
in_gemmkpad_gemmn_grid_desc
=
transform_tensor_descriptor
(
in_gemmktotal_gemmn_grid_desc
,
make_tuple
(
make_right_pad_transform
(
GemmKTotal
,
GemmKPad
-
GemmKTotal
),
make_pass_through_transform
(
GemmN
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}));
const
auto
in_gemmkbatch_gemmk0_gemmn_gemmk1_grid_desc
=
transform_tensor_descriptor
(
in_gemmkpad_gemmn_grid_desc
,
make_tuple
(
make_unmerge_transform
(
make_tuple
(
GemmKBatch
*
GemmK0
,
GemmK1Number
)),
make_pass_through_transform
(
GemmN
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
0
,
2
>
{},
Sequence
<
1
>
{}));
// Padd
const
auto
out_gemmkbatch_gemmk0_gemmm_gemmk1_pad_grid_desc
=
transform_tensor_descriptor
(
out_gemmkbatch_gemmk0_gemmm_gemmk1_grid_desc
,
make_tuple
(
make_pass_through_transform
(
GemmKBatch
*
GemmK0
),
make_right_pad_transform
(
GemmM
,
PadGemmM
),
make_pass_through_transform
(
GemmK1Number
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{}));
const
auto
in_gemmkbatch_gemmk0_gemmn_gemmk1_pad_grid_desc
=
transform_tensor_descriptor
(
in_gemmkbatch_gemmk0_gemmn_gemmk1_grid_desc
,
make_tuple
(
make_pass_through_transform
(
GemmKBatch
*
GemmK0
),
make_right_pad_transform
(
GemmN
,
PadGemmN
),
make_pass_through_transform
(
GemmK1Number
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{}));
const
auto
wei_gemmm_gemmn_pad_grid_desc
=
transform_tensor_descriptor
(
wei_grid_desc
,
make_tuple
(
make_right_pad_transform
(
GemmM
,
PadGemmM
),
make_right_pad_transform
(
GemmN
,
PadGemmN
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}));
return
make_tuple
(
out_gemmkbatch_gemmk0_gemmm_gemmk1_pad_grid_desc
,
in_gemmkbatch_gemmk0_gemmn_gemmk1_pad_grid_desc
,
wei_gemmm_gemmn_pad_grid_desc
);
}
}
// function end
};
}
// namespace tensor_operation
}
// namespace ck
include/ck_tile/core/config.hpp
View file @
e4112de7
...
...
@@ -3,6 +3,21 @@
#pragma once
#if defined(__gfx908__) || defined(__gfx90a__) || defined(__gfx940__) || defined(__gfx941__) || \
defined(__gfx942__)
#define __gfx9__
#endif
#if defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__)
#define __gfx94__
#endif
#if defined(__gfx1030__) || defined(__gfx1031__) || defined(__gfx1032__) || \
defined(__gfx1034__) || defined(__gfx1035__) || defined(__gfx1036__)
#define __gfx103__
#endif
#if defined(__gfx1100__) || defined(__gfx1101__) || defined(__gfx1102__) || defined(__gfx1103__)
#define __gfx11__
#endif
#ifndef CK_TILE_DONT_USE_HIP_RUNTIME_HEADERS
#include "hip/hip_runtime.h"
#include "hip/hip_fp16.h"
...
...
@@ -109,15 +124,13 @@
// buffer atomic add: floating point
#ifndef __HIP_DEVICE_COMPILE__ // for host code
#define CK_TILE_USE_AMD_BUFFER_ATOMIC_ADD_FLOAT 1
#elif defined(__gfx908__) || defined(__gfx90a__) || defined(__gfx940__) || defined(__gfx941__) || \
defined(__gfx942__) // for GPU code
#elif defined(__gfx9__) // for GPU code
#define CK_TILE_USE_AMD_BUFFER_ATOMIC_ADD_FLOAT 1
#else // for GPU code
#define CK_TILE_USE_AMD_BUFFER_ATOMIC_ADD_FLOAT 0
#endif
#if(defined(__gfx90a__) || defined(__gfx940__) || defined(__gfx941__) || \
defined(__gfx942__)) // for GPU code
#if(defined(__gfx90a__) || defined(__gfx94__)) // for GPU code
#define CK_TILE_USE_AMD_BUFFER_ATOMIC_MAX_FLOAT64 1
#else
#define CK_TILE_USE_AMD_BUFFER_ATOMIC_MAX_FLOAT64 0
...
...
@@ -137,13 +150,12 @@
#ifndef __HIP_DEVICE_COMPILE__ // for host code
#define CK_TILE_BUFFER_RESOURCE_3RD_DWORD 0xffffffff
#elif defined(__gfx803__) || defined(__gfx900__) || defined(__gfx906__) || defined(__gfx908__) || \
defined(__gfx90a__) || defined(__gfx940__) || defined(__gfx941__) || \
defined(__gfx942__) // for GPU code
#elif defined(__gfx803__) || defined(__gfx900__) || defined(__gfx906__) || \
defined(__gfx9__) // for GPU code
#define CK_TILE_BUFFER_RESOURCE_3RD_DWORD 0x00020000
#elif defined(__gfx103
0
__) // for GPU code
#elif defined(__gfx103__) // for GPU code
#define CK_TILE_BUFFER_RESOURCE_3RD_DWORD 0x31014000
#elif defined(__gfx11
00__) || defined(__gfx1101__) || defined(__gfx1102
__) // for GPU code
#elif defined(__gfx11__) // for GPU code
#define CK_TILE_BUFFER_RESOURCE_3RD_DWORD 0x31004000
#endif
...
...
include/ck_tile/core/numeric/float8.hpp
View file @
e4112de7
...
...
@@ -55,7 +55,7 @@ struct alignas(1) float8_e4m3_t
{
static
constexpr
int
exponent
=
4
;
static
constexpr
int
mantissa
=
3
;
#if defined(__gfx94
0__) || defined(__gfx941__) || defined(__gfx942
__)
#if defined(__gfx94__)
static
constexpr
int
bias
=
1
<<
(
exponent
-
1
);
// NANOO
#else
static
constexpr
int
bias
=
(
1
<<
(
exponent
-
1
))
-
1
;
// IEEE
...
...
@@ -113,7 +113,7 @@ struct alignas(1) float8_e5m2_t
{
static
constexpr
int
exponent
=
5
;
static
constexpr
int
mantissa
=
2
;
#if defined(__gfx94
0__) || defined(__gfx941__) || defined(__gfx942
__)
#if defined(__gfx94__)
static
constexpr
int
bias
=
1
<<
(
exponent
-
1
);
// NANOO
#else
static
constexpr
int
bias
=
(
1
<<
(
exponent
-
1
))
-
1
;
// IEEE
...
...
@@ -470,7 +470,7 @@ CK_TILE_HOST_DEVICE fp8_raw_t float_to_fp8_sr_raw(float x)
{
constexpr
int
seed
=
42
;
uint32_t
rng
=
prand_generator_t
<
float
,
seed
>
{}(
reinterpret_cast
<
uintptr_t
>
(
&
x
),
x
);
#if defined(__gfx94
0__) || defined(__gfx941__) || defined(__gfx942
__)
#if defined(__gfx94__)
float
max_fp8
=
240.0
f
;
x
=
x
>
max_fp8
?
max_fp8
:
(
x
<
-
max_fp8
?
-
max_fp8
:
x
);
union
...
...
@@ -500,7 +500,7 @@ CK_TILE_HOST_DEVICE bf8_raw_t float_to_bf8_sr_raw(float x)
{
constexpr
int
seed
=
42
;
uint32_t
rng
=
prand_generator_t
<
float
,
seed
>
{}(
reinterpret_cast
<
uintptr_t
>
(
&
x
),
x
);
#if defined(__gfx94
0__) || defined(__gfx941__) || defined(__gfx942
__)
#if defined(__gfx94__)
union
{
float
fval
;
...
...
@@ -526,7 +526,7 @@ CK_TILE_HOST_DEVICE bf8_raw_t float_to_bf8_sr_raw(float x)
CK_TILE_HOST_DEVICE
fp8_raw_t
float_to_fp8_rtn_raw
(
float
x
)
{
#if defined(__gfx94
0__) || defined(__gfx941__) || defined(__gfx942
__)
#if defined(__gfx94__)
float
max_fp8
=
240.0
f
;
x
=
x
>
max_fp8
?
max_fp8
:
(
x
<
-
max_fp8
?
-
max_fp8
:
x
);
union
...
...
@@ -554,7 +554,7 @@ CK_TILE_HOST_DEVICE fp8_raw_t float_to_fp8_rtn_raw(float x)
}
CK_TILE_HOST_DEVICE
bf8_raw_t
float_to_bf8_rtn_raw
(
float
x
)
{
#if defined(__gfx94
0__) || defined(__gfx941__) || defined(__gfx942
__)
#if defined(__gfx94__)
union
{
float
fval
;
...
...
@@ -598,7 +598,7 @@ CK_TILE_HOST_DEVICE bf8_raw_t float_to_bf8_raw(float x, constant<rounding>)
CK_TILE_HOST_DEVICE
float
fp8_to_float_raw
(
fp8_raw_t
x
)
{
#if defined(__gfx94
0__) || defined(__gfx941__) || defined(__gfx942
__)
#if defined(__gfx94__)
float
fval
;
uint32_t
i32val
=
static_cast
<
uint32_t
>
(
x
);
fval
=
__builtin_amdgcn_cvt_f32_fp8
(
i32val
,
0
);
...
...
@@ -612,7 +612,7 @@ CK_TILE_HOST_DEVICE float fp8_to_float_raw(fp8_raw_t x)
CK_TILE_HOST_DEVICE
float
bf8_to_float_raw
(
bf8_raw_t
x
)
{
#if defined(__gfx94
0__) || defined(__gfx941__) || defined(__gfx942
__)
#if defined(__gfx94__)
float
fval
;
uint32_t
i32val
=
static_cast
<
uint32_t
>
(
x
);
fval
=
__builtin_amdgcn_cvt_f32_bf8
(
i32val
,
0
);
...
...
@@ -656,7 +656,7 @@ struct numeric_traits<fp8_t>
{
static
constexpr
int
exp
=
4
;
static
constexpr
int
mant
=
3
;
#if defined(__gfx94
0__) || defined(__gfx941__) || defined(__gfx942
__)
#if defined(__gfx94__)
static
constexpr
int
bias
=
8
;
#else
static
constexpr
int
bias
=
7
;
...
...
@@ -668,7 +668,7 @@ struct numeric_traits<bf8_t>
{
static
constexpr
int
exp
=
5
;
static
constexpr
int
mant
=
2
;
#if defined(__gfx94
0__) || defined(__gfx941__) || defined(__gfx942
__)
#if defined(__gfx94__)
static
constexpr
int
bias
=
16
;
#else
static
constexpr
int
bias
=
15
;
// IEEE
...
...
include/ck_tile/core/tensor/tile_elementwise.hpp
View file @
e4112de7
...
...
@@ -112,7 +112,7 @@ namespace impl {
template
<
typename
OutDataType
,
typename
InTensor
>
CK_TILE_DEVICE
auto
cast_tile_pk_fp8x4
(
const
InTensor
&
in_dstr_tensors
)
{
#if defined(__gfx94
0__) || defined(__gfx941__) || defined(__gfx942
__)
#if defined(__gfx94__)
// This API is designed to use the _pk_ serious of function
constexpr
auto
in_tile_dstr
=
InTensor
::
get_tile_distribution
();
...
...
include/ck_tile/ops/gemm/warp/warp_gemm_attribute_mfma_impl.hpp
View file @
e4112de7
...
...
@@ -36,8 +36,7 @@ struct WarpGemmAttributeMfmaImplF16F16F32M32N32K8
CK_TILE_DEVICE
void
operator
()(
CVecType
&
c_vec
,
const
AVecType
&
a_vec
,
const
BVecType
&
b_vec
)
const
{
#if defined(__gfx908__) || defined(__gfx90a__) || defined(__gfx940__) || defined(__gfx941__) || \
defined(__gfx942__)
#if defined(__gfx9__)
c_vec
=
__builtin_amdgcn_mfma_f32_32x32x8f16
(
a_vec
,
b_vec
,
c_vec
,
0
,
0
,
0
);
#else
ck_tile
::
ignore
=
c_vec
;
...
...
@@ -49,8 +48,7 @@ struct WarpGemmAttributeMfmaImplF16F16F32M32N32K8
// c_vec = a_vec * b_vec
CK_TILE_DEVICE
CVecType
operator
()(
const
AVecType
&
a_vec
,
const
BVecType
&
b_vec
)
const
{
#if defined(__gfx908__) || defined(__gfx90a__) || defined(__gfx940__) || defined(__gfx941__) || \
defined(__gfx942__)
#if defined(__gfx9__)
return
bit_cast
<
CVecType
>
(
__builtin_amdgcn_mfma_f32_32x32x8f16
(
a_vec
,
b_vec
,
fp32x16_t
{
0.
f
},
0
,
0
,
0
));
#else
...
...
@@ -89,8 +87,7 @@ struct WarpGemmAttributeMfmaImplF16F16F32M16N16K16
CK_TILE_DEVICE
void
operator
()(
CVecType
&
c_vec
,
const
AVecType
&
a_vec
,
const
BVecType
&
b_vec
)
const
{
#if defined(__gfx908__) || defined(__gfx90a__) || defined(__gfx940__) || defined(__gfx941__) || \
defined(__gfx942__)
#if defined(__gfx9__)
c_vec
=
__builtin_amdgcn_mfma_f32_16x16x16f16
(
a_vec
,
b_vec
,
c_vec
,
0
,
0
,
0
);
#else
ck_tile
::
ignore
=
c_vec
;
...
...
@@ -102,8 +99,7 @@ struct WarpGemmAttributeMfmaImplF16F16F32M16N16K16
// c_vec = a_vec * b_vec
CK_TILE_DEVICE
CVecType
operator
()(
const
AVecType
&
a_vec
,
const
BVecType
&
b_vec
)
const
{
#if defined(__gfx908__) || defined(__gfx90a__) || defined(__gfx940__) || defined(__gfx941__) || \
defined(__gfx942__)
#if defined(__gfx9__)
return
bit_cast
<
CVecType
>
(
__builtin_amdgcn_mfma_f32_16x16x16f16
(
a_vec
,
b_vec
,
fp32x4_t
{
0.
f
},
0
,
0
,
0
));
#else
...
...
@@ -143,7 +139,7 @@ struct WarpGemmAttributeMfmaImplBf16Bf16F32M32N32K8
CK_TILE_DEVICE
void
operator
()(
CVecType
&
c_vec
,
const
AVecType
&
a_vec
,
const
BVecType
&
b_vec
)
const
{
#if defined(__gfx90a__) || defined(__gfx94
0__) || defined(__gfx941__) || defined(__gfx942
__)
#if defined(__gfx90a__) || defined(__gfx94__)
c_vec
=
__builtin_amdgcn_mfma_f32_32x32x8bf16_1k
(
a_vec
,
b_vec
,
c_vec
,
0
,
0
,
0
);
#elif defined(__gfx908__)
static_for
<
0
,
2
,
1
>
{}([
&
](
auto
k
)
{
...
...
@@ -167,7 +163,7 @@ struct WarpGemmAttributeMfmaImplBf16Bf16F32M32N32K8
// c_vec = a_vec * b_vec
CK_TILE_DEVICE
CVecType
operator
()(
const
AVecType
&
a_vec
,
const
BVecType
&
b_vec
)
const
{
#if defined(__gfx90a__) || defined(__gfx94
0__) || defined(__gfx941__) || defined(__gfx942
__)
#if defined(__gfx90a__) || defined(__gfx94__)
return
bit_cast
<
CVecType
>
(
__builtin_amdgcn_mfma_f32_32x32x8bf16_1k
(
a_vec
,
b_vec
,
fp32x16_t
{
0.
f
},
0
,
0
,
0
));
#elif defined(__gfx908__)
...
...
@@ -220,7 +216,7 @@ struct WarpGemmAttributeMfmaImplBf16Bf16F32M16N16K16
CK_TILE_DEVICE
void
operator
()(
CVecType
&
c_vec
,
const
AVecType
&
a_vec
,
const
BVecType
&
b_vec
)
const
{
#if defined(__gfx90a__) || defined(__gfx94
0__) || defined(__gfx941__) || defined(__gfx942
__)
#if defined(__gfx90a__) || defined(__gfx94__)
c_vec
=
__builtin_amdgcn_mfma_f32_16x16x16bf16_1k
(
a_vec
,
b_vec
,
c_vec
,
0
,
0
,
0
);
#elif defined(__gfx908__)
static_for
<
0
,
2
,
1
>
{}([
&
](
auto
k
)
{
...
...
@@ -244,7 +240,7 @@ struct WarpGemmAttributeMfmaImplBf16Bf16F32M16N16K16
// c_vec = a_vec * b_vec
CK_TILE_DEVICE
CVecType
operator
()(
const
AVecType
&
a_vec
,
const
BVecType
&
b_vec
)
const
{
#if defined(__gfx90a__) || defined(__gfx94
0__) || defined(__gfx941__) || defined(__gfx942
__)
#if defined(__gfx90a__) || defined(__gfx94__)
return
bit_cast
<
CVecType
>
(
__builtin_amdgcn_mfma_f32_16x16x16bf16_1k
(
a_vec
,
b_vec
,
fp32x4_t
{
0.
f
},
0
,
0
,
0
));
#elif defined(__gfx908__)
...
...
@@ -299,7 +295,7 @@ struct WarpGemmAttributeMfmaImpl_f32_32x32x16_f8_base
CK_TILE_DEVICE
void
operator
()(
CVecType
&
c_vec
,
const
AVecType
&
a_vec
,
const
BVecType
&
b_vec
)
const
{
#if defined(__gfx94
0__) || defined(__gfx941__) || defined(__gfx942
__)
#if defined(__gfx94__)
if
constexpr
(
std
::
is_same_v
<
ADataType
,
fp8_t
>
&&
std
::
is_same_v
<
BDataType
,
fp8_t
>
)
c_vec
=
__builtin_amdgcn_mfma_f32_32x32x16_fp8_fp8
(
bit_cast
<
long
>
(
a_vec
),
bit_cast
<
long
>
(
b_vec
),
c_vec
,
0
,
0
,
0
);
...
...
@@ -333,7 +329,7 @@ struct WarpGemmAttributeMfmaImpl_f32_32x32x16_f8_base
// c_vec = a_vec * b_vec
CK_TILE_DEVICE
CVecType
operator
()(
const
AVecType
&
a_vec
,
const
BVecType
&
b_vec
)
const
{
#if defined(__gfx94
0__) || defined(__gfx941__) || defined(__gfx942
__)
#if defined(__gfx94__)
if
constexpr
(
std
::
is_same_v
<
ADataType
,
fp8_t
>
&&
std
::
is_same_v
<
BDataType
,
fp8_t
>
)
return
bit_cast
<
CVecType
>
(
__builtin_amdgcn_mfma_f32_32x32x16_fp8_fp8
(
bit_cast
<
long
>
(
a_vec
),
bit_cast
<
long
>
(
b_vec
),
CVecType
{
0.
f
},
0
,
0
,
0
));
...
...
library/include/ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_weight/device_grouped_conv_bwd_weight_two_stage_xdl_instance.hpp
View file @
e4112de7
...
...
@@ -35,14 +35,24 @@ template <ck::index_t NDimSpatial,
typename
ALayout
,
typename
BLayout
,
typename
ELayout
,
ConvolutionBackwardWeightSpecialization
ConvSpec
>
ConvolutionBackwardWeightSpecialization
ConvSpec
,
BlockGemmPipelineScheduler
Scheduler
,
BlockGemmPipelineVersion
PipelineVersion
>
using
device_grouped_conv_bwd_weight_two_stage_xdl_c_shuffle_f16_instances
=
std
::
tuple
<
// clang-format off
//#########################################| Num| InLayout| WeiLayout| OutLayout| InData| WeiData| OutData| AccData| In| Wei| Out| ConvBackward| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransfer| CBlockTransfer|
//#########################################| Dim| | | | Type| Type| Type| Type| Elementwise| Elementwise| Elementwise| Weight| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| ClusterLengths| ScalarPerVector|
//#########################################| Spatial| | | | | | | | Operation| Operation| Operation| Specialization| | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| MBlock_MPerBlock| NWaveNPerXdl|
//#########################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | NBlock_NPerBlock| |
DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle
<
NDimSpatial
,
ALayout
,
BLayout
,
ELayout
,
F16
,
F16
,
F16
,
F32
,
PassThrough
,
PassThrough
,
PassThrough
,
ConvSpec
,
64
,
16
,
16
,
4
,
8
,
16
,
16
,
1
,
1
,
S
<
1
,
4
,
8
,
1
>
,
S
<
0
,
3
,
1
,
2
>
,
S
<
0
,
2
,
1
,
3
>
,
2
,
1
,
4
,
true
,
S
<
1
,
4
,
8
,
1
>
,
S
<
0
,
3
,
1
,
2
>
,
S
<
0
,
2
,
1
,
3
>
,
2
,
1
,
4
,
true
,
1
,
1
,
S
<
1
,
8
,
1
,
8
>
,
1
>
//#########################################| Num| InLayout| WeiLayout| OutLayout| InData| WeiData| OutData| AccData| In| Wei| Out| ConvBackward| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransfer| CBlockTransfer| BlockGemm| BlockGemm| NumBatch|
//#########################################| Dim| | | | Type| Type| Type| Type| Elementwise| Elementwise| Elementwise| Weight| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| ClusterLengths| ScalarPerVector| Pipeline| Pipeline| ToMerge|
//#########################################| Spatial| | | | | | | | Operation| Operation| Operation| Specialization| | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| MBlock_MPerBlock| NWaveNPerXdl| Scheduler| Version| |
//#########################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | NBlock_NPerBlock| | | | |
DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle
<
NDimSpatial
,
ALayout
,
BLayout
,
ELayout
,
F16
,
F16
,
F16
,
F32
,
PassThrough
,
PassThrough
,
PassThrough
,
ConvSpec
,
64
,
16
,
16
,
32
,
8
,
16
,
16
,
1
,
1
,
S
<
4
,
8
,
1
>
,
S
<
2
,
0
,
1
>
,
S
<
1
,
0
,
2
>
,
1
,
1
,
4
,
false
,
S
<
4
,
8
,
1
>
,
S
<
2
,
0
,
1
>
,
S
<
1
,
0
,
2
>
,
1
,
1
,
4
,
false
,
1
,
1
,
S
<
1
,
8
,
1
,
8
>
,
1
,
Scheduler
,
PipelineVersion
,
1
>
,
DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle
<
NDimSpatial
,
ALayout
,
BLayout
,
ELayout
,
F16
,
F16
,
F16
,
F32
,
PassThrough
,
PassThrough
,
PassThrough
,
ConvSpec
,
64
,
32
,
32
,
32
,
8
,
32
,
32
,
1
,
1
,
S
<
4
,
8
,
1
>
,
S
<
2
,
0
,
1
>
,
S
<
1
,
0
,
2
>
,
1
,
2
,
2
,
false
,
S
<
4
,
16
,
1
>
,
S
<
2
,
0
,
1
>
,
S
<
1
,
0
,
2
>
,
1
,
2
,
2
,
false
,
1
,
1
,
S
<
1
,
8
,
1
,
8
>
,
1
,
Scheduler
,
PipelineVersion
,
2
>
,
DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle
<
NDimSpatial
,
ALayout
,
BLayout
,
ELayout
,
F16
,
F16
,
F16
,
F32
,
PassThrough
,
PassThrough
,
PassThrough
,
ConvSpec
,
64
,
32
,
64
,
32
,
8
,
32
,
32
,
1
,
2
,
S
<
4
,
8
,
1
>
,
S
<
2
,
0
,
1
>
,
S
<
1
,
0
,
2
>
,
1
,
4
,
4
,
false
,
S
<
4
,
16
,
1
>
,
S
<
2
,
0
,
1
>
,
S
<
1
,
0
,
2
>
,
1
,
4
,
4
,
false
,
1
,
1
,
S
<
1
,
8
,
1
,
8
>
,
1
,
Scheduler
,
PipelineVersion
,
4
>
,
DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle
<
NDimSpatial
,
ALayout
,
BLayout
,
ELayout
,
F16
,
F16
,
F16
,
F32
,
PassThrough
,
PassThrough
,
PassThrough
,
ConvSpec
,
64
,
32
,
128
,
32
,
8
,
32
,
32
,
1
,
4
,
S
<
4
,
4
,
1
>
,
S
<
2
,
0
,
1
>
,
S
<
1
,
0
,
2
>
,
1
,
8
,
8
,
false
,
S
<
4
,
16
,
1
>
,
S
<
2
,
0
,
1
>
,
S
<
1
,
0
,
2
>
,
1
,
8
,
8
,
false
,
1
,
1
,
S
<
1
,
8
,
1
,
8
>
,
1
,
Scheduler
,
PipelineVersion
,
8
>
,
DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle
<
NDimSpatial
,
ALayout
,
BLayout
,
ELayout
,
F16
,
F16
,
F16
,
F32
,
PassThrough
,
PassThrough
,
PassThrough
,
ConvSpec
,
64
,
16
,
16
,
32
,
8
,
16
,
16
,
1
,
1
,
S
<
4
,
8
,
1
>
,
S
<
2
,
0
,
1
>
,
S
<
1
,
0
,
2
>
,
1
,
1
,
4
,
false
,
S
<
4
,
8
,
1
>
,
S
<
2
,
0
,
1
>
,
S
<
1
,
0
,
2
>
,
1
,
1
,
4
,
false
,
1
,
1
,
S
<
1
,
8
,
1
,
8
>
,
1
,
Scheduler
,
PipelineVersion
,
1
>
,
DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle
<
NDimSpatial
,
ALayout
,
BLayout
,
ELayout
,
F16
,
F16
,
F16
,
F32
,
PassThrough
,
PassThrough
,
PassThrough
,
ConvSpec
,
64
,
32
,
32
,
32
,
8
,
32
,
32
,
1
,
1
,
S
<
4
,
16
,
1
>
,
S
<
2
,
0
,
1
>
,
S
<
1
,
0
,
2
>
,
1
,
2
,
2
,
false
,
S
<
4
,
8
,
1
>
,
S
<
2
,
0
,
1
>
,
S
<
1
,
0
,
2
>
,
1
,
2
,
2
,
false
,
1
,
1
,
S
<
1
,
8
,
1
,
8
>
,
1
,
Scheduler
,
PipelineVersion
,
2
>
,
DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle
<
NDimSpatial
,
ALayout
,
BLayout
,
ELayout
,
F16
,
F16
,
F16
,
F32
,
PassThrough
,
PassThrough
,
PassThrough
,
ConvSpec
,
64
,
64
,
32
,
32
,
8
,
32
,
32
,
2
,
1
,
S
<
4
,
16
,
1
>
,
S
<
2
,
0
,
1
>
,
S
<
1
,
0
,
2
>
,
1
,
4
,
4
,
false
,
S
<
4
,
8
,
1
>
,
S
<
2
,
0
,
1
>
,
S
<
1
,
0
,
2
>
,
1
,
4
,
4
,
false
,
1
,
1
,
S
<
1
,
8
,
1
,
8
>
,
1
,
Scheduler
,
PipelineVersion
,
4
>
,
DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle
<
NDimSpatial
,
ALayout
,
BLayout
,
ELayout
,
F16
,
F16
,
F16
,
F32
,
PassThrough
,
PassThrough
,
PassThrough
,
ConvSpec
,
64
,
128
,
32
,
32
,
8
,
32
,
32
,
4
,
1
,
S
<
4
,
16
,
1
>
,
S
<
2
,
0
,
1
>
,
S
<
1
,
0
,
2
>
,
1
,
8
,
8
,
false
,
S
<
4
,
4
,
1
>
,
S
<
2
,
0
,
1
>
,
S
<
1
,
0
,
2
>
,
1
,
8
,
8
,
false
,
1
,
1
,
S
<
1
,
8
,
1
,
8
>
,
1
,
Scheduler
,
PipelineVersion
,
8
>
// clang-format on
>
;
...
...
library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_backward_weight.hpp
View file @
e4112de7
...
...
@@ -352,7 +352,9 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGroupe
{
add_device_grouped_conv2d_bwd_weight_xdl_nhwgc_gkyxc_nhwgk_f16_instances
(
op_ptrs
);
add_device_grouped_conv2d_bwd_weight_two_stage_xdl_nhwgc_gkyxc_nhwgk_f16_instances
(
add_device_grouped_conv2d_bwd_weight_two_stage_xdl_nhwgc_gkyxc_nhwgk_f16_pipev2_instances
(
op_ptrs
);
add_device_grouped_conv2d_bwd_weight_two_stage_xdl_nhwgc_gkyxc_nhwgk_f16_pipev5_instances
(
op_ptrs
);
}
#endif
...
...
@@ -421,7 +423,9 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGroupe
{
add_device_grouped_conv3d_bwd_weight_xdl_ndhwgc_gkzyxc_ndhwgk_f16_instances
(
op_ptrs
);
add_device_grouped_conv3d_bwd_weight_two_stage_xdl_ndhwgc_gkzyxc_ndhwgk_f16_instances
(
add_device_grouped_conv3d_bwd_weight_two_stage_xdl_ndhwgc_gkzyxc_ndhwgk_f16_pipev2_instances
(
op_ptrs
);
add_device_grouped_conv3d_bwd_weight_two_stage_xdl_ndhwgc_gkzyxc_ndhwgk_f16_pipev5_instances
(
op_ptrs
);
}
#endif
...
...
library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_backward_weight_xdl.inc
View file @
e4112de7
...
...
@@ -114,7 +114,19 @@ void add_device_grouped_conv2d_bwd_weight_xdl_nhwgc_gkyxc_nhwgk_f16_instances(
PassThrough
,
PassThrough
>>>&
instances
);
void
add_device_grouped_conv2d_bwd_weight_two_stage_xdl_nhwgc_gkyxc_nhwgk_f16_instances
(
void
add_device_grouped_conv2d_bwd_weight_two_stage_xdl_nhwgc_gkyxc_nhwgk_f16_pipev2_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGroupedConvBwdWeight
<
2
,
NHWGC
,
GKYXC
,
NHWGK
,
F16
,
F16
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
instances
);
void
add_device_grouped_conv2d_bwd_weight_two_stage_xdl_nhwgc_gkyxc_nhwgk_f16_pipev5_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGroupedConvBwdWeight
<
2
,
NHWGC
,
GKYXC
,
...
...
@@ -205,7 +217,19 @@ void add_device_grouped_conv3d_bwd_weight_xdl_ndhwgc_gkzyxc_ndhwgk_f16_instances
PassThrough
,
PassThrough
>>>&
instances
);
void
add_device_grouped_conv3d_bwd_weight_two_stage_xdl_ndhwgc_gkzyxc_ndhwgk_f16_instances
(
void
add_device_grouped_conv3d_bwd_weight_two_stage_xdl_ndhwgc_gkzyxc_ndhwgk_f16_pipev2_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGroupedConvBwdWeight
<
3
,
NDHWGC
,
GKZYXC
,
NDHWGK
,
F16
,
F16
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
instances
);
void
add_device_grouped_conv3d_bwd_weight_two_stage_xdl_ndhwgc_gkzyxc_ndhwgk_f16_pipev5_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGroupedConvBwdWeight
<
3
,
NDHWGC
,
GKZYXC
,
...
...
library/src/tensor_operation_instance/gpu/CMakeLists.txt
View file @
e4112de7
...
...
@@ -36,6 +36,13 @@ function(add_instance_library INSTANCE_NAME)
endif
()
endforeach
()
endif
()
if
(
INSTANCES_ONLY
)
set
(
INST_TARGETS
${
DEFAULT_GPU_TARGETS
}
)
else
()
set
(
INST_TARGETS
${
GPU_TARGETS
}
)
endif
()
# Do not build DL instances if DL_KERNELS macro is not set
foreach
(
source IN LISTS ARGN
)
if
(
NOT DEFINED DL_KERNELS AND source MATCHES
"_dl"
)
...
...
@@ -45,21 +52,40 @@ function(add_instance_library INSTANCE_NAME)
endforeach
()
# Do not build XDL instances if gfx9 targets are not on the target list
foreach
(
source IN LISTS ARGN
)
if
(
NOT
GPU
_TARGETS MATCHES
"gfx9"
AND source MATCHES
"_xdl"
)
if
(
NOT
INST
_TARGETS MATCHES
"gfx9"
AND source MATCHES
"_xdl"
)
message
(
"removing xdl instance
${
source
}
"
)
list
(
REMOVE_ITEM ARGN
"
${
source
}
"
)
endif
()
endforeach
()
# Do not build WMMA instances if gfx11 targets are not on the target list
foreach
(
source IN LISTS ARGN
)
if
(
NOT
GPU
_TARGETS MATCHES
"gfx11"
AND source MATCHES
"_wmma"
)
if
(
NOT
INST
_TARGETS MATCHES
"gfx11"
AND source MATCHES
"_wmma"
)
message
(
"removing wmma instance
${
source
}
"
)
list
(
REMOVE_ITEM ARGN
"
${
source
}
"
)
endif
()
endforeach
()
#only continue if there are some source files left on the list
if
(
ARGN
)
add_library
(
${
INSTANCE_NAME
}
OBJECT
${
ARGN
}
)
set
(
INST_OBJ
)
foreach
(
source IN LISTS ARGN
)
if
(
INSTANCES_ONLY
)
set
(
INST_TARGETS
${
DEFAULT_GPU_TARGETS
}
)
else
()
set
(
INST_TARGETS
${
GPU_TARGETS
}
)
endif
()
if
(
source MATCHES
"_xdl"
)
list
(
REMOVE_ITEM INST_TARGETS gfx900 gfx906 gfx1030 gfx1100 gfx1101 gfx1102 gfx1103
)
elseif
(
ARGN MATCHES
"_wmma"
)
list
(
REMOVE_ITEM INST_TARGETS gfx900 gfx906 gfx908 gfx90a gfx940 gfx941 gfx942 gfx1030
)
endif
()
set
(
offload_targets
)
foreach
(
target IN LISTS INST_TARGETS
)
string
(
APPEND offload_targets
"--offload-arch=
${
target
}
"
)
endforeach
()
set_source_files_properties
(
${
source
}
PROPERTIES COMPILE_FLAGS
${
offload_targets
}
)
list
(
APPEND INST_OBJ
${
source
}
)
endforeach
()
add_library
(
${
INSTANCE_NAME
}
OBJECT
${
INST_OBJ
}
)
target_compile_features
(
${
INSTANCE_NAME
}
PUBLIC
)
set_target_properties
(
${
INSTANCE_NAME
}
PROPERTIES POSITION_INDEPENDENT_CODE ON
)
clang_tidy_check
(
${
INSTANCE_NAME
}
)
...
...
@@ -131,6 +157,14 @@ FOREACH(subdir_path ${dir_list})
if
(
NOT DEFINED DTYPES
)
set
(
add_inst 1
)
endif
()
if
(
INSTANCES_ONLY
)
set
(
INST_TARGETS
${
DEFAULT_GPU_TARGETS
}
)
else
()
set
(
INST_TARGETS
${
GPU_TARGETS
}
)
endif
()
if
((
"
${
cmake_instance
}
"
MATCHES
"quantization"
)
AND
(
DEFINED DTYPES
)
AND
(
NOT DTYPES MATCHES
"int8"
))
message
(
"quantization instances will not be built!"
)
set
(
add_inst 0
)
...
...
@@ -139,23 +173,23 @@ FOREACH(subdir_path ${dir_list})
message
(
"Found only dl instances, but DL_KERNELS is not set. Skipping."
)
set
(
add_inst 0
)
endif
()
if
((
"
${
cmake_instance
}
"
MATCHES
"ONLY XDL_KERNELS"
)
AND
(
NOT
GPU
_TARGETS MATCHES
"gfx9"
))
if
((
"
${
cmake_instance
}
"
MATCHES
"ONLY XDL_KERNELS"
)
AND
(
NOT
INST
_TARGETS MATCHES
"gfx9"
))
message
(
"Found only xdl instances, but gfx9 is not on the targets list. Skipping."
)
set
(
add_inst 0
)
endif
()
if
((
"
${
cmake_instance
}
"
MATCHES
"ONLY WMMA_KERNELS"
)
AND
(
NOT
GPU
_TARGETS MATCHES
"gfx11"
))
if
((
"
${
cmake_instance
}
"
MATCHES
"ONLY WMMA_KERNELS"
)
AND
(
NOT
INST
_TARGETS MATCHES
"gfx11"
))
message
(
"Found only wmma instances, but gfx11 is not on the targets list. Skipping."
)
set
(
add_inst 0
)
endif
()
if
((
"
${
cmake_instance
}
"
MATCHES
"ONLY XDL_AND_DL_KERNELS"
)
AND
(
NOT DEFINED DL_KERNELS
)
AND
(
NOT
GPU
_TARGETS MATCHES
"gfx9"
))
if
((
"
${
cmake_instance
}
"
MATCHES
"ONLY XDL_AND_DL_KERNELS"
)
AND
(
NOT DEFINED DL_KERNELS
)
AND
(
NOT
INST
_TARGETS MATCHES
"gfx9"
))
message
(
"Found only xdl and dl instances, but gfx9 is not on the targets listand DL_KERNELS is not set. Skipping."
)
set
(
add_inst 0
)
endif
()
if
((
"
${
cmake_instance
}
"
MATCHES
"ONLY XDL_AND_WMMA_KERNELS"
)
AND
(
NOT
GPU
_TARGETS MATCHES
"gfx11"
)
AND
(
NOT
GPU
_TARGETS MATCHES
"gfx9"
))
if
((
"
${
cmake_instance
}
"
MATCHES
"ONLY XDL_AND_WMMA_KERNELS"
)
AND
(
NOT
INST
_TARGETS MATCHES
"gfx11"
)
AND
(
NOT
INST
_TARGETS MATCHES
"gfx9"
))
message
(
"Found only xdl and wmma instances, but gfx11 and gfx9 are not on the targets list. Skipping."
)
set
(
add_inst 0
)
endif
()
if
((
"
${
cmake_instance
}
"
MATCHES
"XDL_DL_WMMA_KERNELS"
)
AND
(
NOT
GPU
_TARGETS MATCHES
"gfx11"
)
AND
(
NOT
GPU
_TARGETS MATCHES
"gfx9"
)
AND
(
NOT DEFINED DL_KERNELS
))
if
((
"
${
cmake_instance
}
"
MATCHES
"XDL_DL_WMMA_KERNELS"
)
AND
(
NOT
INST
_TARGETS MATCHES
"gfx11"
)
AND
(
NOT
INST
_TARGETS MATCHES
"gfx9"
)
AND
(
NOT DEFINED DL_KERNELS
))
message
(
"Found xdl, dl, and wmma instances, but none of those meet the target list. Skipping."
)
set
(
add_inst 0
)
endif
()
...
...
library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_weight/CMakeLists.txt
View file @
e4112de7
...
...
@@ -6,7 +6,9 @@ set(GROUPED_CONV2D_BWD_WEIGHT
xdl/device_grouped_conv2d_bwd_weight_xdl_nhwgc_gkyxc_nhwgk_f16_instance.cpp
xdl/device_grouped_conv2d_bwd_weight_xdl_nhwgc_gkyxc_nhwgk_f32_instance.cpp
xdl/device_grouped_conv2d_bwd_weight_xdl_nhwgc_gkyxc_nhwgk_bf16_instance.cpp
xdl/device_grouped_conv2d_bwd_weight_two_stage_xdl_nhwgc_gkyxc_nhwgk_f16_instance.cpp
)
xdl/device_grouped_conv2d_bwd_weight_two_stage_xdl_nhwgc_gkyxc_nhwgk_f16_pipev2_instance.cpp
xdl/device_grouped_conv2d_bwd_weight_two_stage_xdl_nhwgc_gkyxc_nhwgk_f16_pipev5_instance.cpp
)
if
(
DL_KERNELS
)
list
(
APPEND GROUPED_CONV2D_BWD_WEIGHT
...
...
library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_weight/xdl/device_grouped_conv2d_bwd_weight_two_stage_xdl_nhwgc_gkyxc_nhwgk_f16_instance.cpp
→
library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_weight/xdl/device_grouped_conv2d_bwd_weight_two_stage_xdl_nhwgc_gkyxc_nhwgk_f16_
pipev2_
instance.cpp
View file @
e4112de7
...
...
@@ -10,7 +10,7 @@ namespace device {
namespace
instance
{
// Compilation parameters for in[n, hi, wi, g, c] * wei[g, k, y, x, c] = out[n, ho, wo, g, k]
void
add_device_grouped_conv2d_bwd_weight_two_stage_xdl_nhwgc_gkyxc_nhwgk_f16_instances
(
void
add_device_grouped_conv2d_bwd_weight_two_stage_xdl_nhwgc_gkyxc_nhwgk_f16_
pipev2_
instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGroupedConvBwdWeight
<
2
,
NHWGC
,
GKYXC
,
...
...
@@ -30,16 +30,9 @@ void add_device_grouped_conv2d_bwd_weight_two_stage_xdl_nhwgc_gkyxc_nhwgk_f16_in
NHWGC
,
GKYXC
,
NHWGK
,
ConvBwdWeightDefault
>
{});
// 2. Filter1x1Stride1Pad0
add_device_operation_instances
(
instances
,
device_grouped_conv_bwd_weight_two_stage_xdl_c_shuffle_f16_instances
<
2
,
NHWGC
,
GKYXC
,
NHWGK
,
ConvBwdWeightFilter1x1Stride1Pad0
>
{});
ConvBwdWeightDefault
,
BlockGemmPipelineScheduler
::
Intrawave
,
BlockGemmPipelineVersion
::
v2
>
{});
}
}
// namespace instance
...
...
library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_weight/xdl/device_grouped_conv2d_bwd_weight_two_stage_xdl_nhwgc_gkyxc_nhwgk_f16_pipev5_instance.cpp
0 → 100644
View file @
e4112de7
// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
#include "ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_weight/device_grouped_conv_bwd_weight_two_stage_xdl_instance.hpp"
namespace
ck
{
namespace
tensor_operation
{
namespace
device
{
namespace
instance
{
// Compilation parameters for in[n, hi, wi, g, c] * wei[g, k, y, x, c] = out[n, ho, wo, g, k]
void
add_device_grouped_conv2d_bwd_weight_two_stage_xdl_nhwgc_gkyxc_nhwgk_f16_pipev5_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGroupedConvBwdWeight
<
2
,
NHWGC
,
GKYXC
,
NHWGK
,
F16
,
F16
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
instances
)
{
// 1. Default
add_device_operation_instances
(
instances
,
device_grouped_conv_bwd_weight_two_stage_xdl_c_shuffle_f16_instances
<
2
,
NHWGC
,
GKYXC
,
NHWGK
,
ConvBwdWeightDefault
,
BlockGemmPipelineScheduler
::
Intrawave
,
BlockGemmPipelineVersion
::
v5
>
{});
}
}
// namespace instance
}
// namespace device
}
// namespace tensor_operation
}
// namespace ck
library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_weight/CMakeLists.txt
View file @
e4112de7
# XDL_DL_WMMA_KERNELS
# XDL_DL_WMMA_KERNELS
set
(
GROUPED_CONV3D_BWD_WEIGHT
xdl/device_grouped_conv3d_bwd_weight_xdl_gndhwc_gkzyxc_gndhwk_f16_instance.cpp
xdl/device_grouped_conv3d_bwd_weight_xdl_gndhwc_gkzyxc_gndhwk_f32_instance.cpp
xdl/device_grouped_conv3d_bwd_weight_xdl_gndhwc_gkzyxc_gndhwk_bf16_instance.cpp
xdl/device_grouped_conv3d_bwd_weight_xdl_ndhwgc_gkzyxc_ndhwgk_f16_instance.cpp
xdl/device_grouped_conv3d_bwd_weight_xdl_ndhwgc_gkzyxc_ndhwgk_f32_instance.cpp
xdl/device_grouped_conv3d_bwd_weight_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_instance.cpp
xdl/device_grouped_conv3d_bwd_weight_two_stage_xdl_ndhwgc_gkzyxc_ndhwgk_f16_instance.cpp
)
xdl/device_grouped_conv3d_bwd_weight_xdl_gndhwc_gkzyxc_gndhwk_f16_instance.cpp
xdl/device_grouped_conv3d_bwd_weight_xdl_gndhwc_gkzyxc_gndhwk_f32_instance.cpp
xdl/device_grouped_conv3d_bwd_weight_xdl_gndhwc_gkzyxc_gndhwk_bf16_instance.cpp
xdl/device_grouped_conv3d_bwd_weight_xdl_ndhwgc_gkzyxc_ndhwgk_f16_instance.cpp
xdl/device_grouped_conv3d_bwd_weight_xdl_ndhwgc_gkzyxc_ndhwgk_f32_instance.cpp
xdl/device_grouped_conv3d_bwd_weight_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_instance.cpp
xdl/device_grouped_conv3d_bwd_weight_two_stage_xdl_ndhwgc_gkzyxc_ndhwgk_f16_pipev2_instance.cpp
xdl/device_grouped_conv3d_bwd_weight_two_stage_xdl_ndhwgc_gkzyxc_ndhwgk_f16_pipev5_instance.cpp
)
if
(
DL_KERNELS
)
list
(
APPEND GROUPED_CONV3D_BWD_WEIGHT
...
...
library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_weight/xdl/device_grouped_conv3d_bwd_weight_two_stage_xdl_ndhwgc_gkzyxc_ndhwgk_f16_instance.cpp
→
library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_weight/xdl/device_grouped_conv3d_bwd_weight_two_stage_xdl_ndhwgc_gkzyxc_ndhwgk_f16_
pipev2_
instance.cpp
View file @
e4112de7
...
...
@@ -10,7 +10,7 @@ namespace device {
namespace
instance
{
// Compilation parameters for in[n, hi, wi, g, c] * wei[g, k, y, x, c] = out[n, ho, wo, g, k]
void
add_device_grouped_conv3d_bwd_weight_two_stage_xdl_ndhwgc_gkzyxc_ndhwgk_f16_instances
(
void
add_device_grouped_conv3d_bwd_weight_two_stage_xdl_ndhwgc_gkzyxc_ndhwgk_f16_
pipev2_
instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGroupedConvBwdWeight
<
3
,
NDHWGC
,
GKZYXC
,
...
...
@@ -30,16 +30,9 @@ void add_device_grouped_conv3d_bwd_weight_two_stage_xdl_ndhwgc_gkzyxc_ndhwgk_f16
NDHWGC
,
GKZYXC
,
NDHWGK
,
ConvBwdWeightDefault
>
{});
// 2. Filter1x1Stride1Pad0
add_device_operation_instances
(
instances
,
device_grouped_conv_bwd_weight_two_stage_xdl_c_shuffle_f16_instances
<
3
,
NDHWGC
,
GKZYXC
,
NDHWGK
,
ConvBwdWeightFilter1x1Stride1Pad0
>
{});
ConvBwdWeightDefault
,
BlockGemmPipelineScheduler
::
Intrawave
,
BlockGemmPipelineVersion
::
v2
>
{});
}
}
// namespace instance
...
...
library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_weight/xdl/device_grouped_conv3d_bwd_weight_two_stage_xdl_ndhwgc_gkzyxc_ndhwgk_f16_pipev5_instance.cpp
0 → 100644
View file @
e4112de7
// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
#include "ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_weight/device_grouped_conv_bwd_weight_two_stage_xdl_instance.hpp"
namespace
ck
{
namespace
tensor_operation
{
namespace
device
{
namespace
instance
{
// Compilation parameters for in[n, hi, wi, g, c] * wei[g, k, y, x, c] = out[n, ho, wo, g, k]
void
add_device_grouped_conv3d_bwd_weight_two_stage_xdl_ndhwgc_gkzyxc_ndhwgk_f16_pipev5_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGroupedConvBwdWeight
<
3
,
NDHWGC
,
GKZYXC
,
NDHWGK
,
F16
,
F16
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
instances
)
{
// 1. Default
add_device_operation_instances
(
instances
,
device_grouped_conv_bwd_weight_two_stage_xdl_c_shuffle_f16_instances
<
3
,
NDHWGC
,
GKZYXC
,
NDHWGK
,
ConvBwdWeightDefault
,
BlockGemmPipelineScheduler
::
Intrawave
,
BlockGemmPipelineVersion
::
v5
>
{});
}
}
// namespace instance
}
// namespace device
}
// namespace tensor_operation
}
// namespace ck
test/CMakeLists.txt
View file @
e4112de7
...
...
@@ -40,6 +40,13 @@ function(add_test_executable TEST_NAME)
endif
()
endforeach
()
endif
()
if
(
INSTANCES_ONLY
)
set
(
TEST_TARGETS
${
DEFAULT_GPU_TARGETS
}
)
else
()
set
(
TEST_TARGETS
${
GPU_TARGETS
}
)
endif
()
foreach
(
source IN LISTS ARGN
)
if
(
NOT DEFINED DL_KERNELS AND source MATCHES
"_dl"
)
message
(
"removing dl test
${
source
}
"
)
...
...
@@ -47,20 +54,27 @@ function(add_test_executable TEST_NAME)
endif
()
endforeach
()
foreach
(
source IN LISTS ARGN
)
if
(
NOT
GPU
_TARGETS MATCHES
"gfx9"
AND source MATCHES
"xdl"
)
if
(
NOT
TEST
_TARGETS MATCHES
"gfx9"
AND source MATCHES
"xdl"
)
message
(
"removing xdl test
${
source
}
"
)
list
(
REMOVE_ITEM ARGN
"
${
source
}
"
)
endif
()
endforeach
()
foreach
(
source IN LISTS ARGN
)
if
(
NOT
GPU
_TARGETS MATCHES
"gfx11"
AND source MATCHES
"wmma"
)
if
(
NOT
TEST
_TARGETS MATCHES
"gfx11"
AND source MATCHES
"wmma"
)
message
(
"removing wmma test
${
source
}
"
)
list
(
REMOVE_ITEM ARGN
"
${
source
}
"
)
endif
()
endforeach
()
#only continue if there are some source files left on the list
if
(
ARGN
)
if
(
ARGN MATCHES
"_xdl"
)
list
(
REMOVE_ITEM TEST_TARGETS gfx1030 gfx1100 gfx1101 gfx1102 gfx1103
)
elseif
(
ARGN MATCHES
"_wmma"
)
list
(
REMOVE_ITEM TEST_TARGETS gfx908 gfx90a gfx940 gfx941 gfx942 gfx1030
)
endif
()
set_source_files_properties
(
${
ARGN
}
PROPERTIES LANGUAGE HIP
)
add_executable
(
${
TEST_NAME
}
${
ARGN
}
)
set_property
(
TARGET
${
TEST_NAME
}
PROPERTY HIP_ARCHITECTURES
${
TEST_TARGETS
}
)
target_link_libraries
(
${
TEST_NAME
}
PRIVATE getopt::getopt
)
add_test
(
NAME
${
TEST_NAME
}
COMMAND $<TARGET_FILE:
${
TEST_NAME
}
>
)
add_dependencies
(
tests
${
TEST_NAME
}
)
...
...
@@ -105,6 +119,13 @@ function(add_gtest_executable TEST_NAME)
endif
()
endforeach
()
endif
()
if
(
INSTANCES_ONLY
)
set
(
TEST_TARGETS
${
DEFAULT_GPU_TARGETS
}
)
else
()
set
(
TEST_TARGETS
${
GPU_TARGETS
}
)
endif
()
foreach
(
source IN LISTS ARGN
)
if
(
NOT DEFINED DL_KERNELS AND source MATCHES
"_dl"
)
message
(
"removing dl test
${
source
}
"
)
...
...
@@ -112,20 +133,27 @@ function(add_gtest_executable TEST_NAME)
endif
()
endforeach
()
foreach
(
source IN LISTS ARGN
)
if
(
NOT
GPU
_TARGETS MATCHES
"gfx9"
AND source MATCHES
"xdl"
)
if
(
NOT
TEST
_TARGETS MATCHES
"gfx9"
AND source MATCHES
"xdl"
)
message
(
"removing xdl test
${
source
}
"
)
list
(
REMOVE_ITEM ARGN
"
${
source
}
"
)
endif
()
endforeach
()
foreach
(
source IN LISTS ARGN
)
if
(
NOT
GPU
_TARGETS MATCHES
"gfx11"
AND source MATCHES
"wmma"
)
if
(
NOT
TEST
_TARGETS MATCHES
"gfx11"
AND source MATCHES
"wmma"
)
message
(
"removing wmma test
${
source
}
"
)
list
(
REMOVE_ITEM ARGN
"
${
source
}
"
)
endif
()
endforeach
()
#only continue if there are some source files left on the list
if
(
ARGN
)
if
(
ARGN MATCHES
"_xdl"
)
list
(
REMOVE_ITEM TEST_TARGETS gfx900 gfx906 gfx1030 gfx1100 gfx1101 gfx1102 gfx1103
)
elseif
(
ARGN MATCHES
"_wmma"
)
list
(
REMOVE_ITEM TEST_TARGETS gfx900 gfx906 gfx908 gfx90a gfx940 gfx941 gfx942 gfx1030
)
endif
()
set_source_files_properties
(
${
ARGN
}
PROPERTIES LANGUAGE HIP
)
add_executable
(
${
TEST_NAME
}
${
ARGN
}
)
set_property
(
TARGET
${
TEST_NAME
}
PROPERTY HIP_ARCHITECTURES
${
TEST_TARGETS
}
)
add_dependencies
(
tests
${
TEST_NAME
}
)
add_dependencies
(
check
${
TEST_NAME
}
)
...
...
test/grouped_convnd_bwd_weight/test_grouped_convnd_bwd_weight.cpp
View file @
e4112de7
...
...
@@ -32,19 +32,8 @@ class TestGroupedConvndBwdWeight : public ::testing::Test
std
::
vector
<
ck
::
utils
::
conv
::
ConvParam
>
conv_params
;
std
::
vector
<
ck
::
index_t
>
split_ks
{
1
,
2
};
bool
skip_case
(
const
ck
::
utils
::
conv
::
ConvParam
&
params
,
const
ck
::
index_t
split_k
)
bool
skip_case
(
const
ck
::
index_t
split_k
)
{
// Odd K or C values are supported only by DL and WMMA
// kernels (only applies to fp16)
// DL and WMMA kernels currently support only `split_k=1`
if
constexpr
(
std
::
is_same_v
<
InDataType
,
ck
::
half_t
>
)
{
if
(
split_k
!=
1
&&
(
params
.
K_
%
2
!=
0
||
params
.
C_
%
2
!=
0
))
{
return
true
;
}
}
// 1d NWGC is only supported by DL kernel
// DL kernel is only supported for split_k=1
if
constexpr
(
std
::
is_same_v
<
InLayout
,
NWGC
>
&&
std
::
is_same_v
<
OutLayout
,
NWGK
>
)
...
...
@@ -100,7 +89,7 @@ class TestGroupedConvndBwdWeight : public ::testing::Test
{
for
(
auto
&
param
:
conv_params
)
{
if
(
!
skip_case
(
param
,
split_k
))
if
(
!
skip_case
(
split_k
))
{
pass
=
pass
&&
ck
::
profiler
::
profile_grouped_conv_bwd_weight_impl
<
NDimSpatial
{},
InLayout
,
...
...
@@ -189,6 +178,8 @@ TYPED_TEST(TestGroupedConvndBwdWeight2d, Test2D)
this
->
conv_params
.
push_back
({
2
,
1
,
1
,
1
,
32
,
{
3
,
3
},
{
32
,
32
},
{
1
,
1
},
{
1
,
1
},
{
1
,
1
},
{
1
,
1
}});
this
->
conv_params
.
push_back
({
2
,
1
,
1
,
64
,
3
,
{
3
,
3
},
{
32
,
32
},
{
1
,
1
},
{
1
,
1
},
{
1
,
1
},
{
1
,
1
}});
this
->
conv_params
.
push_back
({
2
,
1
,
1
,
1
,
1
,
{
3
,
3
},
{
32
,
32
},
{
1
,
1
},
{
1
,
1
},
{
1
,
1
},
{
1
,
1
}});
this
->
conv_params
.
push_back
(
{
2
,
16
,
16
,
1
,
1
,
{
3
,
3
},
{
28
,
28
},
{
2
,
2
},
{
1
,
1
},
{
1
,
1
},
{
1
,
1
}});
this
->
Run
();
}
...
...
@@ -207,5 +198,7 @@ TYPED_TEST(TestGroupedConvndBwdWeight3d, Test3D)
{
3
,
1
,
1
,
64
,
3
,
{
3
,
3
,
3
},
{
32
,
32
,
32
},
{
1
,
1
,
1
},
{
1
,
1
,
1
},
{
1
,
1
,
1
},
{
1
,
1
,
1
}});
this
->
conv_params
.
push_back
(
{
3
,
1
,
1
,
1
,
1
,
{
3
,
3
,
3
},
{
32
,
32
,
32
},
{
1
,
1
,
1
},
{
1
,
1
,
1
},
{
1
,
1
,
1
},
{
1
,
1
,
1
}});
this
->
conv_params
.
push_back
(
{
3
,
16
,
16
,
1
,
1
,
{
3
,
3
,
3
},
{
28
,
28
,
28
},
{
2
,
2
,
2
},
{
1
,
1
,
1
},
{
1
,
1
,
1
},
{
1
,
1
,
1
}});
this
->
Run
();
}
Prev
1
2
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment