Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
822a1110
Commit
822a1110
authored
Sep 07, 2023
by
Bartlomiej Kocot
Browse files
Add M and N padding
parent
5112a51e
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
158 additions
and
129 deletions
+158
-129
include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_dl.hpp
...ion/gpu/device/impl/device_grouped_conv_bwd_weight_dl.hpp
+145
-126
test/grouped_convnd_bwd_weight/test_grouped_convnd_bwd_weight.cpp
...uped_convnd_bwd_weight/test_grouped_convnd_bwd_weight.cpp
+13
-3
No files found.
include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_dl.hpp
View file @
822a1110
...
...
@@ -14,6 +14,7 @@
#include "ck/tensor_operation/gpu/device/device_grouped_conv_bwd_weight.hpp"
#include "ck/tensor_operation/gpu/device/convolution_backward_weight_specialization.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_gemm_dl_v1r3.hpp"
#include "ck/tensor_operation/gpu/device/matrix_padder.hpp"
#include "ck/host_utility/device_prop.hpp"
#include "ck/host_utility/kernel_launch.hpp"
...
...
@@ -269,14 +270,10 @@ struct DeviceGroupedConvBwdWeight_Dl : public DeviceGroupedConvBwdWeight<NDimSpa
const
auto
OutWStride
=
e_g_n_k_wos_strides
[
spatial_offset
];
const
index_t
GemmKTotal
=
N
*
Wo
;
const
index_t
GemmM
=
K
;
const
index_t
GemmN
=
C
*
X
;
const
index_t
GemmKBatch
=
batch_k
;
const
index_t
GemmK0
=
math
::
integer_divide_ceil
(
GemmKTotal
,
GemmK1Number
*
K0PerBlock
*
GemmKBatch
)
*
K0PerBlock
;
const
index_t
GemmKPad
=
GemmKBatch
*
GemmK0
*
GemmK1Number
;
if
constexpr
(
ConvBackwardWeightSpecialization
==
ConvolutionBackwardWeightSpecialization
::
Filter1x1Stride1Pad0
)
...
...
@@ -285,17 +282,17 @@ struct DeviceGroupedConvBwdWeight_Dl : public DeviceGroupedConvBwdWeight<NDimSpa
const
auto
out_gemmktotal_gemmm_grid_desc
=
make_naive_tensor_descriptor
(
make_tuple
(
N
*
Wo
,
K
),
make_tuple
(
OutWStride
,
OutKStride
));
const
auto
out_gemmkpad_gemmm_grid_desc
=
transform_tensor_descriptor
(
out_gemmktotal_gemmm_grid_desc
,
make_tuple
(
make_right_pad_transform
(
GemmKTotal
,
GemmKPad
-
GemmKTotal
),
make_pass_through_transform
(
GemmM
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}));
const
auto
out_gemmkpad_gemmmpad_grid_desc
=
ck
::
tensor_operation
::
device
::
PadTensorDescriptor
(
out_gemmktotal_gemmm_grid_desc
,
make_tuple
(
GemmK1Number
*
K0PerBlock
*
GemmKBatch
,
MPerBlock
),
Sequence
<
true
,
true
>
{});
const
auto
out_gemmkbatch_gemmk0_gemmm_gemmk1_grid_desc
=
transform_tensor_descriptor
(
out_gemmkpad_gemmm_grid_desc
,
make_tuple
(
make_unmerge_transform
(
make_tuple
(
GemmKBatch
,
GemmK0
,
GemmK1Number
)),
make_pass_through_transform
(
GemmM
)),
out_gemmkpad_gemmmpad_grid_desc
,
make_tuple
(
make_unmerge_transform
(
make_tuple
(
GemmKBatch
,
GemmK0
,
GemmK1Number
)),
make_pass_through_transform
(
out_gemmkpad_gemmmpad_grid_desc
.
GetLength
(
I1
))),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
0
,
1
,
3
>
{},
Sequence
<
2
>
{}));
...
...
@@ -303,17 +300,17 @@ struct DeviceGroupedConvBwdWeight_Dl : public DeviceGroupedConvBwdWeight<NDimSpa
const
auto
in_gemmktotal_gemmn_grid_desc
=
make_naive_tensor_descriptor
(
make_tuple
(
N
*
Wi
,
C
),
make_tuple
(
InWStride
,
InCStride
));
const
auto
in_gemmkpad_gemmn_grid_desc
=
transform_tensor_descriptor
(
in_gemmktotal_gemmn_grid_desc
,
make_tuple
(
make_right_pad_transform
(
GemmKTotal
,
GemmKPad
-
GemmKTotal
),
make_pass_through_transform
(
GemmM
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}));
const
auto
in_gemmkpad_gemmnpad_grid_desc
=
ck
::
tensor_operation
::
device
::
PadTensorDescriptor
(
in_gemmktotal_gemmn_grid_desc
,
make_tuple
(
GemmK1Number
*
K0PerBlock
*
GemmKBatch
,
NPerBlock
),
Sequence
<
true
,
true
>
{});
const
auto
in_gemmkbatch_gemmk0_gemmn_gemmk1_grid_desc
=
transform_tensor_descriptor
(
in_gemmkpad_gemmn_grid_desc
,
make_tuple
(
make_unmerge_transform
(
make_tuple
(
GemmKBatch
,
GemmK0
,
GemmK1Number
)),
make_pass_through_transform
(
GemmM
)),
in_gemmkpad_gemmnpad_grid_desc
,
make_tuple
(
make_unmerge_transform
(
make_tuple
(
GemmKBatch
,
GemmK0
,
GemmK1Number
)),
make_pass_through_transform
(
in_gemmkpad_gemmnpad_grid_desc
.
GetLength
(
I1
))),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
0
,
1
,
3
>
{},
Sequence
<
2
>
{}));
...
...
@@ -321,9 +318,14 @@ struct DeviceGroupedConvBwdWeight_Dl : public DeviceGroupedConvBwdWeight<NDimSpa
const
auto
wei_gemmm_gemmn_grid_desc
=
make_naive_tensor_descriptor
(
make_tuple
(
K
,
X
*
C
),
make_tuple
(
WeiKStride
,
WeiCStride
));
const
auto
wei_gemmmpad_gemmnpad_grid_desc
=
ck
::
tensor_operation
::
device
::
PadTensorDescriptor
(
wei_gemmm_gemmn_grid_desc
,
make_tuple
(
MPerBlock
,
NPerBlock
),
Sequence
<
true
,
true
>
{});
return
make_tuple
(
out_gemmkbatch_gemmk0_gemmm_gemmk1_grid_desc
,
in_gemmkbatch_gemmk0_gemmn_gemmk1_grid_desc
,
wei_gemmm_gemmn_grid_desc
);
wei_gemmm
pad
_gemmn
pad
_grid_desc
);
}
else
{
...
...
@@ -333,17 +335,17 @@ struct DeviceGroupedConvBwdWeight_Dl : public DeviceGroupedConvBwdWeight<NDimSpa
make_tuple
(
N
,
Wi
,
C
),
make_tuple
(
InNStride
,
InWStride
,
InCStride
));
// A: output tensor
const
auto
out_gemmkpad_gemmm_grid_desc
=
transform_tensor_descriptor
(
out_gemmktotal_gemmm_grid_desc
,
make_tuple
(
make_right_pad_transform
(
GemmKTotal
,
GemmKPad
-
GemmKTotal
),
make_pass_through_transform
(
GemmM
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}));
const
auto
out_gemmkpad_gemmmpad_grid_desc
=
ck
::
tensor_operation
::
device
::
PadTensorDescriptor
(
out_gemmktotal_gemmm_grid_desc
,
make_tuple
(
GemmK1Number
*
K0PerBlock
*
GemmKBatch
,
MPerBlock
),
Sequence
<
true
,
true
>
{});
const
auto
out_gemmkbatch_gemmk0_gemmm_gemmk1_grid_desc
=
transform_tensor_descriptor
(
out_gemmkpad_gemmm_grid_desc
,
make_tuple
(
make_unmerge_transform
(
make_tuple
(
GemmKBatch
,
GemmK0
,
GemmK1Number
)),
make_pass_through_transform
(
GemmM
)),
out_gemmkpad_gemmmpad_grid_desc
,
make_tuple
(
make_unmerge_transform
(
make_tuple
(
GemmKBatch
,
GemmK0
,
GemmK1Number
)),
make_pass_through_transform
(
out_gemmkpad_gemmmpad_grid_desc
.
GetLength
(
I1
))),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
0
,
1
,
3
>
{},
Sequence
<
2
>
{}));
...
...
@@ -372,17 +374,17 @@ struct DeviceGroupedConvBwdWeight_Dl : public DeviceGroupedConvBwdWeight<NDimSpa
make_tuple
(
Sequence
<
1
,
3
>
{},
Sequence
<
0
,
2
>
{}),
make_tuple
(
Sequence
<
1
>
{},
Sequence
<
0
>
{}));
const
auto
in_gemmkpad_gemmn_grid_desc
=
transform_tensor_descriptor
(
in_gemmktotal_gemmn_grid_desc
,
make_tuple
(
make_right_pad_transform
(
GemmKTotal
,
GemmKPad
-
GemmKTotal
),
make_pass_through_transform
(
GemmN
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}));
const
auto
in_gemmkpad_gemmnpad_grid_desc
=
ck
::
tensor_operation
::
device
::
PadTensorDescriptor
(
in_gemmktotal_gemmn_grid_desc
,
make_tuple
(
GemmK1Number
*
K0PerBlock
*
GemmKBatch
,
NPerBlock
),
Sequence
<
true
,
true
>
{});
const
auto
in_gemmkbatch_gemmk0_gemmn_gemmk1_grid_desc
=
transform_tensor_descriptor
(
in_gemmkpad_gemmn_grid_desc
,
make_tuple
(
make_unmerge_transform
(
make_tuple
(
GemmKBatch
,
GemmK0
,
GemmK1Number
)),
make_pass_through_transform
(
GemmN
)),
in_gemmkpad_gemmnpad_grid_desc
,
make_tuple
(
make_unmerge_transform
(
make_tuple
(
GemmKBatch
,
GemmK0
,
GemmK1Number
)),
make_pass_through_transform
(
in_gemmkpad_gemmnpad_grid_desc
.
GetLength
(
I1
))),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
0
,
1
,
3
>
{},
Sequence
<
2
>
{}));
...
...
@@ -390,9 +392,14 @@ struct DeviceGroupedConvBwdWeight_Dl : public DeviceGroupedConvBwdWeight<NDimSpa
const
auto
wei_gemmm_gemmn_grid_desc
=
make_naive_tensor_descriptor
(
make_tuple
(
K
,
X
*
C
),
make_tuple
(
WeiKStride
,
WeiCStride
));
const
auto
wei_gemmmpad_gemmnpad_grid_desc
=
ck
::
tensor_operation
::
device
::
PadTensorDescriptor
(
wei_gemmm_gemmn_grid_desc
,
make_tuple
(
MPerBlock
,
NPerBlock
),
Sequence
<
true
,
true
>
{});
return
make_tuple
(
out_gemmkbatch_gemmk0_gemmm_gemmk1_grid_desc
,
in_gemmkbatch_gemmk0_gemmn_gemmk1_grid_desc
,
wei_gemmm_gemmn_grid_desc
);
wei_gemmm
pad
_gemmn
pad
_grid_desc
);
}
}
// function end
...
...
@@ -441,14 +448,10 @@ struct DeviceGroupedConvBwdWeight_Dl : public DeviceGroupedConvBwdWeight<NDimSpa
const
auto
OutWStride
=
e_g_n_k_wos_strides
[
spatial_offset
+
I1
];
const
index_t
GemmKTotal
=
N
*
Ho
*
Wo
;
const
index_t
GemmM
=
K
;
const
index_t
GemmN
=
C
*
X
*
Y
;
const
index_t
GemmKBatch
=
batch_k
;
const
index_t
GemmK0
=
math
::
integer_divide_ceil
(
GemmKTotal
,
GemmK1Number
*
K0PerBlock
*
GemmKBatch
)
*
K0PerBlock
;
const
index_t
GemmKPad
=
GemmKBatch
*
GemmK0
*
GemmK1Number
;
if
constexpr
(
ConvBackwardWeightSpecialization
==
ConvolutionBackwardWeightSpecialization
::
Filter1x1Stride1Pad0
)
...
...
@@ -457,17 +460,17 @@ struct DeviceGroupedConvBwdWeight_Dl : public DeviceGroupedConvBwdWeight<NDimSpa
const
auto
out_gemmktotal_gemmm_grid_desc
=
make_naive_tensor_descriptor
(
make_tuple
(
N
*
Ho
*
Wo
,
K
),
make_tuple
(
OutWStride
,
OutKStride
));
const
auto
out_gemmkpad_gemmm_grid_desc
=
transform_tensor_descriptor
(
out_gemmktotal_gemmm_grid_desc
,
make_tuple
(
make_right_pad_transform
(
GemmKTotal
,
GemmKPad
-
GemmKTotal
),
make_pass_through_transform
(
GemmM
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}));
const
auto
out_gemmkpad_gemmmpad_grid_desc
=
ck
::
tensor_operation
::
device
::
PadTensorDescriptor
(
out_gemmktotal_gemmm_grid_desc
,
make_tuple
(
GemmK1Number
*
K0PerBlock
*
GemmKBatch
,
MPerBlock
),
Sequence
<
true
,
true
>
{});
const
auto
out_gemmkbatch_gemmk0_gemmm_gemmk1_grid_desc
=
transform_tensor_descriptor
(
out_gemmkpad_gemmm_grid_desc
,
make_tuple
(
make_unmerge_transform
(
make_tuple
(
GemmKBatch
,
GemmK0
,
GemmK1Number
)),
make_pass_through_transform
(
GemmM
)),
out_gemmkpad_gemmmpad_grid_desc
,
make_tuple
(
make_unmerge_transform
(
make_tuple
(
GemmKBatch
,
GemmK0
,
GemmK1Number
)),
make_pass_through_transform
(
out_gemmkpad_gemmmpad_grid_desc
.
GetLength
(
I1
))),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
0
,
1
,
3
>
{},
Sequence
<
2
>
{}));
...
...
@@ -475,17 +478,17 @@ struct DeviceGroupedConvBwdWeight_Dl : public DeviceGroupedConvBwdWeight<NDimSpa
const
auto
in_gemmktotal_gemmn_grid_desc
=
make_naive_tensor_descriptor
(
make_tuple
(
N
*
Hi
*
Wi
,
C
),
make_tuple
(
InWStride
,
InCStride
));
const
auto
in_gemmkpad_gemmn_grid_desc
=
transform_tensor_descriptor
(
in_gemmktotal_gemmn_grid_desc
,
make_tuple
(
make_right_pad_transform
(
GemmKTotal
,
GemmKPad
-
GemmKTotal
),
make_pass_through_transform
(
GemmM
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}));
const
auto
in_gemmkpad_gemmnpad_grid_desc
=
ck
::
tensor_operation
::
device
::
PadTensorDescriptor
(
in_gemmktotal_gemmn_grid_desc
,
make_tuple
(
GemmK1Number
*
K0PerBlock
*
GemmKBatch
,
NPerBlock
),
Sequence
<
true
,
true
>
{});
const
auto
in_gemmkbatch_gemmk0_gemmn_gemmk1_grid_desc
=
transform_tensor_descriptor
(
in_gemmkpad_gemmn_grid_desc
,
make_tuple
(
make_unmerge_transform
(
make_tuple
(
GemmKBatch
,
GemmK0
,
GemmK1Number
)),
make_pass_through_transform
(
GemmM
)),
in_gemmkpad_gemmnpad_grid_desc
,
make_tuple
(
make_unmerge_transform
(
make_tuple
(
GemmKBatch
,
GemmK0
,
GemmK1Number
)),
make_pass_through_transform
(
in_gemmkpad_gemmnpad_grid_desc
.
GetLength
(
I1
))),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
0
,
1
,
3
>
{},
Sequence
<
2
>
{}));
...
...
@@ -493,9 +496,14 @@ struct DeviceGroupedConvBwdWeight_Dl : public DeviceGroupedConvBwdWeight<NDimSpa
const
auto
wei_gemmm_gemmn_grid_desc
=
make_naive_tensor_descriptor
(
make_tuple
(
K
,
Y
*
X
*
C
),
make_tuple
(
WeiKStride
,
WeiCStride
));
const
auto
wei_gemmmpad_gemmnpad_grid_desc
=
ck
::
tensor_operation
::
device
::
PadTensorDescriptor
(
wei_gemmm_gemmn_grid_desc
,
make_tuple
(
MPerBlock
,
NPerBlock
),
Sequence
<
true
,
true
>
{});
return
make_tuple
(
out_gemmkbatch_gemmk0_gemmm_gemmk1_grid_desc
,
in_gemmkbatch_gemmk0_gemmn_gemmk1_grid_desc
,
wei_gemmm_gemmn_grid_desc
);
wei_gemmm
pad
_gemmn
pad
_grid_desc
);
}
else
{
...
...
@@ -505,17 +513,17 @@ struct DeviceGroupedConvBwdWeight_Dl : public DeviceGroupedConvBwdWeight<NDimSpa
make_tuple
(
N
,
Hi
,
Wi
,
C
),
make_tuple
(
InNStride
,
InHStride
,
InWStride
,
InCStride
));
// A: output tensor
const
auto
out_gemmkpad_gemmm_grid_desc
=
transform_tensor_descriptor
(
out_gemmktotal_gemmm_grid_desc
,
make_tuple
(
make_right_pad_transform
(
GemmKTotal
,
GemmKPad
-
GemmKTotal
),
make_pass_through_transform
(
GemmM
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}));
const
auto
out_gemmkpad_gemmmpad_grid_desc
=
ck
::
tensor_operation
::
device
::
PadTensorDescriptor
(
out_gemmktotal_gemmm_grid_desc
,
make_tuple
(
GemmK1Number
*
K0PerBlock
*
GemmKBatch
,
MPerBlock
),
Sequence
<
true
,
true
>
{});
const
auto
out_gemmkbatch_gemmk0_gemmm_gemmk1_grid_desc
=
transform_tensor_descriptor
(
out_gemmkpad_gemmm_grid_desc
,
make_tuple
(
make_unmerge_transform
(
make_tuple
(
GemmKBatch
,
GemmK0
,
GemmK1Number
)),
make_pass_through_transform
(
GemmM
)),
out_gemmkpad_gemmmpad_grid_desc
,
make_tuple
(
make_unmerge_transform
(
make_tuple
(
GemmKBatch
,
GemmK0
,
GemmK1Number
)),
make_pass_through_transform
(
out_gemmkpad_gemmmpad_grid_desc
.
GetLength
(
I1
))),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
0
,
1
,
3
>
{},
Sequence
<
2
>
{}));
...
...
@@ -546,17 +554,17 @@ struct DeviceGroupedConvBwdWeight_Dl : public DeviceGroupedConvBwdWeight<NDimSpa
make_tuple
(
Sequence
<
1
,
3
,
5
>
{},
Sequence
<
0
,
2
,
4
>
{}),
make_tuple
(
Sequence
<
1
>
{},
Sequence
<
0
>
{}));
const
auto
in_gemmkpad_gemmn_grid_desc
=
transform_tensor_descriptor
(
in_gemmktotal_gemmn_grid_desc
,
make_tuple
(
make_right_pad_transform
(
GemmKTotal
,
GemmKPad
-
GemmKTotal
),
make_pass_through_transform
(
GemmN
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}));
const
auto
in_gemmkpad_gemmnpad_grid_desc
=
ck
::
tensor_operation
::
device
::
PadTensorDescriptor
(
in_gemmktotal_gemmn_grid_desc
,
make_tuple
(
GemmK1Number
*
K0PerBlock
*
GemmKBatch
,
NPerBlock
),
Sequence
<
true
,
true
>
{});
const
auto
in_gemmkbatch_gemmk0_gemmn_gemmk1_grid_desc
=
transform_tensor_descriptor
(
in_gemmkpad_gemmn_grid_desc
,
make_tuple
(
make_unmerge_transform
(
make_tuple
(
GemmKBatch
,
GemmK0
,
GemmK1Number
)),
make_pass_through_transform
(
GemmN
)),
in_gemmkpad_gemmnpad_grid_desc
,
make_tuple
(
make_unmerge_transform
(
make_tuple
(
GemmKBatch
,
GemmK0
,
GemmK1Number
)),
make_pass_through_transform
(
in_gemmkpad_gemmnpad_grid_desc
.
GetLength
(
I1
))),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
0
,
1
,
3
>
{},
Sequence
<
2
>
{}));
...
...
@@ -564,9 +572,14 @@ struct DeviceGroupedConvBwdWeight_Dl : public DeviceGroupedConvBwdWeight<NDimSpa
const
auto
wei_gemmm_gemmn_grid_desc
=
make_naive_tensor_descriptor
(
make_tuple
(
K
,
Y
*
X
*
C
),
make_tuple
(
WeiKStride
,
WeiCStride
));
const
auto
wei_gemmmpad_gemmnpad_grid_desc
=
ck
::
tensor_operation
::
device
::
PadTensorDescriptor
(
wei_gemmm_gemmn_grid_desc
,
make_tuple
(
MPerBlock
,
NPerBlock
),
Sequence
<
true
,
true
>
{});
return
make_tuple
(
out_gemmkbatch_gemmk0_gemmm_gemmk1_grid_desc
,
in_gemmkbatch_gemmk0_gemmn_gemmk1_grid_desc
,
wei_gemmm_gemmn_grid_desc
);
wei_gemmm
pad
_gemmn
pad
_grid_desc
);
}
}
// function end
...
...
@@ -624,14 +637,10 @@ struct DeviceGroupedConvBwdWeight_Dl : public DeviceGroupedConvBwdWeight<NDimSpa
const
auto
OutWStride
=
e_g_n_k_wos_strides
[
spatial_offset
+
I2
];
const
index_t
GemmKTotal
=
N
*
Do
*
Ho
*
Wo
;
const
index_t
GemmM
=
K
;
const
index_t
GemmN
=
C
*
Z
*
X
*
Y
;
const
index_t
GemmKBatch
=
batch_k
;
const
index_t
GemmK0
=
math
::
integer_divide_ceil
(
GemmKTotal
,
GemmK1Number
*
K0PerBlock
*
GemmKBatch
)
*
K0PerBlock
;
const
index_t
GemmKPad
=
GemmKBatch
*
GemmK0
*
GemmK1Number
;
if
constexpr
(
ConvBackwardWeightSpecialization
==
ConvolutionBackwardWeightSpecialization
::
Filter1x1Stride1Pad0
)
...
...
@@ -640,17 +649,17 @@ struct DeviceGroupedConvBwdWeight_Dl : public DeviceGroupedConvBwdWeight<NDimSpa
const
auto
out_gemmktotal_gemmm_grid_desc
=
make_naive_tensor_descriptor
(
make_tuple
(
N
*
Do
*
Ho
*
Wo
,
K
),
make_tuple
(
OutWStride
,
OutKStride
));
const
auto
out_gemmkpad_gemmm_grid_desc
=
transform_tensor_descriptor
(
out_gemmktotal_gemmm_grid_desc
,
make_tuple
(
make_right_pad_transform
(
GemmKTotal
,
GemmKPad
-
GemmKTotal
),
make_pass_through_transform
(
GemmM
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}));
const
auto
out_gemmkpad_gemmmpad_grid_desc
=
ck
::
tensor_operation
::
device
::
PadTensorDescriptor
(
out_gemmktotal_gemmm_grid_desc
,
make_tuple
(
GemmK1Number
*
K0PerBlock
*
GemmKBatch
,
MPerBlock
),
Sequence
<
true
,
true
>
{});
const
auto
out_gemmkbatch_gemmk0_gemmm_gemmk1_grid_desc
=
transform_tensor_descriptor
(
out_gemmkpad_gemmm_grid_desc
,
make_tuple
(
make_unmerge_transform
(
make_tuple
(
GemmKBatch
,
GemmK0
,
GemmK1Number
)),
make_pass_through_transform
(
GemmM
)),
out_gemmkpad_gemmmpad_grid_desc
,
make_tuple
(
make_unmerge_transform
(
make_tuple
(
GemmKBatch
,
GemmK0
,
GemmK1Number
)),
make_pass_through_transform
(
out_gemmkpad_gemmmpad_grid_desc
.
GetLength
(
I1
))),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
0
,
1
,
3
>
{},
Sequence
<
2
>
{}));
...
...
@@ -658,17 +667,17 @@ struct DeviceGroupedConvBwdWeight_Dl : public DeviceGroupedConvBwdWeight<NDimSpa
const
auto
in_gemmktotal_gemmn_grid_desc
=
make_naive_tensor_descriptor
(
make_tuple
(
N
*
Di
*
Hi
*
Wi
,
C
),
make_tuple
(
InWStride
,
InCStride
));
const
auto
in_gemmkpad_gemmn_grid_desc
=
transform_tensor_descriptor
(
in_gemmktotal_gemmn_grid_desc
,
make_tuple
(
make_right_pad_transform
(
GemmKTotal
,
GemmKPad
-
GemmKTotal
),
make_pass_through_transform
(
GemmM
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}));
const
auto
in_gemmkpad_gemmnpad_grid_desc
=
ck
::
tensor_operation
::
device
::
PadTensorDescriptor
(
in_gemmktotal_gemmn_grid_desc
,
make_tuple
(
GemmK1Number
*
K0PerBlock
*
GemmKBatch
,
NPerBlock
),
Sequence
<
true
,
true
>
{});
const
auto
in_gemmkbatch_gemmk0_gemmn_gemmk1_grid_desc
=
transform_tensor_descriptor
(
in_gemmkpad_gemmn_grid_desc
,
make_tuple
(
make_unmerge_transform
(
make_tuple
(
GemmKBatch
,
GemmK0
,
GemmK1Number
)),
make_pass_through_transform
(
GemmM
)),
in_gemmkpad_gemmnpad_grid_desc
,
make_tuple
(
make_unmerge_transform
(
make_tuple
(
GemmKBatch
,
GemmK0
,
GemmK1Number
)),
make_pass_through_transform
(
in_gemmkpad_gemmnpad_grid_desc
.
GetLength
(
I1
))),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
0
,
1
,
3
>
{},
Sequence
<
2
>
{}));
...
...
@@ -676,9 +685,14 @@ struct DeviceGroupedConvBwdWeight_Dl : public DeviceGroupedConvBwdWeight<NDimSpa
const
auto
wei_gemmm_gemmn_grid_desc
=
make_naive_tensor_descriptor
(
make_tuple
(
K
,
Z
*
Y
*
X
*
C
),
make_tuple
(
WeiKStride
,
WeiCStride
));
const
auto
wei_gemmmpad_gemmnpad_grid_desc
=
ck
::
tensor_operation
::
device
::
PadTensorDescriptor
(
wei_gemmm_gemmn_grid_desc
,
make_tuple
(
MPerBlock
,
NPerBlock
),
Sequence
<
true
,
true
>
{});
return
make_tuple
(
out_gemmkbatch_gemmk0_gemmm_gemmk1_grid_desc
,
in_gemmkbatch_gemmk0_gemmn_gemmk1_grid_desc
,
wei_gemmm_gemmn_grid_desc
);
wei_gemmm
pad
_gemmn
pad
_grid_desc
);
}
else
{
...
...
@@ -689,17 +703,17 @@ struct DeviceGroupedConvBwdWeight_Dl : public DeviceGroupedConvBwdWeight<NDimSpa
make_tuple
(
InNStride
,
InDStride
,
InHStride
,
InWStride
,
InCStride
));
// A: output tensor
const
auto
out_gemmkpad_gemmm_grid_desc
=
transform_tensor_descriptor
(
out_gemmktotal_gemmm_grid_desc
,
make_tuple
(
make_right_pad_transform
(
GemmKTotal
,
GemmKPad
-
GemmKTotal
),
make_pass_through_transform
(
GemmM
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}));
const
auto
out_gemmkpad_gemmmpad_grid_desc
=
ck
::
tensor_operation
::
device
::
PadTensorDescriptor
(
out_gemmktotal_gemmm_grid_desc
,
make_tuple
(
GemmK1Number
*
K0PerBlock
*
GemmKBatch
,
MPerBlock
),
Sequence
<
true
,
true
>
{});
const
auto
out_gemmkbatch_gemmk0_gemmm_gemmk1_grid_desc
=
transform_tensor_descriptor
(
out_gemmkpad_gemmm_grid_desc
,
make_tuple
(
make_unmerge_transform
(
make_tuple
(
GemmKBatch
,
GemmK0
,
GemmK1Number
)),
make_pass_through_transform
(
GemmM
)),
out_gemmkpad_gemmmpad_grid_desc
,
make_tuple
(
make_unmerge_transform
(
make_tuple
(
GemmKBatch
,
GemmK0
,
GemmK1Number
)),
make_pass_through_transform
(
out_gemmkpad_gemmmpad_grid_desc
.
GetLength
(
I1
))),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
0
,
1
,
3
>
{},
Sequence
<
2
>
{}));
...
...
@@ -739,17 +753,17 @@ struct DeviceGroupedConvBwdWeight_Dl : public DeviceGroupedConvBwdWeight<NDimSpa
make_tuple
(
Sequence
<
1
,
3
,
5
,
7
>
{},
Sequence
<
0
,
2
,
4
,
6
>
{}),
make_tuple
(
Sequence
<
1
>
{},
Sequence
<
0
>
{}));
const
auto
in_gemmkpad_gemmn_grid_desc
=
transform_tensor_descriptor
(
in_gemmktotal_gemmn_grid_desc
,
make_tuple
(
make_right_pad_transform
(
GemmKTotal
,
GemmKPad
-
GemmKTotal
),
make_pass_through_transform
(
GemmN
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}));
const
auto
in_gemmkpad_gemmnpad_grid_desc
=
ck
::
tensor_operation
::
device
::
PadTensorDescriptor
(
in_gemmktotal_gemmn_grid_desc
,
make_tuple
(
GemmK1Number
*
K0PerBlock
*
GemmKBatch
,
NPerBlock
),
Sequence
<
true
,
true
>
{});
const
auto
in_gemmkbatch_gemmk0_gemmn_gemmk1_grid_desc
=
transform_tensor_descriptor
(
in_gemmkpad_gemmn_grid_desc
,
make_tuple
(
make_unmerge_transform
(
make_tuple
(
GemmKBatch
,
GemmK0
,
GemmK1Number
)),
make_pass_through_transform
(
GemmN
)),
in_gemmkpad_gemmnpad_grid_desc
,
make_tuple
(
make_unmerge_transform
(
make_tuple
(
GemmKBatch
,
GemmK0
,
GemmK1Number
)),
make_pass_through_transform
(
in_gemmkpad_gemmnpad_grid_desc
.
GetLength
(
I1
))),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
0
,
1
,
3
>
{},
Sequence
<
2
>
{}));
...
...
@@ -757,9 +771,14 @@ struct DeviceGroupedConvBwdWeight_Dl : public DeviceGroupedConvBwdWeight<NDimSpa
const
auto
wei_gemmm_gemmn_grid_desc
=
make_naive_tensor_descriptor
(
make_tuple
(
K
,
Z
*
Y
*
X
*
C
),
make_tuple
(
WeiKStride
,
WeiCStride
));
const
auto
wei_gemmmpad_gemmnpad_grid_desc
=
ck
::
tensor_operation
::
device
::
PadTensorDescriptor
(
wei_gemmm_gemmn_grid_desc
,
make_tuple
(
MPerBlock
,
NPerBlock
),
Sequence
<
true
,
true
>
{});
return
make_tuple
(
out_gemmkbatch_gemmk0_gemmm_gemmk1_grid_desc
,
in_gemmkbatch_gemmk0_gemmn_gemmk1_grid_desc
,
wei_gemmm_gemmn_grid_desc
);
wei_gemmm
pad
_gemmn
pad
_grid_desc
);
}
}
// function end
...
...
test/grouped_convnd_bwd_weight/test_grouped_convnd_bwd_weight.cpp
View file @
822a1110
...
...
@@ -14,6 +14,8 @@
#include "profiler/profile_grouped_conv_bwd_weight_impl.hpp"
using
namespace
ck
::
tensor_layout
::
convolution
;
template
<
typename
Tuple
>
class
TestGroupedConvndBwdWeight
:
public
::
testing
::
Test
{
...
...
@@ -35,7 +37,17 @@ class TestGroupedConvndBwdWeight : public ::testing::Test
// dl kernel is only supported for split_k=1
if
constexpr
(
std
::
is_same_v
<
InDataType
,
ck
::
half_t
>
)
{
if
(
split_k
==
1
&&
(
params
.
K_
==
1
||
params
.
C_
==
1
))
if
(
split_k
!=
1
&&
(
params
.
K_
%
2
!=
0
||
params
.
C_
%
2
!=
0
))
{
return
true
;
}
}
// 1d nhwgc is only supported by dl kernel
// dl kernel is only supported for split_k=1
if
constexpr
(
std
::
is_same_v
<
InLayout
,
NWGC
>
&&
std
::
is_same_v
<
OutLayout
,
NWGK
>
)
{
if
(
split_k
!=
1
)
{
return
true
;
}
...
...
@@ -90,8 +102,6 @@ class TestGroupedConvndBwdWeight3d : public TestGroupedConvndBwdWeight<Tuple>
{
};
using
namespace
ck
::
tensor_layout
::
convolution
;
using
KernelTypes1d
=
::
testing
::
Types
<
std
::
tuple
<
float
,
float
,
float
,
GNWC
,
GKXC
,
GNWK
,
ck
::
Number
<
1
>>
,
std
::
tuple
<
ck
::
half_t
,
ck
::
half_t
,
ck
::
half_t
,
GNWC
,
GKXC
,
GNWK
,
ck
::
Number
<
1
>>
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment