Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel_ROCM
Commits
7a20cf67
Commit
7a20cf67
authored
Sep 18, 2024
by
Jakub Piasecki
Browse files
tmp save
parent
a793afc9
Changes
2
Expand all
Show whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
1223 additions
and
23 deletions
+1223
-23
include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_data_multiple_d_xdl_cshuffle_v1.hpp
...vice_grouped_conv_bwd_data_multiple_d_xdl_cshuffle_v1.hpp
+44
-23
include/ck/tensor_operation/operator_transform/transform_conv_bwd_data_to_gemm_v2.hpp
...operator_transform/transform_conv_bwd_data_to_gemm_v2.hpp
+1179
-0
No files found.
include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_data_multiple_d_xdl_cshuffle_v1.hpp
View file @
7a20cf67
...
@@ -76,7 +76,6 @@ __global__ void
...
@@ -76,7 +76,6 @@ __global__ void
const
AElementwiseOperation
a_element_op
,
const
AElementwiseOperation
a_element_op
,
const
BElementwiseOperation
b_element_op
,
const
BElementwiseOperation
b_element_op
,
const
CDEElementwiseOperation
cde_element_op
,
const
CDEElementwiseOperation
cde_element_op
,
const
index_t
batch_count
,
const
AGridDesc_AK0_M_AK1
a_grid_desc_ak0_m_ak1
,
const
AGridDesc_AK0_M_AK1
a_grid_desc_ak0_m_ak1
,
const
BGridDesc_BK0_N_BK1
b_grid_desc_bk0_n_bk1
,
const
BGridDesc_BK0_N_BK1
b_grid_desc_bk0_n_bk1
,
const
DsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
const
DsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
...
@@ -89,18 +88,19 @@ __global__ void
...
@@ -89,18 +88,19 @@ __global__ void
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \
defined(__gfx94__))
defined(__gfx94__))
// offset base pointer for each work-group
// offset base pointer for each work-group
const
index_t
num_blocks_per_batch
=
// const index_t num_blocks_per_batch =
__builtin_amdgcn_readfirstlane
(
get_grid_size
()
/
batch_count
);
// __builtin_amdgcn_readfirstlane(get_grid_size() / batch_count);
const
index_t
g_idx
=
__builtin_amdgcn_readfirstlane
(
get_block_1d_id
()
/
num_blocks_per_batch
);
//const index_t g_idx = __builtin_amdgcn_readfirstlane(get_block_1d_id() / num_blocks_per_batch);
const
index_t
g_idx
=
__builtin_amdgcn_readfirstlane
(
blockIdx
.
y
);
const
long_index_t
a_
batch
_offset
=
const
long_index_t
a_
group
_offset
=
amd_wave_read_first_lane
(
compute_ptr_offset_of_batch
.
GetAPtrOffset
(
g_idx
));
amd_wave_read_first_lane
(
compute_ptr_offset_of_batch
.
GetAPtrOffset
(
g_idx
));
const
long_index_t
b_
batch
_offset
=
const
long_index_t
b_
group
_offset
=
amd_wave_read_first_lane
(
compute_ptr_offset_of_batch
.
GetBPtrOffset
(
g_idx
));
amd_wave_read_first_lane
(
compute_ptr_offset_of_batch
.
GetBPtrOffset
(
g_idx
));
const
long_index_t
e_
batch
_offset
=
const
long_index_t
e_
group
_offset
=
amd_wave_read_first_lane
(
compute_ptr_offset_of_batch
.
GetEPtrOffset
(
g_idx
));
amd_wave_read_first_lane
(
compute_ptr_offset_of_batch
.
GetEPtrOffset
(
g_idx
));
const
auto
ds_
batch
_offset
=
compute_ptr_offset_of_batch
.
GetDsPtrOffset
(
g_idx
);
const
auto
ds_
group
_offset
=
compute_ptr_offset_of_batch
.
GetDsPtrOffset
(
g_idx
);
__shared__
char
p_shared
[
GridwiseGemm
::
GetSharedMemoryNumberOfByte
()];
__shared__
char
p_shared
[
GridwiseGemm
::
GetSharedMemoryNumberOfByte
()];
...
@@ -110,12 +110,12 @@ __global__ void
...
@@ -110,12 +110,12 @@ __global__ void
DsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
::
Size
();
DsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
::
Size
();
static_for
<
0
,
NumDTensor
,
1
>
{}(
static_for
<
0
,
NumDTensor
,
1
>
{}(
[
&
](
auto
i
)
{
p_ds_grid_grp
(
i
)
=
p_ds_grid
[
i
]
+
ds_
batch
_offset
[
i
];
});
[
&
](
auto
i
)
{
p_ds_grid_grp
(
i
)
=
p_ds_grid
[
i
]
+
ds_
group
_offset
[
i
];
});
GridwiseGemm
::
template
Run
<
HasMainKBlockLoop
>(
p_a_grid
+
a_
batch
_offset
,
GridwiseGemm
::
template
Run
<
HasMainKBlockLoop
>(
p_a_grid
+
a_
group
_offset
,
p_b_grid
+
b_
batch
_offset
,
p_b_grid
+
b_
group
_offset
,
p_ds_grid_grp
,
p_ds_grid_grp
,
p_e_grid
+
e_
batch
_offset
,
p_e_grid
+
e_
group
_offset
,
p_shared
,
p_shared
,
a_element_op
,
a_element_op
,
b_element_op
,
b_element_op
,
...
@@ -130,7 +130,6 @@ __global__ void
...
@@ -130,7 +130,6 @@ __global__ void
ignore
=
p_b_grid
;
ignore
=
p_b_grid
;
ignore
=
p_ds_grid
;
ignore
=
p_ds_grid
;
ignore
=
p_e_grid
;
ignore
=
p_e_grid
;
ignore
=
batch_count
;
ignore
=
a_grid_desc_ak0_m_ak1
;
ignore
=
a_grid_desc_ak0_m_ak1
;
ignore
=
b_grid_desc_bk0_n_bk1
;
ignore
=
b_grid_desc_bk0_n_bk1
;
ignore
=
ds_grid_desc_mblock_mperblock_nblock_nperblock
;
ignore
=
ds_grid_desc_mblock_mperblock_nblock_nperblock
;
...
@@ -200,7 +199,8 @@ template <index_t NDimSpatial,
...
@@ -200,7 +199,8 @@ template <index_t NDimSpatial,
index_t
CDEBlockTransferScalarPerVector_NPerBlock
,
index_t
CDEBlockTransferScalarPerVector_NPerBlock
,
LoopScheduler
LoopSched
=
make_default_loop_scheduler
(),
LoopScheduler
LoopSched
=
make_default_loop_scheduler
(),
typename
AComputeType
=
ADataType
,
typename
AComputeType
=
ADataType
,
typename
BComputeType
=
AComputeType
>
typename
BComputeType
=
AComputeType
,
index_t
NumGroupsToMerge
=
1
>
struct
DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1
struct
DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1
:
public
DeviceGroupedConvBwdDataMultipleD
<
NDimSpatial
,
:
public
DeviceGroupedConvBwdDataMultipleD
<
NDimSpatial
,
ALayout
,
// output image
ALayout
,
// output image
...
@@ -220,6 +220,8 @@ struct DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1
...
@@ -220,6 +220,8 @@ struct DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1
// TODO: Extend support for more spatial dimensions.
// TODO: Extend support for more spatial dimensions.
static_assert
(
NDimSpatial
==
2
||
NDimSpatial
==
3
,
static_assert
(
NDimSpatial
==
2
||
NDimSpatial
==
3
,
"wrong! only implemented for 2D and 3D now"
);
"wrong! only implemented for 2D and 3D now"
);
static_assert
(
NumGroupsToMerge
>=
1
,
"wrong! NumGroupsToMerge must be greater or equal to 1!"
);
using
DeviceOp
=
DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1
;
using
DeviceOp
=
DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1
;
...
@@ -242,7 +244,8 @@ struct DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1
...
@@ -242,7 +244,8 @@ struct DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1
NPerBlock
,
NPerBlock
,
KPerBlock
,
KPerBlock
,
DoPadGemmM
,
DoPadGemmM
,
DoPadGemmN
>
{};
DoPadGemmN
,
NumGroupsToMerge
>
{};
static
auto
GetDummyABDsEGridDescriptor
()
static
auto
GetDummyABDsEGridDescriptor
()
{
{
...
@@ -453,10 +456,10 @@ struct DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1
...
@@ -453,10 +456,10 @@ struct DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1
// A/B/Ds/E Batch Stride
// A/B/Ds/E Batch Stride
compute_ptr_offset_of_batch_
.
BatchStrideA_
=
a_g_n_k_wos_strides
[
0
];
compute_ptr_offset_of_batch_
.
BatchStrideA_
=
a_g_n_k_wos_strides
[
0
];
compute_ptr_offset_of_batch_
.
BatchStrideB_
=
b_g_k_c_xs_strides
[
0
];
compute_ptr_offset_of_batch_
.
BatchStrideB_
=
b_g_k_c_xs_strides
[
0
];
compute_ptr_offset_of_batch_
.
BatchStrideE_
=
e_g_n_c_wis_strides
[
0
]
;
compute_ptr_offset_of_batch_
.
BatchStrideE_
=
e_g_n_c_wis_strides
[
0
]
*
NumGroupsToMerge
;
// sure?
static_for
<
0
,
NumDTensor
,
1
>
{}([
&
](
auto
i
)
{
static_for
<
0
,
NumDTensor
,
1
>
{}([
&
](
auto
i
)
{
compute_ptr_offset_of_batch_
.
BatchStrideDs_
(
i
)
=
ds_g_n_c_wis_strides
[
i
][
0
];
compute_ptr_offset_of_batch_
.
BatchStrideDs_
(
i
)
=
ds_g_n_c_wis_strides
[
i
][
0
]
*
NumGroupsToMerge
;
});
});
static
constexpr
auto
NonSpatialDimsNum
=
Number
<
3
>
{};
static
constexpr
auto
NonSpatialDimsNum
=
Number
<
3
>
{};
...
@@ -713,7 +716,7 @@ struct DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1
...
@@ -713,7 +716,7 @@ struct DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1
float
ave_time
=
0
;
float
ave_time
=
0
;
for
(
std
::
size_t
i
=
0
;
i
<
arg
.
a_grid_desc_ak0_m_ak1_container_
.
size
();
i
++
)
for
(
std
::
size_t
i
=
0
;
i
<
arg
.
a_grid_desc_ak0_m_ak1_container_
.
size
();
i
++
)
// why is there a for?
{
{
if
(
!
GridwiseGemm
::
CheckValidity
(
arg
.
a_grid_desc_m_k_container_
[
i
],
if
(
!
GridwiseGemm
::
CheckValidity
(
arg
.
a_grid_desc_m_k_container_
[
i
],
arg
.
b_grid_desc_n_k_container_
[
i
],
arg
.
b_grid_desc_n_k_container_
[
i
],
...
@@ -752,7 +755,7 @@ struct DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1
...
@@ -752,7 +755,7 @@ struct DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1
return
launch_and_time_kernel
(
return
launch_and_time_kernel
(
stream_config
,
stream_config
,
kernel
,
kernel
,
dim3
(
grid_size
),
dim3
(
grid_size
),
// change to gdx, gdy, gdz after removing for and calculating all in one go
dim3
(
BlockSize
),
dim3
(
BlockSize
),
0
,
0
,
arg
.
p_a_grid_
,
arg
.
p_a_grid_
,
...
@@ -762,7 +765,7 @@ struct DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1
...
@@ -762,7 +765,7 @@ struct DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1
arg
.
a_element_op_
,
arg
.
a_element_op_
,
arg
.
b_element_op_
,
arg
.
b_element_op_
,
arg
.
cde_element_op_
,
arg
.
cde_element_op_
,
arg
.
a_g_n_k_wos_lengths_
[
0
],
// Group count
//
arg.a_g_n_k_wos_lengths_[0], // Group count
arg
.
a_grid_desc_ak0_m_ak1_container_
[
i
],
arg
.
a_grid_desc_ak0_m_ak1_container_
[
i
],
arg
.
b_grid_desc_bk0_n_bk1_container_
[
i
],
arg
.
b_grid_desc_bk0_n_bk1_container_
[
i
],
arg
.
ds_grid_desc_mblock_mperblock_nblock_nperblock_container_
[
i
],
arg
.
ds_grid_desc_mblock_mperblock_nblock_nperblock_container_
[
i
],
...
@@ -798,8 +801,9 @@ struct DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1
...
@@ -798,8 +801,9 @@ struct DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1
return
false
;
return
false
;
}
}
const
index_t
ConvK
=
arg
.
b_g_k_c_xs_lengths_
[
1
];
const
index_t
ConvG
=
arg
.
b_g_k_c_xs_lengths_
[
I0
];
const
index_t
ConvC
=
arg
.
b_g_k_c_xs_lengths_
[
2
];
const
index_t
ConvK
=
arg
.
b_g_k_c_xs_lengths_
[
I1
];
const
index_t
ConvC
=
arg
.
b_g_k_c_xs_lengths_
[
I2
];
// Specifialization
// Specifialization
if
constexpr
(
ConvBackwardDataSpecialization
==
if
constexpr
(
ConvBackwardDataSpecialization
==
...
@@ -894,6 +898,22 @@ struct DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1
...
@@ -894,6 +898,22 @@ struct DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1
return
false
;
return
false
;
}
}
if
constexpr
(
NumGroupsToMerge
>
1
)
{
if
(
!
(
ConvC
==
1
))
{
return
false
;
}
if
(
ConvG
%
NumGroupsToMerge
!=
0
)
{
return
false
;
}
if
constexpr
(
!
is_NSpatialGK_GKSpatial_NSpatialGC
<
ALayout
,
BLayout
,
ELayout
>
())
{
return
false
;
}
}
// Gridwise GEMM size
// Gridwise GEMM size
for
(
std
::
size_t
i
=
0
;
i
<
arg
.
a_grid_desc_ak0_m_ak1_container_
.
size
();
i
++
)
for
(
std
::
size_t
i
=
0
;
i
<
arg
.
a_grid_desc_ak0_m_ak1_container_
.
size
();
i
++
)
{
{
...
@@ -1031,7 +1051,8 @@ struct DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1
...
@@ -1031,7 +1051,8 @@ struct DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1
<<
ABlockTransferSrcScalarPerVector
<<
", "
<<
ABlockTransferSrcScalarPerVector
<<
", "
<<
BBlockTransferSrcScalarPerVector
<<
", "
<<
BBlockTransferSrcScalarPerVector
<<
", "
<<
CShuffleMXdlPerWavePerShuffle
<<
", "
<<
CShuffleMXdlPerWavePerShuffle
<<
", "
<<
CShuffleNXdlPerWavePerShuffle
<<
CShuffleNXdlPerWavePerShuffle
<<
", "
<<
NumGroupsToMerge
<<
">"
;
<<
">"
;
return
str
.
str
();
return
str
.
str
();
...
...
include/ck/tensor_operation/operator_transform/transform_conv_bwd_data_to_gemm_v2.hpp
0 → 100644
View file @
7a20cf67
This diff is collapsed.
Click to expand it.
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment