Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
7830272f
Commit
7830272f
authored
Oct 17, 2023
by
Artur Wojcik
Browse files
Merge branch 'develop' into uif2-initial
parents
2b8a9941
16d7c4d2
Changes
71
Hide whitespace changes
Inline
Side-by-side
Showing
11 changed files
with
268 additions
and
14 deletions
+268
-14
library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_weight/xdl/device_grouped_conv3d_bwd_weight_xdl_gndhwc_gkzyxc_gndhwk_f32_instance.cpp
...nv3d_bwd_weight_xdl_gndhwc_gkzyxc_gndhwk_f32_instance.cpp
+0
-0
library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_weight/xdl/device_grouped_conv3d_bwd_weight_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_instance.cpp
...v3d_bwd_weight_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_instance.cpp
+0
-0
library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_weight/xdl/device_grouped_conv3d_bwd_weight_xdl_ndhwgc_gkzyxc_ndhwgk_f16_comp_bf8_fp8_instance.cpp
...ht_xdl_ndhwgc_gkzyxc_ndhwgk_f16_comp_bf8_fp8_instance.cpp
+0
-0
library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_weight/xdl/device_grouped_conv3d_bwd_weight_xdl_ndhwgc_gkzyxc_ndhwgk_f16_instance.cpp
...nv3d_bwd_weight_xdl_ndhwgc_gkzyxc_ndhwgk_f16_instance.cpp
+0
-0
library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_weight/xdl/device_grouped_conv3d_bwd_weight_xdl_ndhwgc_gkzyxc_ndhwgk_f32_instance.cpp
...nv3d_bwd_weight_xdl_ndhwgc_gkzyxc_ndhwgk_f32_instance.cpp
+0
-0
profiler/README.md
profiler/README.md
+4
-2
profiler/src/profile_grouped_conv_bwd_weight.cpp
profiler/src/profile_grouped_conv_bwd_weight.cpp
+17
-5
test/grouped_convnd_bwd_weight/CMakeLists.txt
test/grouped_convnd_bwd_weight/CMakeLists.txt
+13
-4
test/grouped_convnd_bwd_weight/test_grouped_convnd_bwd_weight.cpp
...uped_convnd_bwd_weight/test_grouped_convnd_bwd_weight.cpp
+43
-3
test/grouped_convnd_bwd_weight/test_grouped_convnd_bwd_weight_interface_wmma.cpp
..._weight/test_grouped_convnd_bwd_weight_interface_wmma.cpp
+191
-0
test/grouped_convnd_bwd_weight/test_grouped_convnd_bwd_weight_interface_xdl.cpp
...d_weight/test_grouped_convnd_bwd_weight_interface_xdl.cpp
+0
-0
No files found.
library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_weight/device_grouped_conv3d_bwd_weight_xdl_gndhwc_gkzyxc_gndhwk_f32_instance.cpp
→
library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_weight/
xdl/
device_grouped_conv3d_bwd_weight_xdl_gndhwc_gkzyxc_gndhwk_f32_instance.cpp
View file @
7830272f
File moved
library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_weight/device_grouped_conv3d_bwd_weight_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_instance.cpp
→
library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_weight/
xdl/
device_grouped_conv3d_bwd_weight_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_instance.cpp
View file @
7830272f
File moved
library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_weight/device_grouped_conv3d_bwd_weight_xdl_ndhwgc_gkzyxc_ndhwgk_f16_comp_bf8_fp8_instance.cpp
→
library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_weight/
xdl/
device_grouped_conv3d_bwd_weight_xdl_ndhwgc_gkzyxc_ndhwgk_f16_comp_bf8_fp8_instance.cpp
View file @
7830272f
File moved
library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_weight/device_grouped_conv3d_bwd_weight_xdl_ndhwgc_gkzyxc_ndhwgk_f16_instance.cpp
→
library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_weight/
xdl/
device_grouped_conv3d_bwd_weight_xdl_ndhwgc_gkzyxc_ndhwgk_f16_instance.cpp
View file @
7830272f
File moved
library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_weight/device_grouped_conv3d_bwd_weight_xdl_ndhwgc_gkzyxc_ndhwgk_f32_instance.cpp
→
library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_weight/
xdl/
device_grouped_conv3d_bwd_weight_xdl_ndhwgc_gkzyxc_ndhwgk_f32_instance.cpp
View file @
7830272f
File moved
profiler/README.md
View file @
7830272f
...
@@ -147,7 +147,9 @@ GB/s: 127.947
...
@@ -147,7 +147,9 @@ GB/s: 127.947
# arg1: tensor operation (grouped_conv_bwd_weight: Grouped Convolution Backward Weight)
# arg1: tensor operation (grouped_conv_bwd_weight: Grouped Convolution Backward Weight)
# arg2: data type (0: Input fp32, Weight fp32, Output fp32
# arg2: data type (0: Input fp32, Weight fp32, Output fp32
# 1: Input fp16, Weight fp16, Output fp16
# 1: Input fp16, Weight fp16, Output fp16
# 2: Input bf16, Weight fp32, Output bf16)
# 2: Input bf16, Weight fp32, Output bf16
# 3: Input fp16, Weight fp16, Output fp16, Gemm bf8@fp8
# 4: Input int8, Weight int8, Output int8)
# arg3: tensor layout (0: Input[G, N, C, Hi, Wi], Weight[G, K, C, Y, X], Output[G, N, K, Ho, Wo]
# arg3: tensor layout (0: Input[G, N, C, Hi, Wi], Weight[G, K, C, Y, X], Output[G, N, K, Ho, Wo]
# 1: Input[G, N, Hi, Wi, C], Weight[G, K, Y, X, C], Output[G, N, Ho, Wo, K]
# 1: Input[G, N, Hi, Wi, C], Weight[G, K, Y, X, C], Output[G, N, Ho, Wo, K]
# 2: Input[N, Hi, Wi, G, C], Weight[G, K, Y, X, C], Output[N, Ho, Wo, G, K]
# 2: Input[N, Hi, Wi, G, C], Weight[G, K, Y, X, C], Output[N, Ho, Wo, G, K]
...
@@ -167,7 +169,7 @@ GB/s: 127.947
...
@@ -167,7 +169,7 @@ GB/s: 127.947
# SplitK
# SplitK
################ op datatype layout verify init log time Ndims G N K C Y X Hi Wi Sy Sx Dy Dx LeftPy LeftPx RightPy RightPx SplitK
################ op datatype layout verify init log time Ndims G N K C Y X Hi Wi Sy Sx Dy Dx LeftPy LeftPx RightPy RightPx SplitK
./bin/ckProfiler grouped_conv_bwd_weight
1
0
1
1 0 1 2 32 256 256 512 3 3 28 28 1 1 1 1 1 0 0 0 1
./bin/ckProfiler grouped_conv_bwd_weight 1
1
0
1 0 1 2 32 256 256 512 3 3 28 28 1 1 1 1 1 0 0 0 1
```
```
...
...
profiler/src/profile_grouped_conv_bwd_weight.cpp
View file @
7830272f
...
@@ -20,10 +20,11 @@ enum struct ConvLayout
...
@@ -20,10 +20,11 @@ enum struct ConvLayout
enum
struct
ConvDataType
enum
struct
ConvDataType
{
{
F32_F32_F32
,
// 0
F32_F32_F32
,
// 0
F16_F16_F16
,
// 1
F16_F16_F16
,
// 1
BF16_F32_BF16
,
// 2
BF16_F32_BF16
,
// 2
F16_F16_F16_BF8_F8
// 3
F16_F16_F16_BF8_F8
,
// 3
I8_I8_I8
// 4
};
};
#define OP_NAME "grouped_conv_bwd_weight"
#define OP_NAME "grouped_conv_bwd_weight"
...
@@ -35,7 +36,8 @@ static void print_helper_msg()
...
@@ -35,7 +36,8 @@ static void print_helper_msg()
<<
"arg2: data type (0: Input fp32, Weight fp32, Output fp32
\n
"
<<
"arg2: data type (0: Input fp32, Weight fp32, Output fp32
\n
"
<<
" 1: Input fp16, Weight fp16, Output fp16
\n
"
<<
" 1: Input fp16, Weight fp16, Output fp16
\n
"
<<
" 2: Input bf16, Weight fp32, Output bf16
\n
"
<<
" 2: Input bf16, Weight fp32, Output bf16
\n
"
<<
" 3: Input fp16, Weight fp16, Output fp16, Gemm bf8@fp8)
\n
"
<<
" 3: Input fp16, Weight fp16, Output fp16, Gemm bf8@fp8
\n
"
<<
" 4: Input int8, Weight int8, Output int8)
\n
"
<<
"arg3: tensor layout (0: Input[G, N, C, Hi, Wi], Weight[G, K, C, Y, X], Output[G, "
<<
"arg3: tensor layout (0: Input[G, N, C, Hi, Wi], Weight[G, K, C, Y, X], Output[G, "
"N, K, Ho, Wo]
\n
"
"N, K, Ho, Wo]
\n
"
<<
" 1: Input[G, N, Hi, Wi, C], Weight[G, K, Y, X, C], Output[G, "
<<
" 1: Input[G, N, Hi, Wi, C], Weight[G, K, Y, X, C], Output[G, "
...
@@ -196,6 +198,11 @@ int profile_grouped_conv_bwd_weight(int argc, char* argv[])
...
@@ -196,6 +198,11 @@ int profile_grouped_conv_bwd_weight(int argc, char* argv[])
// fp32 atomic add is used for weight tensor in bf16 kernel
// fp32 atomic add is used for weight tensor in bf16 kernel
return
profile
(
I3
,
GNDHWC
{},
GKZYXC
{},
GNDHWK
{},
BF16
{},
F32
{},
BF16
{},
BF16
{},
BF16
{});
return
profile
(
I3
,
GNDHWC
{},
GKZYXC
{},
GNDHWK
{},
BF16
{},
F32
{},
BF16
{},
BF16
{},
BF16
{});
}
}
else
if
(
data_type
==
ConvDataType
::
I8_I8_I8
)
{
return
profile
(
I3
,
GNDHWC
{},
GKZYXC
{},
GNDHWK
{},
int8_t
{},
int8_t
{},
int8_t
{},
int8_t
{},
int8_t
{});
}
}
}
else
if
(
num_dim_spatial
==
3
&&
layout
==
ConvLayout
::
NHWGC_GKYXC_NHWGK
)
else
if
(
num_dim_spatial
==
3
&&
layout
==
ConvLayout
::
NHWGC_GKYXC_NHWGK
)
{
{
...
@@ -216,6 +223,11 @@ int profile_grouped_conv_bwd_weight(int argc, char* argv[])
...
@@ -216,6 +223,11 @@ int profile_grouped_conv_bwd_weight(int argc, char* argv[])
{
{
return
profile
(
I3
,
NDHWGC
{},
GKZYXC
{},
NDHWGK
{},
F16
{},
F16
{},
F16
{},
BF8
{},
F8
{});
return
profile
(
I3
,
NDHWGC
{},
GKZYXC
{},
NDHWGK
{},
F16
{},
F16
{},
F16
{},
BF8
{},
F8
{});
}
}
else
if
(
data_type
==
ConvDataType
::
I8_I8_I8
)
{
return
profile
(
I3
,
NDHWGC
{},
GKZYXC
{},
NDHWGK
{},
int8_t
{},
int8_t
{},
int8_t
{},
int8_t
{},
int8_t
{});
}
}
}
std
::
cout
<<
"this data_type & layout is not implemented"
<<
std
::
endl
;
std
::
cout
<<
"this data_type & layout is not implemented"
<<
std
::
endl
;
...
...
test/grouped_convnd_bwd_weight/CMakeLists.txt
View file @
7830272f
list
(
APPEND gpu_list gfx908 gfx90a gfx940 gfx941 gfx942
)
list
(
APPEND gpu_list_xdl gfx908 gfx90a gfx940 gfx941 gfx942
)
list
(
APPEND gpu_list_wmma gfx1100 gfx1101 gfx1102
)
set
(
target 0
)
set
(
target 0
)
foreach
(
gpu IN LISTS GPU_TARGETS
)
foreach
(
gpu IN LISTS GPU_TARGETS
)
if
(
gpu IN_LIST gpu_list AND target EQUAL 0
)
if
(
gpu IN_LIST gpu_list
_xdl
AND target EQUAL 0
)
add_gtest_executable
(
test_grouped_convnd_bwd_weight test_grouped_convnd_bwd_weight.cpp
)
add_gtest_executable
(
test_grouped_convnd_bwd_weight test_grouped_convnd_bwd_weight.cpp
)
target_link_libraries
(
test_grouped_convnd_bwd_weight PRIVATE utility device_grouped_conv1d_bwd_weight_instance device_grouped_conv2d_bwd_weight_instance device_grouped_conv3d_bwd_weight_instance
)
target_link_libraries
(
test_grouped_convnd_bwd_weight PRIVATE utility device_grouped_conv1d_bwd_weight_instance device_grouped_conv2d_bwd_weight_instance device_grouped_conv3d_bwd_weight_instance
)
add_gtest_executable
(
test_grouped_convnd_bwd_weight_interface test_grouped_convnd_bwd_weight_interface.cpp
)
add_gtest_executable
(
test_grouped_convnd_bwd_weight_interface test_grouped_convnd_bwd_weight_interface_xdl.cpp
)
target_link_libraries
(
test_grouped_convnd_bwd_weight_interface PRIVATE utility device_grouped_conv1d_bwd_weight_instance device_grouped_conv2d_bwd_weight_instance device_grouped_conv3d_bwd_weight_instance
)
target_link_libraries
(
test_grouped_convnd_bwd_weight_interface PRIVATE utility
)
set
(
target 1
)
endif
()
if
(
gpu IN_LIST gpu_list_wmma AND target EQUAL 0
)
add_gtest_executable
(
test_grouped_convnd_bwd_weight test_grouped_convnd_bwd_weight.cpp
)
target_link_libraries
(
test_grouped_convnd_bwd_weight PRIVATE utility device_grouped_conv1d_bwd_weight_instance device_grouped_conv2d_bwd_weight_instance device_grouped_conv3d_bwd_weight_instance
)
add_gtest_executable
(
test_grouped_convnd_bwd_weight_interface test_grouped_convnd_bwd_weight_interface_wmma.cpp
)
target_link_libraries
(
test_grouped_convnd_bwd_weight_interface PRIVATE utility
)
set
(
target 1
)
set
(
target 1
)
endif
()
endif
()
endforeach
()
endforeach
()
\ No newline at end of file
test/grouped_convnd_bwd_weight/test_grouped_convnd_bwd_weight.cpp
View file @
7830272f
...
@@ -11,6 +11,7 @@
...
@@ -11,6 +11,7 @@
#include "ck/utility/common_header.hpp"
#include "ck/utility/common_header.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/host_utility/device_prop.hpp"
#include "profiler/profile_grouped_conv_bwd_weight_impl.hpp"
#include "profiler/profile_grouped_conv_bwd_weight_impl.hpp"
...
@@ -33,8 +34,9 @@ class TestGroupedConvndBwdWeight : public ::testing::Test
...
@@ -33,8 +34,9 @@ class TestGroupedConvndBwdWeight : public ::testing::Test
bool
skip_case
(
const
ck
::
utils
::
conv
::
ConvParam
&
params
,
const
ck
::
index_t
split_k
)
bool
skip_case
(
const
ck
::
utils
::
conv
::
ConvParam
&
params
,
const
ck
::
index_t
split_k
)
{
{
// Odd K or C values are supported only by DL kernel (only applies to fp16)
// Odd K or C values are supported only by DL and WMMA
// DL kernel currently supports only `split_k=1`
// kernels (only applies to fp16)
// DL and WMMA kernels currently support only `split_k=1`
if
constexpr
(
std
::
is_same_v
<
InDataType
,
ck
::
half_t
>
)
if
constexpr
(
std
::
is_same_v
<
InDataType
,
ck
::
half_t
>
)
{
{
if
(
split_k
!=
1
&&
(
params
.
K_
%
2
!=
0
||
params
.
C_
%
2
!=
0
))
if
(
split_k
!=
1
&&
(
params
.
K_
%
2
!=
0
||
params
.
C_
%
2
!=
0
))
...
@@ -53,6 +55,42 @@ class TestGroupedConvndBwdWeight : public ::testing::Test
...
@@ -53,6 +55,42 @@ class TestGroupedConvndBwdWeight : public ::testing::Test
}
}
}
}
const
bool
is_navi3x
=
ck
::
get_device_name
()
==
"gfx1100"
||
ck
::
get_device_name
()
==
"gfx1101"
||
ck
::
get_device_name
()
==
"gfx1102"
;
if
(
is_navi3x
)
{
// on navi3x only support for 3d is implemented
if
constexpr
(
NDimSpatial
{}
!=
3
)
{
return
true
;
}
// on navi3x only support for i8 and fp16 is implemented
if
constexpr
(
!
((
std
::
is_same_v
<
InDataType
,
int8_t
>
&&
std
::
is_same_v
<
WeiDataType
,
int8_t
>
&&
std
::
is_same_v
<
OutDataType
,
int8_t
>
)
||
(
std
::
is_same_v
<
InDataType
,
ck
::
half_t
>
&&
std
::
is_same_v
<
WeiDataType
,
ck
::
half_t
>
&&
std
::
is_same_v
<
OutDataType
,
ck
::
half_t
>
)))
{
return
true
;
}
// WMMA kernel is only supported for split_k=1
if
(
split_k
!=
1
)
{
return
true
;
}
}
else
{
// support for i8 is only implemented on navi3x
if
constexpr
(
std
::
is_same_v
<
InDataType
,
int8_t
>
&&
std
::
is_same_v
<
WeiDataType
,
int8_t
>
&&
std
::
is_same_v
<
OutDataType
,
int8_t
>
)
{
return
true
;
}
}
return
false
;
return
false
;
}
}
...
@@ -120,9 +158,11 @@ using KernelTypes3d = ::testing::Types<
...
@@ -120,9 +158,11 @@ using KernelTypes3d = ::testing::Types<
std
::
tuple
<
float
,
float
,
float
,
GNDHWC
,
GKZYXC
,
GNDHWK
,
ck
::
Number
<
3
>>
,
std
::
tuple
<
float
,
float
,
float
,
GNDHWC
,
GKZYXC
,
GNDHWK
,
ck
::
Number
<
3
>>
,
std
::
tuple
<
ck
::
half_t
,
ck
::
half_t
,
ck
::
half_t
,
GNDHWC
,
GKZYXC
,
GNDHWK
,
ck
::
Number
<
3
>>
,
std
::
tuple
<
ck
::
half_t
,
ck
::
half_t
,
ck
::
half_t
,
GNDHWC
,
GKZYXC
,
GNDHWK
,
ck
::
Number
<
3
>>
,
std
::
tuple
<
ck
::
bhalf_t
,
float
,
ck
::
bhalf_t
,
GNDHWC
,
GKZYXC
,
GNDHWK
,
ck
::
Number
<
3
>>
,
std
::
tuple
<
ck
::
bhalf_t
,
float
,
ck
::
bhalf_t
,
GNDHWC
,
GKZYXC
,
GNDHWK
,
ck
::
Number
<
3
>>
,
std
::
tuple
<
int8_t
,
int8_t
,
int8_t
,
GNDHWC
,
GKZYXC
,
GNDHWK
,
ck
::
Number
<
3
>>
,
std
::
tuple
<
float
,
float
,
float
,
NDHWGC
,
GKZYXC
,
NDHWGK
,
ck
::
Number
<
3
>>
,
std
::
tuple
<
float
,
float
,
float
,
NDHWGC
,
GKZYXC
,
NDHWGK
,
ck
::
Number
<
3
>>
,
std
::
tuple
<
ck
::
half_t
,
ck
::
half_t
,
ck
::
half_t
,
NDHWGC
,
GKZYXC
,
NDHWGK
,
ck
::
Number
<
3
>>
,
std
::
tuple
<
ck
::
half_t
,
ck
::
half_t
,
ck
::
half_t
,
NDHWGC
,
GKZYXC
,
NDHWGK
,
ck
::
Number
<
3
>>
,
std
::
tuple
<
ck
::
bhalf_t
,
float
,
ck
::
bhalf_t
,
NDHWGC
,
GKZYXC
,
NDHWGK
,
ck
::
Number
<
3
>>>
;
std
::
tuple
<
ck
::
bhalf_t
,
float
,
ck
::
bhalf_t
,
NDHWGC
,
GKZYXC
,
NDHWGK
,
ck
::
Number
<
3
>>
,
std
::
tuple
<
int8_t
,
int8_t
,
int8_t
,
NDHWGC
,
GKZYXC
,
NDHWGK
,
ck
::
Number
<
3
>>>
;
TYPED_TEST_SUITE
(
TestGroupedConvndBwdWeight1d
,
KernelTypes1d
);
TYPED_TEST_SUITE
(
TestGroupedConvndBwdWeight1d
,
KernelTypes1d
);
TYPED_TEST_SUITE
(
TestGroupedConvndBwdWeight2d
,
KernelTypes2d
);
TYPED_TEST_SUITE
(
TestGroupedConvndBwdWeight2d
,
KernelTypes2d
);
...
...
test/grouped_convnd_bwd_weight/test_grouped_convnd_bwd_weight_interface_wmma.cpp
0 → 100644
View file @
7830272f
// SPDX-License-Identifier: MIT
// Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved.
#include <cstdlib>
#include <iostream>
#include <initializer_list>
#include <tuple>
#include <vector>
#include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/convolution_backward_weight_specialization.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_wmma_cshuffle.hpp"
#include "ck/library/utility/convolution_parameter.hpp"
#include "ck/library/utility/algorithm.hpp"
#include "ck/library/utility/convolution_host_tensor_descriptor_helper.hpp"
#include <gtest/gtest.h>
using
F16
=
ck
::
half_t
;
using
F32
=
float
;
using
PassThrough
=
ck
::
tensor_operation
::
element_wise
::
PassThrough
;
template
<
ck
::
index_t
...
Is
>
using
S
=
ck
::
Sequence
<
Is
...
>
;
using
ConvolutionBackwardWeightSpecialization
=
ck
::
tensor_operation
::
device
::
ConvolutionBackwardWeightSpecialization
;
static
constexpr
auto
ConvBwdWeightDefault
=
ConvolutionBackwardWeightSpecialization
::
Default
;
static
constexpr
auto
Filter1x1Stride1Pad0
=
ConvolutionBackwardWeightSpecialization
::
Filter1x1Stride1Pad0
;
template
<
typename
Tuple
,
ConvolutionBackwardWeightSpecialization
ConvSpec
>
class
TestGroupedConvndBwdWeight
:
public
::
testing
::
Test
{
protected:
using
OutLayout
=
std
::
tuple_element_t
<
0
,
Tuple
>
;
using
WeiLayout
=
std
::
tuple_element_t
<
1
,
Tuple
>
;
using
InLayout
=
std
::
tuple_element_t
<
2
,
Tuple
>
;
static
constexpr
ck
::
index_t
NDimSpatial
=
std
::
tuple_element_t
<
3
,
Tuple
>
{};
// clang-format off
using
GroupedConvBwdWeightDeviceInstance
=
ck
::
tensor_operation
::
device
::
DeviceGroupedConvBwdWeight_Wmma_CShuffle
//| NumDim| A| B| C| AData| BData| CData| AccData| A| B| C| ConvForward| Block| MPer| NPer| KPer| K1| MPer| NPer| MRepeat| NRepeat| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer|
//| Spatial| Layout| Layout| Layout| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise| Specialization| Size| Block| Block| Block| | WMMA| WMMA| | | ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MRepeatPerWave| NRepeatPerWave| _MBlock_MPerBlock| ScalarPerVector|
//| | | | | | | | | Operation| Operation| Operation| | | | | | | | | | | Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NPerBlock| _NPerBlock|
//| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
<
NDimSpatial
,
InLayout
,
WeiLayout
,
OutLayout
,
F16
,
F16
,
F16
,
F32
,
PassThrough
,
PassThrough
,
PassThrough
,
ConvSpec
,
128
,
128
,
128
,
8
,
8
,
16
,
16
,
4
,
4
,
S
<
8
,
16
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
8
,
8
,
1
,
S
<
8
,
16
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
8
,
8
,
1
,
1
,
1
,
S
<
1
,
32
,
1
,
4
>
,
8
>
;
// clang-format on
ck
::
utils
::
conv
::
ConvParam
conv_param
;
template
<
ck
::
index_t
SplitK
>
bool
Run
()
{
const
auto
in_g_n_c_wis_desc
=
ck
::
utils
::
conv
::
make_input_host_tensor_descriptor_g_n_c_wis_packed
<
InLayout
>
(
conv_param
);
const
auto
wei_g_k_c_xs_desc
=
ck
::
utils
::
conv
::
make_weight_host_tensor_descriptor_g_k_c_xs_packed
<
WeiLayout
>
(
conv_param
);
const
auto
out_g_n_k_wos_desc
=
ck
::
utils
::
conv
::
make_output_host_tensor_descriptor_g_n_k_wos_packed
<
OutLayout
>
(
conv_param
);
std
::
array
<
ck
::
index_t
,
NDimSpatial
+
3
>
input_lengths
{};
std
::
array
<
ck
::
index_t
,
NDimSpatial
+
3
>
filter_lengths
{};
std
::
array
<
ck
::
index_t
,
NDimSpatial
+
3
>
output_lengths
{};
std
::
array
<
ck
::
index_t
,
NDimSpatial
+
3
>
input_strides
{};
std
::
array
<
ck
::
index_t
,
NDimSpatial
+
3
>
weights_strides
{};
std
::
array
<
ck
::
index_t
,
NDimSpatial
+
3
>
output_strides
{};
std
::
array
<
ck
::
index_t
,
NDimSpatial
>
conv_filter_strides
{};
std
::
array
<
ck
::
index_t
,
NDimSpatial
>
conv_filter_dilations
{};
std
::
array
<
ck
::
index_t
,
NDimSpatial
>
input_left_pads
{};
std
::
array
<
ck
::
index_t
,
NDimSpatial
>
input_right_pads
{};
auto
range_copy
=
[](
const
auto
&
from
,
auto
to
)
{
std
::
copy
(
begin
(
from
),
end
(
from
),
to
);
};
range_copy
(
in_g_n_c_wis_desc
.
GetLengths
(),
begin
(
input_lengths
));
range_copy
(
in_g_n_c_wis_desc
.
GetStrides
(),
begin
(
input_strides
));
range_copy
(
wei_g_k_c_xs_desc
.
GetLengths
(),
begin
(
filter_lengths
));
range_copy
(
wei_g_k_c_xs_desc
.
GetStrides
(),
begin
(
weights_strides
));
range_copy
(
out_g_n_k_wos_desc
.
GetLengths
(),
begin
(
output_lengths
));
range_copy
(
out_g_n_k_wos_desc
.
GetStrides
(),
begin
(
output_strides
));
range_copy
(
conv_param
.
conv_filter_strides_
,
begin
(
conv_filter_strides
));
range_copy
(
conv_param
.
conv_filter_dilations_
,
begin
(
conv_filter_dilations
));
range_copy
(
conv_param
.
input_left_pads_
,
begin
(
input_left_pads
));
range_copy
(
conv_param
.
input_right_pads_
,
begin
(
input_right_pads
));
auto
conv
=
GroupedConvBwdWeightDeviceInstance
{};
auto
argument
=
conv
.
MakeArgument
(
nullptr
,
nullptr
,
nullptr
,
input_lengths
,
input_strides
,
filter_lengths
,
weights_strides
,
output_lengths
,
output_strides
,
conv_filter_strides
,
conv_filter_dilations
,
input_left_pads
,
input_right_pads
,
PassThrough
{},
PassThrough
{},
PassThrough
{},
SplitK
);
return
conv
.
IsSupportedArgument
(
argument
);
}
};
using
namespace
ck
::
tensor_layout
::
convolution
;
using
KernelTypes3d
=
::
testing
::
Types
<
std
::
tuple
<
GNDHWK
,
GKZYXC
,
GNDHWC
,
ck
::
Number
<
3
>>
,
std
::
tuple
<
NDHWGK
,
GKZYXC
,
NDHWGC
,
ck
::
Number
<
3
>>>
;
template
<
typename
Tuple
>
class
TestGroupedConvndBwdWeightFilter1x13d
:
public
TestGroupedConvndBwdWeight
<
Tuple
,
Filter1x1Stride1Pad0
>
{
};
template
<
typename
Tuple
>
class
TestGroupedConvndBwdWeightDefault3d
:
public
TestGroupedConvndBwdWeight
<
Tuple
,
ConvBwdWeightDefault
>
{
};
TYPED_TEST_SUITE
(
TestGroupedConvndBwdWeightFilter1x13d
,
KernelTypes3d
);
TYPED_TEST_SUITE
(
TestGroupedConvndBwdWeightDefault3d
,
KernelTypes3d
);
TYPED_TEST
(
TestGroupedConvndBwdWeightFilter1x13d
,
SpecializationCheck
)
{
// Check filter 3x3x3 instead of 1x1x1
this
->
conv_param
=
{
3
,
2
,
4
,
192
,
192
,
{
3
,
3
,
3
},
{
28
,
28
,
28
},
{
1
,
1
,
1
},
{
1
,
1
,
1
},
{
0
,
0
,
0
},
{
0
,
0
,
0
}};
bool
is_supported
=
this
->
template
Run
<
1
>();
EXPECT_FALSE
(
is_supported
);
// Check strides 2x2x2 instead of 1x1x1
this
->
conv_param
=
{
3
,
2
,
4
,
192
,
192
,
{
1
,
1
,
1
},
{
28
,
28
,
28
},
{
2
,
2
,
2
},
{
1
,
1
,
1
},
{
0
,
0
,
0
},
{
0
,
0
,
0
}};
is_supported
=
this
->
template
Run
<
1
>();
EXPECT_FALSE
(
is_supported
);
// Check with pad
this
->
conv_param
=
{
3
,
2
,
4
,
192
,
192
,
{
1
,
1
,
1
},
{
28
,
28
,
28
},
{
1
,
1
,
1
},
{
1
,
1
,
1
},
{
1
,
1
,
1
},
{
1
,
1
,
1
}};
is_supported
=
this
->
template
Run
<
1
>();
EXPECT_FALSE
(
is_supported
);
// Supported version
this
->
conv_param
=
{
3
,
2
,
128
,
128
,
256
,
{
1
,
1
,
1
},
{
3
,
3
,
3
},
{
1
,
1
,
1
},
{
1
,
1
,
1
},
{
0
,
0
,
0
},
{
0
,
0
,
0
}};
is_supported
=
this
->
template
Run
<
1
>();
EXPECT_TRUE
(
is_supported
);
}
TYPED_TEST
(
TestGroupedConvndBwdWeightDefault3d
,
VectorLoadCheck
)
{
// vector load for A
this
->
conv_param
=
{
3
,
2
,
128
,
129
,
256
,
{
1
,
1
,
1
},
{
7
,
7
,
7
},
{
2
,
2
,
2
},
{
1
,
1
,
1
},
{
0
,
0
,
0
},
{
0
,
0
,
0
}};
bool
is_supported
=
this
->
template
Run
<
1
>();
EXPECT_FALSE
(
is_supported
);
// vector load for B, E, Ds
this
->
conv_param
=
{
3
,
2
,
128
,
128
,
257
,
{
1
,
1
,
1
},
{
7
,
7
,
7
},
{
2
,
2
,
2
},
{
1
,
1
,
1
},
{
0
,
0
,
0
},
{
0
,
0
,
0
}};
is_supported
=
this
->
template
Run
<
1
>();
EXPECT_FALSE
(
is_supported
);
}
TYPED_TEST
(
TestGroupedConvndBwdWeightDefault3d
,
SplitKCheck
)
{
// SplitK=1
this
->
conv_param
=
{
3
,
2
,
128
,
128
,
256
,
{
1
,
1
,
1
},
{
3
,
3
,
3
},
{
1
,
1
,
1
},
{
1
,
1
,
1
},
{
0
,
0
,
0
},
{
0
,
0
,
0
}};
bool
is_supported
=
this
->
template
Run
<
1
>();
EXPECT_TRUE
(
is_supported
);
// SplitK=2
this
->
conv_param
=
{
3
,
2
,
128
,
128
,
256
,
{
1
,
1
,
1
},
{
3
,
3
,
3
},
{
1
,
1
,
1
},
{
1
,
1
,
1
},
{
0
,
0
,
0
},
{
0
,
0
,
0
}};
is_supported
=
this
->
template
Run
<
2
>();
EXPECT_FALSE
(
is_supported
);
}
test/grouped_convnd_bwd_weight/test_grouped_convnd_bwd_weight_interface.cpp
→
test/grouped_convnd_bwd_weight/test_grouped_convnd_bwd_weight_interface
_xdl
.cpp
View file @
7830272f
File moved
Prev
1
2
3
4
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment