Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
139b950f
Commit
139b950f
authored
Jul 07, 2023
by
Bartlomiej Kocot
Browse files
Fix comments
parent
e6f4653a
Changes
12
Hide whitespace changes
Inline
Side-by-side
Showing
12 changed files
with
256 additions
and
217 deletions
+256
-217
client_example/11_grouped_conv_bwd_weight/common.hpp
client_example/11_grouped_conv_bwd_weight/common.hpp
+6
-4
client_example/11_grouped_conv_bwd_weight/grouped_conv1d_bwd_weight_fp16.cpp
...rouped_conv_bwd_weight/grouped_conv1d_bwd_weight_fp16.cpp
+22
-6
client_example/11_grouped_conv_bwd_weight/grouped_conv2d_bwd_weight_fp16.cpp
...rouped_conv_bwd_weight/grouped_conv2d_bwd_weight_fp16.cpp
+16
-9
client_example/11_grouped_conv_bwd_weight/grouped_conv3d_bwd_weight_fp16.cpp
...rouped_conv_bwd_weight/grouped_conv3d_bwd_weight_fp16.cpp
+16
-9
client_example/11_grouped_conv_bwd_weight/grouped_conv3d_bwd_weight_fp32.cpp
...rouped_conv_bwd_weight/grouped_conv3d_bwd_weight_fp32.cpp
+23
-15
example/20_grouped_conv_bwd_weight/grouped_conv_bwd_weight_xdl_bf16.cpp
...uped_conv_bwd_weight/grouped_conv_bwd_weight_xdl_bf16.cpp
+1
-1
example/20_grouped_conv_bwd_weight/grouped_conv_bwd_weight_xdl_fp16.cpp
...uped_conv_bwd_weight/grouped_conv_bwd_weight_xdl_fp16.cpp
+1
-1
include/ck/tensor_operation/gpu/device/device_grouped_conv_bwd_weight.hpp
...r_operation/gpu/device/device_grouped_conv_bwd_weight.hpp
+13
-13
include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_gnwc_gkxc_gnwk_dl.hpp
...impl/device_grouped_conv_bwd_weight_gnwc_gkxc_gnwk_dl.hpp
+84
-84
include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_xdl_cshuffle.hpp
...vice/impl/device_grouped_conv_bwd_weight_xdl_cshuffle.hpp
+53
-53
library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_weight/device_grouped_conv2d_bwd_weight_xdl_instance.hpp
..._weight/device_grouped_conv2d_bwd_weight_xdl_instance.hpp
+0
-2
test/grouped_convnd_bwd_weight/test_grouped_convnd_bwd_weight.cpp
...uped_convnd_bwd_weight/test_grouped_convnd_bwd_weight.cpp
+21
-20
No files found.
client_example/11_grouped_conv_bwd_weight/common.hpp
View file @
139b950f
...
...
@@ -101,10 +101,10 @@ template <ck::index_t NumDimSpatial,
typename
WeiLayout
,
typename
OutLayout
>
bool
run_grouped_conv_bwd_weight
(
ck
::
index_t
G
,
ck
::
index_t
N
,
ck
::
index_t
K
,
ck
::
index_t
C
,
const
ck
::
index_t
G
,
const
ck
::
index_t
N
,
const
ck
::
index_t
K
,
const
ck
::
index_t
C
,
const
std
::
array
<
ck
::
index_t
,
NumDimSpatial
>&
input_spatial_lengths
,
const
std
::
array
<
ck
::
index_t
,
NumDimSpatial
>&
filter_spatial_lengths
,
const
std
::
array
<
ck
::
index_t
,
NumDimSpatial
>&
output_spatial_lengths
,
...
...
@@ -228,6 +228,8 @@ bool run_grouped_conv_bwd_weight(
input_spatial_lengths
,
filter_spatial_lengths
,
output_spatial_lengths
,
input_strides
,
output_strides
,
conv_filter_strides
,
conv_filter_dilations
,
input_left_pads
,
...
...
client_example/11_grouped_conv_bwd_weight/grouped_conv1d_bwd_weight_fp16.cpp
View file @
139b950f
...
...
@@ -22,10 +22,15 @@ static constexpr ck::index_t C = 192;
static
constexpr
ck
::
index_t
X
=
3
;
static
constexpr
ck
::
index_t
Wi
=
28
;
static
constexpr
ck
::
index_t
Wo
=
28
;
static
constexpr
std
::
array
<
ck
::
index_t
,
NumDimSpatial
+
3
>
input_strides
{
G
*
N
*
Wi
*
C
,
N
*
Wi
*
C
,
Wi
*
C
,
C
,
1
};
static
constexpr
std
::
array
<
ck
::
index_t
,
NumDimSpatial
+
3
>
output_strides
{
G
*
N
*
Wo
*
K
,
N
*
Wo
*
K
,
Wo
*
K
,
K
,
1
};
static
constexpr
std
::
array
<
ck
::
index_t
,
NumDimSpatial
>
input_spatial_lengths
{
Wi
};
static
constexpr
std
::
array
<
ck
::
index_t
,
NumDimSpatial
>
filter_spatial_lengths
{
X
};
static
constexpr
std
::
array
<
ck
::
index_t
,
NumDimSpatial
>
output_spatial_lengths
{
Wo
};
static
constexpr
std
::
array
<
ck
::
index_t
,
NumDimSpatial
+
3
>
input_strides
{
N
*
Wi
*
C
,
Wi
*
C
,
C
,
1
};
static
constexpr
std
::
array
<
ck
::
index_t
,
NumDimSpatial
+
3
>
output_strides
{
N
*
Wo
*
K
,
Wo
*
K
,
K
,
1
};
static
constexpr
std
::
array
<
ck
::
index_t
,
NumDimSpatial
>
conv_filter_strides
{
1
};
static
constexpr
std
::
array
<
ck
::
index_t
,
NumDimSpatial
>
conv_filter_dilations
{
1
};
static
constexpr
std
::
array
<
ck
::
index_t
,
NumDimSpatial
>
input_left_pads
{
1
};
static
constexpr
std
::
array
<
ck
::
index_t
,
NumDimSpatial
>
input_right_pads
{
1
};
int
main
()
{
...
...
@@ -35,8 +40,19 @@ int main()
OutDataType
,
InLayout
,
WeiLayout
,
OutLayout
>
(
G
,
N
,
K
,
C
,
{
Wi
},
{
X
},
{
Wo
},
input_strides
,
output_strides
,
{}
{
1
},
{
1
},
{
1
},
{
1
})
OutLayout
>
(
G
,
N
,
K
,
C
,
input_spatial_lengths
,
filter_spatial_lengths
,
output_spatial_lengths
,
input_strides
,
output_strides
,
conv_filter_strides
,
conv_filter_dilations
,
input_left_pads
,
input_right_pads
)
?
EXIT_SUCCESS
:
EXIT_FAILURE
;
}
client_example/11_grouped_conv_bwd_weight/grouped_conv2d_bwd_weight_fp16.cpp
View file @
139b950f
...
...
@@ -25,10 +25,17 @@ static constexpr ck::index_t Hi = 28;
static
constexpr
ck
::
index_t
Wi
=
28
;
static
constexpr
ck
::
index_t
Ho
=
28
;
static
constexpr
ck
::
index_t
Wo
=
28
;
static
constexpr
std
::
array
<
ck
::
index_t
,
NumDimSpatial
>
input_spatial_lengths
{
Hi
,
Wi
};
static
constexpr
std
::
array
<
ck
::
index_t
,
NumDimSpatial
>
filter_spatial_lengths
{
Y
,
X
};
static
constexpr
std
::
array
<
ck
::
index_t
,
NumDimSpatial
>
output_spatial_lengths
{
Ho
,
Wo
};
static
constexpr
std
::
array
<
ck
::
index_t
,
NumDimSpatial
+
3
>
input_strides
{
G
*
N
*
Hi
*
Wi
*
C
,
N
*
Hi
*
Wi
*
C
,
Hi
*
Wi
*
C
,
Wi
*
C
,
C
,
1
};
N
*
Hi
*
Wi
*
C
,
Hi
*
Wi
*
C
,
Wi
*
C
,
C
,
1
};
static
constexpr
std
::
array
<
ck
::
index_t
,
NumDimSpatial
+
3
>
output_strides
{
G
*
N
*
Ho
*
Wo
*
K
,
N
*
Ho
*
Wo
*
K
,
Ho
*
Wo
*
K
,
Wo
*
K
,
K
,
1
};
N
*
Ho
*
Wo
*
K
,
Ho
*
Wo
*
K
,
Wo
*
K
,
K
,
1
};
static
constexpr
std
::
array
<
ck
::
index_t
,
NumDimSpatial
>
conv_filter_strides
{
1
,
1
};
static
constexpr
std
::
array
<
ck
::
index_t
,
NumDimSpatial
>
conv_filter_dilations
{
1
,
1
};
static
constexpr
std
::
array
<
ck
::
index_t
,
NumDimSpatial
>
input_left_pads
{
1
,
1
};
static
constexpr
std
::
array
<
ck
::
index_t
,
NumDimSpatial
>
input_right_pads
{
1
,
1
};
int
main
()
{
...
...
@@ -42,15 +49,15 @@ int main()
N
,
K
,
C
,
{
Hi
,
Wi
}
,
{
Y
,
X
}
,
{
Ho
,
Wo
}
,
input_spatial_lengths
,
filter_spatial_lengths
,
output_spatial_lengths
,
input_strides
,
output_strides
,
{
1
,
1
}
,
{
1
,
1
}
,
{
1
,
1
}
,
{
1
,
1
}
)
conv_filter_strides
,
conv_filter_dilations
,
input_left_pads
,
input_right_pads
)
?
EXIT_SUCCESS
:
EXIT_FAILURE
;
}
client_example/11_grouped_conv_bwd_weight/grouped_conv3d_bwd_weight_fp16.cpp
View file @
139b950f
...
...
@@ -28,10 +28,17 @@ static constexpr ck::index_t Wi = 3;
static
constexpr
ck
::
index_t
Do
=
28
;
static
constexpr
ck
::
index_t
Ho
=
28
;
static
constexpr
ck
::
index_t
Wo
=
3
;
static
constexpr
std
::
array
<
ck
::
index_t
,
NumDimSpatial
>
input_spatial_lengths
{
Di
,
Hi
,
Wi
};
static
constexpr
std
::
array
<
ck
::
index_t
,
NumDimSpatial
>
filter_spatial_lengths
{
Z
,
Y
,
X
};
static
constexpr
std
::
array
<
ck
::
index_t
,
NumDimSpatial
>
output_spatial_lengths
{
Do
,
Ho
,
Wo
};
static
constexpr
std
::
array
<
ck
::
index_t
,
NumDimSpatial
+
3
>
input_strides
{
G
*
N
*
Di
*
Hi
*
Wi
*
C
,
N
*
Di
*
Hi
*
Wi
*
C
,
Di
*
Hi
*
Wi
*
C
,
Hi
*
Wi
*
C
,
Wi
*
C
,
C
,
1
};
N
*
Di
*
Hi
*
Wi
*
C
,
Di
*
Hi
*
Wi
*
C
,
Hi
*
Wi
*
C
,
Wi
*
C
,
C
,
1
};
static
constexpr
std
::
array
<
ck
::
index_t
,
NumDimSpatial
+
3
>
output_strides
{
G
*
N
*
Do
*
Ho
*
Wo
*
K
,
N
*
Do
*
Ho
*
Wo
*
K
,
Do
*
Ho
*
Wo
*
K
,
Ho
*
Wo
*
K
,
Wo
*
K
,
K
,
1
};
N
*
Do
*
Ho
*
Wo
*
K
,
Do
*
Ho
*
Wo
*
K
,
Ho
*
Wo
*
K
,
Wo
*
K
,
K
,
1
};
static
constexpr
std
::
array
<
ck
::
index_t
,
NumDimSpatial
>
conv_filter_strides
{
1
,
1
,
1
};
static
constexpr
std
::
array
<
ck
::
index_t
,
NumDimSpatial
>
conv_filter_dilations
{
1
,
1
,
1
};
static
constexpr
std
::
array
<
ck
::
index_t
,
NumDimSpatial
>
input_left_pads
{
1
,
1
,
1
};
static
constexpr
std
::
array
<
ck
::
index_t
,
NumDimSpatial
>
input_right_pads
{
1
,
1
,
1
};
int
main
()
{
...
...
@@ -45,15 +52,15 @@ int main()
N
,
K
,
C
,
{
Di
,
Hi
,
Wi
}
,
{
Z
,
Y
,
X
}
,
{
Do
,
Ho
,
Wo
}
,
input_spatial_lengths
,
filter_spatial_lengths
,
output_spatial_lengths
,
input_strides
,
output_strides
,
{
1
,
1
,
1
}
,
{
1
,
1
,
1
}
,
{
1
,
1
,
1
}
,
{
1
,
1
,
1
}
)
conv_filter_strides
,
conv_filter_dilations
,
input_left_pads
,
input_right_pads
)
?
EXIT_SUCCESS
:
EXIT_FAILURE
;
}
client_example/11_grouped_conv_bwd_weight/grouped_conv3d_bwd_weight_fp32.cpp
View file @
139b950f
...
...
@@ -28,10 +28,17 @@ static constexpr ck::index_t Wi = 3;
static
constexpr
ck
::
index_t
Do
=
28
;
static
constexpr
ck
::
index_t
Ho
=
28
;
static
constexpr
ck
::
index_t
Wo
=
3
;
static
constexpr
std
::
array
<
ck
::
index_t
,
NumDimSpatial
>
input_spatial_lengths
{
Di
,
Hi
,
Wi
};
static
constexpr
std
::
array
<
ck
::
index_t
,
NumDimSpatial
>
filter_spatial_lengths
{
Z
,
Y
,
X
};
static
constexpr
std
::
array
<
ck
::
index_t
,
NumDimSpatial
>
output_spatial_lengths
{
Do
,
Ho
,
Wo
};
static
constexpr
std
::
array
<
ck
::
index_t
,
NumDimSpatial
+
3
>
input_strides
{
G
*
N
*
Di
*
Hi
*
Wi
*
C
,
N
*
Di
*
Hi
*
Wi
*
C
,
Di
*
Hi
*
Wi
*
C
,
Hi
*
Wi
*
C
,
Wi
*
C
,
C
,
1
};
N
*
Di
*
Hi
*
Wi
*
C
,
Di
*
Hi
*
Wi
*
C
,
Hi
*
Wi
*
C
,
Wi
*
C
,
C
,
1
};
static
constexpr
std
::
array
<
ck
::
index_t
,
NumDimSpatial
+
3
>
output_strides
{
G
*
N
*
Do
*
Ho
*
Wo
*
K
,
N
*
Do
*
Ho
*
Wo
*
K
,
Do
*
Ho
*
Wo
*
K
,
Ho
*
Wo
*
K
,
Wo
*
K
,
K
,
1
};
N
*
Do
*
Ho
*
Wo
*
K
,
Do
*
Ho
*
Wo
*
K
,
Ho
*
Wo
*
K
,
Wo
*
K
,
K
,
1
};
static
constexpr
std
::
array
<
ck
::
index_t
,
NumDimSpatial
>
conv_filter_strides
{
1
,
1
,
1
};
static
constexpr
std
::
array
<
ck
::
index_t
,
NumDimSpatial
>
conv_filter_dilations
{
1
,
1
,
1
};
static
constexpr
std
::
array
<
ck
::
index_t
,
NumDimSpatial
>
input_left_pads
{
1
,
1
,
1
};
static
constexpr
std
::
array
<
ck
::
index_t
,
NumDimSpatial
>
input_right_pads
{
1
,
1
,
1
};
int
main
()
{
...
...
@@ -41,19 +48,20 @@ int main()
OutDataType
,
InLayout
,
WeiLayout
,
OutLayout
>
(
G
,
N
,
K
,
C
,
{
Di
,
Hi
,
Wi
},
{
Z
,
Y
,
X
},
{
Do
,
Ho
,
Wo
},
input_strides
,
output_strides
,
{
1
,
1
,
1
},
{
1
,
1
,
1
},
{
1
,
1
,
1
},
{
1
,
1
,
1
})
OutLayout
>
(
G
,
N
,
K
,
C
,
{
Di
,
Hi
,
Wi
},
{
Z
,
Y
,
X
},
{
Do
,
Ho
,
Wo
},
{
N
*
Di
*
Hi
*
Wi
*
C
,
Di
*
Hi
*
Wi
*
C
,
Hi
*
Wi
*
C
,
Wi
*
C
,
C
,
1
},
{
N
*
Do
*
Ho
*
Wo
*
K
,
Do
*
Ho
*
Wo
*
K
,
Ho
*
Wo
*
K
,
Wo
*
K
,
K
,
1
},
{
1
,
1
,
1
},
{
1
,
1
,
1
},
{
1
,
1
,
1
},
{
1
,
1
,
1
})
?
EXIT_SUCCESS
:
EXIT_FAILURE
;
}
example/20_grouped_conv_bwd_weight/grouped_conv_bwd_weight_xdl_bf16.cpp
View file @
139b950f
...
...
@@ -18,7 +18,7 @@ using OutElementOp = PassThrough;
template
<
ck
::
index_t
NDimSpatial
>
using
DeviceConvBwdWeightInstance
=
ck
::
tensor_operation
::
device
::
DeviceGroupedConvBwdWeight_Xdl_CShuffle
<
NDimSpatial
,
// NDimSpatial
NDimSpatial
,
ck
::
tuple_element_t
<
NDimSpatial
-
1
,
ck
::
Tuple
<
ck
::
tensor_layout
::
convolution
::
GNWC
,
ck
::
tensor_layout
::
convolution
::
GNHWC
,
...
...
example/20_grouped_conv_bwd_weight/grouped_conv_bwd_weight_xdl_fp16.cpp
View file @
139b950f
...
...
@@ -17,7 +17,7 @@ using OutElementOp = PassThrough;
template
<
ck
::
index_t
NDimSpatial
>
using
DeviceConvBwdWeightInstance
=
ck
::
tensor_operation
::
device
::
DeviceGroupedConvBwdWeight_Xdl_CShuffle
<
NDimSpatial
,
// NDimSpatial
NDimSpatial
,
ck
::
tuple_element_t
<
NDimSpatial
-
1
,
ck
::
Tuple
<
ck
::
tensor_layout
::
convolution
::
GNWC
,
ck
::
tensor_layout
::
convolution
::
GNHWC
,
...
...
include/ck/tensor_operation/gpu/device/device_grouped_conv_bwd_weight.hpp
View file @
139b950f
...
...
@@ -27,19 +27,19 @@ struct DeviceGroupedConvBwdWeight : public BaseOperator
MakeArgumentPointer
(
const
void
*
p_in
,
void
*
p_wei
,
const
void
*
p_out
,
ck
::
index_t
G
,
ck
::
index_t
N
,
ck
::
index_t
K
,
ck
::
index_t
C
,
std
::
array
<
ck
::
index_t
,
NDimSpatial
>&
input_spatial_lengths
,
std
::
array
<
ck
::
index_t
,
NDimSpatial
>&
filter_spatial_lengths
,
std
::
array
<
ck
::
index_t
,
NDimSpatial
>&
output_spatial_lengths
,
std
::
array
<
ck
::
index_t
,
NDimSpatial
+
3
>&
input_strides
,
std
::
array
<
ck
::
index_t
,
NDimSpatial
+
3
>&
output_strides
,
std
::
array
<
ck
::
index_t
,
NDimSpatial
>&
conv_filter_strides
,
std
::
array
<
ck
::
index_t
,
NDimSpatial
>&
conv_filter_dilations
,
std
::
array
<
ck
::
index_t
,
NDimSpatial
>&
input_left_pads
,
std
::
array
<
ck
::
index_t
,
NDimSpatial
>&
input_right_pads
,
const
ck
::
index_t
G
,
const
ck
::
index_t
N
,
const
ck
::
index_t
K
,
const
ck
::
index_t
C
,
const
std
::
array
<
ck
::
index_t
,
NDimSpatial
>&
input_spatial_lengths
,
const
std
::
array
<
ck
::
index_t
,
NDimSpatial
>&
filter_spatial_lengths
,
const
std
::
array
<
ck
::
index_t
,
NDimSpatial
>&
output_spatial_lengths
,
const
std
::
array
<
ck
::
index_t
,
NDimSpatial
+
3
>&
input_strides
,
const
std
::
array
<
ck
::
index_t
,
NDimSpatial
+
3
>&
output_strides
,
const
std
::
array
<
ck
::
index_t
,
NDimSpatial
>&
conv_filter_strides
,
const
std
::
array
<
ck
::
index_t
,
NDimSpatial
>&
conv_filter_dilations
,
const
std
::
array
<
ck
::
index_t
,
NDimSpatial
>&
input_left_pads
,
const
std
::
array
<
ck
::
index_t
,
NDimSpatial
>&
input_right_pads
,
InElementwiseOperation
in_element_op
,
WeiElementwiseOperation
wei_element_op
,
OutElementwiseOperation
out_element_op
,
...
...
include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_gnwc_gkxc_gnwk_dl.hpp
View file @
139b950f
...
...
@@ -195,17 +195,17 @@ struct DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Dl
template
<
ck
::
index_t
NDim
,
typename
ck
::
enable_if
<
NDim
==
1
,
bool
>
::
type
=
false
>
static
auto
MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N
(
ck
::
index_t
N
,
ck
::
index_t
K
,
ck
::
index_t
C
,
std
::
array
<
ck
::
index_t
,
NDimSpatial
>
input_spatial_lengths
,
std
::
array
<
ck
::
index_t
,
NDimSpatial
>
filter_spatial_lengths
,
std
::
array
<
ck
::
index_t
,
NDimSpatial
>
output_spatial_lengths
,
std
::
array
<
ck
::
index_t
,
NDimSpatial
>
conv_filter_strides
,
std
::
array
<
ck
::
index_t
,
NDimSpatial
>
conv_filter_dilations
,
std
::
array
<
ck
::
index_t
,
NDimSpatial
>
input_left_pads
,
std
::
array
<
ck
::
index_t
,
NDimSpatial
>
input_right_pads
,
ck
::
index_t
batch_k
)
const
ck
::
index_t
N
,
const
ck
::
index_t
K
,
const
ck
::
index_t
C
,
const
std
::
array
<
ck
::
index_t
,
NDimSpatial
>
&
input_spatial_lengths
,
const
std
::
array
<
ck
::
index_t
,
NDimSpatial
>
&
filter_spatial_lengths
,
const
std
::
array
<
ck
::
index_t
,
NDimSpatial
>
&
output_spatial_lengths
,
const
std
::
array
<
ck
::
index_t
,
NDimSpatial
>
&
conv_filter_strides
,
const
std
::
array
<
ck
::
index_t
,
NDimSpatial
>
&
conv_filter_dilations
,
const
std
::
array
<
ck
::
index_t
,
NDimSpatial
>
&
input_left_pads
,
const
std
::
array
<
ck
::
index_t
,
NDimSpatial
>
&
input_right_pads
,
const
ck
::
index_t
batch_k
)
{
using
namespace
ck
;
...
...
@@ -347,17 +347,17 @@ struct DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Dl
}
// function end
template
<
ck
::
index_t
NDim
,
typename
ck
::
enable_if
<
NDim
==
2
,
bool
>
::
type
=
false
>
static
auto
MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N
(
ck
::
index_t
N
,
ck
::
index_t
K
,
ck
::
index_t
C
,
std
::
array
<
ck
::
index_t
,
NDimSpatial
>
input_spatial_lengths
,
std
::
array
<
ck
::
index_t
,
NDimSpatial
>
filter_spatial_lengths
,
std
::
array
<
ck
::
index_t
,
NDimSpatial
>
output_spatial_lengths
,
std
::
array
<
ck
::
index_t
,
NDimSpatial
>
conv_filter_strides
,
std
::
array
<
ck
::
index_t
,
NDimSpatial
>
conv_filter_dilations
,
std
::
array
<
ck
::
index_t
,
NDimSpatial
>
input_left_pads
,
std
::
array
<
ck
::
index_t
,
NDimSpatial
>
input_right_pads
,
ck
::
index_t
batch_k
)
const
ck
::
index_t
N
,
const
ck
::
index_t
K
,
const
ck
::
index_t
C
,
const
std
::
array
<
ck
::
index_t
,
NDimSpatial
>
&
input_spatial_lengths
,
const
std
::
array
<
ck
::
index_t
,
NDimSpatial
>
&
filter_spatial_lengths
,
const
std
::
array
<
ck
::
index_t
,
NDimSpatial
>
&
output_spatial_lengths
,
const
std
::
array
<
ck
::
index_t
,
NDimSpatial
>
&
conv_filter_strides
,
const
std
::
array
<
ck
::
index_t
,
NDimSpatial
>
&
conv_filter_dilations
,
const
std
::
array
<
ck
::
index_t
,
NDimSpatial
>
&
input_left_pads
,
const
std
::
array
<
ck
::
index_t
,
NDimSpatial
>
&
input_right_pads
,
const
ck
::
index_t
batch_k
)
{
using
namespace
ck
;
...
...
@@ -515,17 +515,17 @@ struct DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Dl
template
<
ck
::
index_t
NDim
,
typename
ck
::
enable_if
<
NDim
==
3
,
bool
>
::
type
=
false
>
static
auto
MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N
(
ck
::
index_t
N
,
ck
::
index_t
K
,
ck
::
index_t
C
,
std
::
array
<
ck
::
index_t
,
NDimSpatial
>
input_spatial_lengths
,
std
::
array
<
ck
::
index_t
,
NDimSpatial
>
filter_spatial_lengths
,
std
::
array
<
ck
::
index_t
,
NDimSpatial
>
output_spatial_lengths
,
std
::
array
<
ck
::
index_t
,
NDimSpatial
>
conv_filter_strides
,
std
::
array
<
ck
::
index_t
,
NDimSpatial
>
conv_filter_dilations
,
std
::
array
<
ck
::
index_t
,
NDimSpatial
>
input_left_pads
,
std
::
array
<
ck
::
index_t
,
NDimSpatial
>
input_right_pads
,
ck
::
index_t
batch_k
)
const
ck
::
index_t
N
,
const
ck
::
index_t
K
,
const
ck
::
index_t
C
,
const
std
::
array
<
ck
::
index_t
,
NDimSpatial
>
&
input_spatial_lengths
,
const
std
::
array
<
ck
::
index_t
,
NDimSpatial
>
&
filter_spatial_lengths
,
const
std
::
array
<
ck
::
index_t
,
NDimSpatial
>
&
output_spatial_lengths
,
const
std
::
array
<
ck
::
index_t
,
NDimSpatial
>
&
conv_filter_strides
,
const
std
::
array
<
ck
::
index_t
,
NDimSpatial
>
&
conv_filter_dilations
,
const
std
::
array
<
ck
::
index_t
,
NDimSpatial
>
&
input_left_pads
,
const
std
::
array
<
ck
::
index_t
,
NDimSpatial
>
&
input_right_pads
,
const
ck
::
index_t
batch_k
)
{
using
namespace
ck
;
...
...
@@ -784,19 +784,19 @@ struct DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Dl
Argument
(
const
InDataType
*
p_in_grid
,
WeiDataType
*
p_wei_grid
,
const
OutDataType
*
p_out_grid
,
ck
::
index_t
G
,
ck
::
index_t
N
,
ck
::
index_t
K
,
ck
::
index_t
C
,
std
::
array
<
ck
::
index_t
,
NDimSpatial
>&
input_spatial_lengths
,
std
::
array
<
ck
::
index_t
,
NDimSpatial
>&
filter_spatial_lengths
,
std
::
array
<
ck
::
index_t
,
NDimSpatial
>&
output_spatial_lengths
,
std
::
array
<
ck
::
index_t
,
NDimSpatial
+
3
>&
/*input_strides*/
,
std
::
array
<
ck
::
index_t
,
NDimSpatial
+
3
>&
/*output_strides*/
,
std
::
array
<
ck
::
index_t
,
NDimSpatial
>&
conv_filter_strides
,
std
::
array
<
ck
::
index_t
,
NDimSpatial
>&
conv_filter_dilations
,
std
::
array
<
ck
::
index_t
,
NDimSpatial
>&
input_left_pads
,
std
::
array
<
ck
::
index_t
,
NDimSpatial
>&
input_right_pads
,
const
ck
::
index_t
G
,
const
ck
::
index_t
N
,
const
ck
::
index_t
K
,
const
ck
::
index_t
C
,
const
std
::
array
<
ck
::
index_t
,
NDimSpatial
>&
input_spatial_lengths
,
const
std
::
array
<
ck
::
index_t
,
NDimSpatial
>&
filter_spatial_lengths
,
const
std
::
array
<
ck
::
index_t
,
NDimSpatial
>&
output_spatial_lengths
,
const
std
::
array
<
ck
::
index_t
,
NDimSpatial
+
3
>&
/*input_strides*/
,
const
std
::
array
<
ck
::
index_t
,
NDimSpatial
+
3
>&
/*output_strides*/
,
const
std
::
array
<
ck
::
index_t
,
NDimSpatial
>&
conv_filter_strides
,
const
std
::
array
<
ck
::
index_t
,
NDimSpatial
>&
conv_filter_dilations
,
const
std
::
array
<
ck
::
index_t
,
NDimSpatial
>&
input_left_pads
,
const
std
::
array
<
ck
::
index_t
,
NDimSpatial
>&
input_right_pads
,
InElementwiseOperation
in_element_op
,
WeiElementwiseOperation
wei_element_op
,
OutElementwiseOperation
out_element_op
,
...
...
@@ -899,18 +899,18 @@ struct DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Dl
InElementwiseOperation
c_element_op_
;
// for checking IsSupportedArgument()
index_t
Conv_G_
;
index_t
Conv_N_
;
index_t
Conv_K_
;
index_t
Conv_C_
;
std
::
array
<
ck
::
index_t
,
NDimSpatial
>
input_spatial_lengths_
;
std
::
array
<
ck
::
index_t
,
NDimSpatial
>
filter_spatial_lengths_
;
std
::
array
<
ck
::
index_t
,
NDimSpatial
>
output_spatial_lengths_
;
std
::
array
<
ck
::
index_t
,
NDimSpatial
>
conv_filter_strides_
;
std
::
array
<
ck
::
index_t
,
NDimSpatial
>
conv_filter_dilations_
;
std
::
array
<
ck
::
index_t
,
NDimSpatial
>
input_left_pads_
;
std
::
array
<
ck
::
index_t
,
NDimSpatial
>
input_right_pads_
;
const
index_t
Conv_G_
;
const
index_t
Conv_N_
;
const
index_t
Conv_K_
;
const
index_t
Conv_C_
;
const
std
::
array
<
ck
::
index_t
,
NDimSpatial
>
&
input_spatial_lengths_
;
const
std
::
array
<
ck
::
index_t
,
NDimSpatial
>
&
filter_spatial_lengths_
;
const
std
::
array
<
ck
::
index_t
,
NDimSpatial
>
&
output_spatial_lengths_
;
const
std
::
array
<
ck
::
index_t
,
NDimSpatial
>
&
conv_filter_strides_
;
const
std
::
array
<
ck
::
index_t
,
NDimSpatial
>
&
conv_filter_dilations_
;
const
std
::
array
<
ck
::
index_t
,
NDimSpatial
>
&
input_left_pads_
;
const
std
::
array
<
ck
::
index_t
,
NDimSpatial
>
&
input_right_pads_
;
index_t
k_batch_
;
};
...
...
@@ -1113,19 +1113,19 @@ struct DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Dl
static
auto
MakeArgument
(
const
InDataType
*
p_in_grid
,
WeiDataType
*
p_wei_grid
,
const
OutDataType
*
p_out_grid
,
ck
::
index_t
G
,
ck
::
index_t
N
,
ck
::
index_t
K
,
ck
::
index_t
C
,
std
::
array
<
ck
::
index_t
,
NDimSpatial
>&
input_spatial_lengths
,
std
::
array
<
ck
::
index_t
,
NDimSpatial
>&
filter_spatial_lengths
,
std
::
array
<
ck
::
index_t
,
NDimSpatial
>&
output_spatial_lengths
,
std
::
array
<
ck
::
index_t
,
NDimSpatial
+
3
>&
input_strides
,
std
::
array
<
ck
::
index_t
,
NDimSpatial
+
3
>&
output_strides
,
std
::
array
<
ck
::
index_t
,
NDimSpatial
>&
conv_filter_strides
,
std
::
array
<
ck
::
index_t
,
NDimSpatial
>&
conv_filter_dilations
,
std
::
array
<
ck
::
index_t
,
NDimSpatial
>&
input_left_pads
,
std
::
array
<
ck
::
index_t
,
NDimSpatial
>&
input_right_pads
,
const
ck
::
index_t
G
,
const
ck
::
index_t
N
,
const
ck
::
index_t
K
,
const
ck
::
index_t
C
,
const
std
::
array
<
ck
::
index_t
,
NDimSpatial
>&
input_spatial_lengths
,
const
std
::
array
<
ck
::
index_t
,
NDimSpatial
>&
filter_spatial_lengths
,
const
std
::
array
<
ck
::
index_t
,
NDimSpatial
>&
output_spatial_lengths
,
const
std
::
array
<
ck
::
index_t
,
NDimSpatial
+
3
>&
input_strides
,
const
std
::
array
<
ck
::
index_t
,
NDimSpatial
+
3
>&
output_strides
,
const
std
::
array
<
ck
::
index_t
,
NDimSpatial
>&
conv_filter_strides
,
const
std
::
array
<
ck
::
index_t
,
NDimSpatial
>&
conv_filter_dilations
,
const
std
::
array
<
ck
::
index_t
,
NDimSpatial
>&
input_left_pads
,
const
std
::
array
<
ck
::
index_t
,
NDimSpatial
>&
input_right_pads
,
InElementwiseOperation
in_element_op
,
WeiElementwiseOperation
wei_element_op
,
OutElementwiseOperation
out_element_op
,
...
...
@@ -1159,19 +1159,19 @@ struct DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Dl
MakeArgumentPointer
(
const
void
*
p_in_grid
,
void
*
p_wei_grid
,
const
void
*
p_out_grid
,
ck
::
index_t
G
,
ck
::
index_t
N
,
ck
::
index_t
K
,
ck
::
index_t
C
,
std
::
array
<
ck
::
index_t
,
NDimSpatial
>&
input_spatial_lengths
,
std
::
array
<
ck
::
index_t
,
NDimSpatial
>&
filter_spatial_lengths
,
std
::
array
<
ck
::
index_t
,
NDimSpatial
>&
output_spatial_lengths
,
std
::
array
<
ck
::
index_t
,
NDimSpatial
+
3
>&
input_strides
,
std
::
array
<
ck
::
index_t
,
NDimSpatial
+
3
>&
output_strides
,
std
::
array
<
ck
::
index_t
,
NDimSpatial
>&
conv_filter_strides
,
std
::
array
<
ck
::
index_t
,
NDimSpatial
>&
conv_filter_dilations
,
std
::
array
<
ck
::
index_t
,
NDimSpatial
>&
input_left_pads
,
std
::
array
<
ck
::
index_t
,
NDimSpatial
>&
input_right_pads
,
const
ck
::
index_t
G
,
const
ck
::
index_t
N
,
const
ck
::
index_t
K
,
const
ck
::
index_t
C
,
const
std
::
array
<
ck
::
index_t
,
NDimSpatial
>&
input_spatial_lengths
,
const
std
::
array
<
ck
::
index_t
,
NDimSpatial
>&
filter_spatial_lengths
,
const
std
::
array
<
ck
::
index_t
,
NDimSpatial
>&
output_spatial_lengths
,
const
std
::
array
<
ck
::
index_t
,
NDimSpatial
+
3
>&
input_strides
,
const
std
::
array
<
ck
::
index_t
,
NDimSpatial
+
3
>&
output_strides
,
const
std
::
array
<
ck
::
index_t
,
NDimSpatial
>&
conv_filter_strides
,
const
std
::
array
<
ck
::
index_t
,
NDimSpatial
>&
conv_filter_dilations
,
const
std
::
array
<
ck
::
index_t
,
NDimSpatial
>&
input_left_pads
,
const
std
::
array
<
ck
::
index_t
,
NDimSpatial
>&
input_right_pads
,
InElementwiseOperation
in_element_op
,
WeiElementwiseOperation
wei_element_op
,
OutElementwiseOperation
out_element_op
,
...
...
include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_xdl_cshuffle.hpp
View file @
139b950f
...
...
@@ -1086,21 +1086,21 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffle
Argument
(
const
InDataType
*
p_in_grid
,
WeiDataType
*
p_wei_grid
,
const
OutDataType
*
p_out_grid
,
ck
::
index_t
G
,
ck
::
index_t
N
,
ck
::
index_t
K
,
ck
::
index_t
C
,
std
::
array
<
ck
::
index_t
,
NDimSpatial
>&
input_spatial_lengths
,
std
::
array
<
ck
::
index_t
,
NDimSpatial
>&
filter_spatial_lengths
,
std
::
array
<
ck
::
index_t
,
NDimSpatial
>&
output_spatial_lengths
,
std
::
array
<
ck
::
index_t
,
NDimSpatial
+
3
>&
input_strides
,
std
::
array
<
ck
::
index_t
,
NDimSpatial
+
3
>&
output_strides
,
std
::
array
<
ck
::
index_t
,
NDimSpatial
>&
conv_filter_strides
,
std
::
array
<
ck
::
index_t
,
NDimSpatial
>&
conv_filter_dilations
,
std
::
array
<
ck
::
index_t
,
NDimSpatial
>&
input_left_pads
,
std
::
array
<
ck
::
index_t
,
NDimSpatial
>&
input_right_pads
,
ck
::
index_t
M01
,
ck
::
index_t
N01
,
const
ck
::
index_t
G
,
const
ck
::
index_t
N
,
const
ck
::
index_t
K
,
const
ck
::
index_t
C
,
const
std
::
array
<
ck
::
index_t
,
NDimSpatial
>&
input_spatial_lengths
,
const
std
::
array
<
ck
::
index_t
,
NDimSpatial
>&
filter_spatial_lengths
,
const
std
::
array
<
ck
::
index_t
,
NDimSpatial
>&
output_spatial_lengths
,
const
std
::
array
<
ck
::
index_t
,
NDimSpatial
+
3
>&
input_strides
,
const
std
::
array
<
ck
::
index_t
,
NDimSpatial
+
3
>&
output_strides
,
const
std
::
array
<
ck
::
index_t
,
NDimSpatial
>&
conv_filter_strides
,
const
std
::
array
<
ck
::
index_t
,
NDimSpatial
>&
conv_filter_dilations
,
const
std
::
array
<
ck
::
index_t
,
NDimSpatial
>&
input_left_pads
,
const
std
::
array
<
ck
::
index_t
,
NDimSpatial
>&
input_right_pads
,
const
ck
::
index_t
M01
,
const
ck
::
index_t
N01
,
InElementwiseOperation
in_element_op
,
WeiElementwiseOperation
wei_element_op
,
OutElementwiseOperation
out_element_op
,
...
...
@@ -1194,16 +1194,16 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffle
WeiElementwiseOperation
c_element_op_
;
// for checking IsSupportedArgument()
index_t
Conv_G_
;
index_t
Conv_N_
;
index_t
Conv_K_
;
index_t
Conv_C_
;
std
::
array
<
ck
::
index_t
,
NDimSpatial
>&
output_spatial_lengths_
;
std
::
array
<
ck
::
index_t
,
NDimSpatial
>&
filter_spatial_lengths_
;
std
::
array
<
ck
::
index_t
,
NDimSpatial
>&
conv_filter_strides_
;
std
::
array
<
ck
::
index_t
,
NDimSpatial
>&
input_left_pads_
;
std
::
array
<
ck
::
index_t
,
NDimSpatial
>&
input_right_pads_
;
index_t
k_batch_
;
const
index_t
Conv_G_
;
const
index_t
Conv_N_
;
const
index_t
Conv_K_
;
const
index_t
Conv_C_
;
const
std
::
array
<
ck
::
index_t
,
NDimSpatial
>&
output_spatial_lengths_
;
const
std
::
array
<
ck
::
index_t
,
NDimSpatial
>&
filter_spatial_lengths_
;
const
std
::
array
<
ck
::
index_t
,
NDimSpatial
>&
conv_filter_strides_
;
const
std
::
array
<
ck
::
index_t
,
NDimSpatial
>&
input_left_pads_
;
const
std
::
array
<
ck
::
index_t
,
NDimSpatial
>&
input_right_pads_
;
const
index_t
k_batch_
;
};
// Invoker
...
...
@@ -1390,23 +1390,23 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffle
static
auto
MakeArgument
(
const
InDataType
*
p_in_grid
,
WeiDataType
*
p_wei_grid
,
const
OutDataType
*
p_out_grid
,
ck
::
index_t
G
,
ck
::
index_t
N
,
ck
::
index_t
K
,
ck
::
index_t
C
,
std
::
array
<
ck
::
index_t
,
NDimSpatial
>&
input_spatial_lengths
,
std
::
array
<
ck
::
index_t
,
NDimSpatial
>&
filter_spatial_lengths
,
std
::
array
<
ck
::
index_t
,
NDimSpatial
>&
output_spatial_lengths
,
std
::
array
<
ck
::
index_t
,
NDimSpatial
+
3
>&
input_strides
,
std
::
array
<
ck
::
index_t
,
NDimSpatial
+
3
>&
output_strides
,
std
::
array
<
ck
::
index_t
,
NDimSpatial
>&
conv_filter_strides
,
std
::
array
<
ck
::
index_t
,
NDimSpatial
>&
conv_filter_dilations
,
std
::
array
<
ck
::
index_t
,
NDimSpatial
>&
input_left_pads
,
std
::
array
<
ck
::
index_t
,
NDimSpatial
>&
input_right_pads
,
const
ck
::
index_t
G
,
const
ck
::
index_t
N
,
const
ck
::
index_t
K
,
const
ck
::
index_t
C
,
const
std
::
array
<
ck
::
index_t
,
NDimSpatial
>&
input_spatial_lengths
,
const
std
::
array
<
ck
::
index_t
,
NDimSpatial
>&
filter_spatial_lengths
,
const
std
::
array
<
ck
::
index_t
,
NDimSpatial
>&
output_spatial_lengths
,
const
std
::
array
<
ck
::
index_t
,
NDimSpatial
+
3
>&
input_strides
,
const
std
::
array
<
ck
::
index_t
,
NDimSpatial
+
3
>&
output_strides
,
const
std
::
array
<
ck
::
index_t
,
NDimSpatial
>&
conv_filter_strides
,
const
std
::
array
<
ck
::
index_t
,
NDimSpatial
>&
conv_filter_dilations
,
const
std
::
array
<
ck
::
index_t
,
NDimSpatial
>&
input_left_pads
,
const
std
::
array
<
ck
::
index_t
,
NDimSpatial
>&
input_right_pads
,
InElementwiseOperation
in_element_op
,
WeiElementwiseOperation
wei_element_op
,
OutElementwiseOperation
out_element_op
,
ck
::
index_t
split_k
)
const
ck
::
index_t
split_k
)
{
return
Argument
{
p_in_grid
,
p_wei_grid
,
...
...
@@ -1438,23 +1438,23 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffle
MakeArgumentPointer
(
const
void
*
p_in_grid
,
void
*
p_wei_grid
,
const
void
*
p_out_grid
,
ck
::
index_t
G
,
ck
::
index_t
N
,
ck
::
index_t
K
,
ck
::
index_t
C
,
std
::
array
<
ck
::
index_t
,
NDimSpatial
>&
input_spatial_lengths
,
std
::
array
<
ck
::
index_t
,
NDimSpatial
>&
filter_spatial_lengths
,
std
::
array
<
ck
::
index_t
,
NDimSpatial
>&
output_spatial_lengths
,
std
::
array
<
ck
::
index_t
,
NDimSpatial
+
3
>&
input_strides
,
std
::
array
<
ck
::
index_t
,
NDimSpatial
+
3
>&
output_strides
,
std
::
array
<
ck
::
index_t
,
NDimSpatial
>&
conv_filter_strides
,
std
::
array
<
ck
::
index_t
,
NDimSpatial
>&
conv_filter_dilations
,
std
::
array
<
ck
::
index_t
,
NDimSpatial
>&
input_left_pads
,
std
::
array
<
ck
::
index_t
,
NDimSpatial
>&
input_right_pads
,
const
ck
::
index_t
G
,
const
ck
::
index_t
N
,
const
ck
::
index_t
K
,
const
ck
::
index_t
C
,
const
std
::
array
<
ck
::
index_t
,
NDimSpatial
>&
input_spatial_lengths
,
const
std
::
array
<
ck
::
index_t
,
NDimSpatial
>&
filter_spatial_lengths
,
const
std
::
array
<
ck
::
index_t
,
NDimSpatial
>&
output_spatial_lengths
,
const
std
::
array
<
ck
::
index_t
,
NDimSpatial
+
3
>&
input_strides
,
const
std
::
array
<
ck
::
index_t
,
NDimSpatial
+
3
>&
output_strides
,
const
std
::
array
<
ck
::
index_t
,
NDimSpatial
>&
conv_filter_strides
,
const
std
::
array
<
ck
::
index_t
,
NDimSpatial
>&
conv_filter_dilations
,
const
std
::
array
<
ck
::
index_t
,
NDimSpatial
>&
input_left_pads
,
const
std
::
array
<
ck
::
index_t
,
NDimSpatial
>&
input_right_pads
,
InElementwiseOperation
in_element_op
,
WeiElementwiseOperation
wei_element_op
,
OutElementwiseOperation
out_element_op
,
ck
::
index_t
split_k
)
override
const
ck
::
index_t
split_k
)
override
{
return
std
::
make_unique
<
Argument
>
(
static_cast
<
const
InDataType
*>
(
p_in_grid
),
static_cast
<
WeiDataType
*>
(
p_wei_grid
),
...
...
library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_weight/device_grouped_conv2d_bwd_weight_xdl_instance.hpp
View file @
139b950f
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#include <cstdlib>
#include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_xdl_cshuffle.hpp"
...
...
test/grouped_convnd_bwd_weight/test_grouped_convnd_bwd_weight.cpp
View file @
139b950f
...
...
@@ -9,6 +9,7 @@
#include <gtest/gtest.h>
#include "ck/utility/common_header.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "profiler/profile_grouped_conv_bwd_weight_impl.hpp"
...
...
@@ -23,11 +24,11 @@ class TestGroupedConvndBwdWeight : public ::testing::Test
using
InLayout
=
std
::
tuple_element_t
<
3
,
Tuple
>
;
using
WeiLayout
=
std
::
tuple_element_t
<
4
,
Tuple
>
;
using
OutLayout
=
std
::
tuple_element_t
<
5
,
Tuple
>
;
using
NDimSpatial
=
std
::
tuple_element_t
<
6
,
Tuple
>
;
std
::
vector
<
ck
::
utils
::
conv
::
ConvParam
>
conv_params
;
ck
::
index_t
split_k
{
2
};
template
<
ck
::
index_t
NDimSpatial
>
void
Run
()
{
EXPECT_FALSE
(
conv_params
.
empty
());
...
...
@@ -35,7 +36,7 @@ class TestGroupedConvndBwdWeight : public ::testing::Test
for
(
auto
&
param
:
conv_params
)
{
pass
=
pass
&&
ck
::
profiler
::
profile_grouped_conv_bwd_weight_impl
<
NDimSpatial
,
pass
=
pass
&&
ck
::
profiler
::
profile_grouped_conv_bwd_weight_impl
<
NDimSpatial
{}
,
InLayout
,
WeiLayout
,
OutLayout
,
...
...
@@ -70,21 +71,21 @@ class TestGroupedConvndBwdWeight3d : public TestGroupedConvndBwdWeight<Tuple>
using
namespace
ck
::
tensor_layout
::
convolution
;
using
KernelTypes1d
=
::
testing
::
Types
<
std
::
tuple
<
float
,
float
,
float
,
GNWC
,
GKXC
,
GNWK
>
,
std
::
tuple
<
ck
::
half_t
,
ck
::
half_t
,
ck
::
half_t
,
GNWC
,
GKXC
,
GNWK
>
,
std
::
tuple
<
ck
::
bhalf_t
,
float
,
ck
::
bhalf_t
,
GNWC
,
GKXC
,
GNWK
>>
;
using
KernelTypes2d
=
::
testing
::
Types
<
std
::
tuple
<
float
,
float
,
float
,
GNHWC
,
GKYXC
,
GNHWK
>
,
std
::
tuple
<
ck
::
half_t
,
ck
::
half_t
,
ck
::
half_t
,
GNHWC
,
GKYXC
,
GNHWK
>
,
std
::
tuple
<
ck
::
bhalf_t
,
float
,
ck
::
bhalf_t
,
GNHWC
,
GKYXC
,
GNHWK
>
,
std
::
tuple
<
float
,
float
,
float
,
NHWGC
,
GKYXC
,
NHWGK
>
,
std
::
tuple
<
ck
::
half_t
,
ck
::
half_t
,
ck
::
half_t
,
NHWGC
,
GKYXC
,
NHWGK
>
,
std
::
tuple
<
ck
::
bhalf_t
,
float
,
ck
::
bhalf_t
,
NHWGC
,
GKYXC
,
NHWGK
>>
;
using
KernelTypes3d
=
::
testing
::
Types
<
std
::
tuple
<
float
,
float
,
float
,
GNDHWC
,
GKZYXC
,
GNDHWK
>
,
std
::
tuple
<
ck
::
half_t
,
ck
::
half_t
,
ck
::
half_t
,
GNDHWC
,
GKZYXC
,
GNDHWK
>
,
std
::
tuple
<
ck
::
bhalf_t
,
float
,
ck
::
bhalf_t
,
GNDHWC
,
GKZYXC
,
GNDHWK
>>
;
using
KernelTypes1d
=
::
testing
::
Types
<
std
::
tuple
<
float
,
float
,
float
,
GNWC
,
GKXC
,
GNWK
,
ck
::
Number
<
1
>
>
,
std
::
tuple
<
ck
::
half_t
,
ck
::
half_t
,
ck
::
half_t
,
GNWC
,
GKXC
,
GNWK
,
ck
::
Number
<
1
>
>
,
std
::
tuple
<
ck
::
bhalf_t
,
float
,
ck
::
bhalf_t
,
GNWC
,
GKXC
,
GNWK
,
ck
::
Number
<
1
>
>>
;
using
KernelTypes2d
=
::
testing
::
Types
<
std
::
tuple
<
float
,
float
,
float
,
GNHWC
,
GKYXC
,
GNHWK
,
ck
::
Number
<
2
>
>
,
std
::
tuple
<
ck
::
half_t
,
ck
::
half_t
,
ck
::
half_t
,
GNHWC
,
GKYXC
,
GNHWK
,
ck
::
Number
<
2
>
>
,
std
::
tuple
<
ck
::
bhalf_t
,
float
,
ck
::
bhalf_t
,
GNHWC
,
GKYXC
,
GNHWK
,
ck
::
Number
<
2
>
>
,
std
::
tuple
<
float
,
float
,
float
,
NHWGC
,
GKYXC
,
NHWGK
,
ck
::
Number
<
2
>
>
,
std
::
tuple
<
ck
::
half_t
,
ck
::
half_t
,
ck
::
half_t
,
NHWGC
,
GKYXC
,
NHWGK
,
ck
::
Number
<
2
>
>
,
std
::
tuple
<
ck
::
bhalf_t
,
float
,
ck
::
bhalf_t
,
NHWGC
,
GKYXC
,
NHWGK
,
ck
::
Number
<
2
>
>>
;
using
KernelTypes3d
=
::
testing
::
Types
<
std
::
tuple
<
float
,
float
,
float
,
GNDHWC
,
GKZYXC
,
GNDHWK
,
ck
::
Number
<
3
>
>
,
std
::
tuple
<
ck
::
half_t
,
ck
::
half_t
,
ck
::
half_t
,
GNDHWC
,
GKZYXC
,
GNDHWK
,
ck
::
Number
<
3
>
>
,
std
::
tuple
<
ck
::
bhalf_t
,
float
,
ck
::
bhalf_t
,
GNDHWC
,
GKZYXC
,
GNDHWK
,
ck
::
Number
<
3
>
>>
;
TYPED_TEST_SUITE
(
TestGroupedConvndBwdWeight1d
,
KernelTypes1d
);
TYPED_TEST_SUITE
(
TestGroupedConvndBwdWeight2d
,
KernelTypes2d
);
...
...
@@ -96,7 +97,7 @@ TYPED_TEST(TestGroupedConvndBwdWeight1d, Test1D)
this
->
conv_params
.
push_back
({
1
,
2
,
128
,
128
,
256
,
{
1
},
{
14
},
{
2
},
{
1
},
{
0
},
{
0
}});
this
->
conv_params
.
push_back
({
1
,
2
,
32
,
128
,
256
,
{
3
},
{
28
},
{
1
},
{
1
},
{
1
},
{
1
}});
this
->
conv_params
.
push_back
({
1
,
2
,
128
,
128
,
256
,
{
1
},
{
3
},
{
1
},
{
1
},
{
0
},
{
0
}});
this
->
template
Run
<
1
>
();
this
->
Run
();
}
TYPED_TEST
(
TestGroupedConvndBwdWeight2d
,
Test2D
)
...
...
@@ -108,7 +109,7 @@ TYPED_TEST(TestGroupedConvndBwdWeight2d, Test2D)
{
2
,
2
,
4
,
128
,
256
,
{
3
,
3
},
{
14
,
14
},
{
1
,
1
},
{
1
,
1
},
{
1
,
1
},
{
1
,
1
}});
this
->
conv_params
.
push_back
(
{
2
,
2
,
128
,
128
,
256
,
{
1
,
1
},
{
3
,
3
},
{
1
,
1
},
{
1
,
1
},
{
0
,
0
},
{
0
,
0
}});
this
->
template
Run
<
2
>
();
this
->
Run
();
}
TYPED_TEST
(
TestGroupedConvndBwdWeight3d
,
Test3D
)
...
...
@@ -120,5 +121,5 @@ TYPED_TEST(TestGroupedConvndBwdWeight3d, Test3D)
{
3
,
2
,
2
,
128
,
256
,
{
3
,
3
,
3
},
{
14
,
14
,
3
},
{
1
,
1
,
1
},
{
1
,
1
,
1
},
{
1
,
1
,
1
},
{
1
,
1
,
1
}});
this
->
conv_params
.
push_back
(
{
3
,
2
,
32
,
128
,
256
,
{
1
,
1
,
1
},
{
3
,
3
,
3
},
{
1
,
1
,
1
},
{
1
,
1
,
1
},
{
0
,
0
,
0
},
{
0
,
0
,
0
}});
this
->
template
Run
<
3
>
();
this
->
Run
();
}
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment