Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel_ROCM
Commits
9a99c841
Commit
9a99c841
authored
Aug 26, 2024
by
aska-0096
Browse files
temp save
parent
4f65f7b3
Changes
2
Show whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
180 additions
and
14 deletions
+180
-14
include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3_multi_d.hpp
...ration/gpu/grid/gridwise_gemm_xdl_cshuffle_v3_multi_d.hpp
+180
-13
include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp
...operation/gpu/thread/threadwise_tensor_slice_transfer.hpp
+0
-1
No files found.
include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3_multi_d.hpp
View file @
9a99c841
...
@@ -1574,7 +1574,8 @@ struct GridwiseGemmMultiD_xdl_cshuffle_v3
...
@@ -1574,7 +1574,8 @@ struct GridwiseGemmMultiD_xdl_cshuffle_v3
Number<NumDTensor>{}));
Number<NumDTensor>{}));
const auto e_grid_desc_mblock_mperblock_nblock_nperblock =
const auto e_grid_desc_mblock_mperblock_nblock_nperblock =
c_grid_desc_mblock_mperblock_nblock_nperblock;
MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(
c_grid_desc_m_n, problem.MBlock, problem.NBlock);;
using CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock =
using CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock =
CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock;
CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock;
...
@@ -1615,7 +1616,74 @@ struct GridwiseGemmMultiD_xdl_cshuffle_v3
...
@@ -1615,7 +1616,74 @@ struct GridwiseGemmMultiD_xdl_cshuffle_v3
// copy multipled from global to vgpr
// copy multipled from global to vgpr
auto
d_threadwise_copy
;
auto
d_threadwise_copy
;
// copy c from vgpr to lds
// copy c from vgpr to lds
auto
c_threadwise_copy_vgpr_to_lds
=
// TODO: Avoid bankconflict. 2 for mfma16x16, 0 for mfma32x32.
constexpr
index_t
MWave
=
MPerBlock
/
(
MXdlPerWave
*
MPerXdl
);
constexpr
index_t
NWave
=
NPerBlock
/
(
NXdlPerWave
*
NPerXdl
);
// C SrcDesc in VGPR
constexpr
auto
c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2
=
blockwise_gemm_pipeline
.
GetCThreadDescriptor_M0_N0_M1_N1_M2_M3_M4_N2
();
// C DstDesc in LDS
constexpr
auto
c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp
=
blockwise_gemm_pipeline
.
GetCBlockDescriptor_M0_N0_M1_N1_M2_M3_M4_N2
();
constexpr
auto
M0
=
c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp
.
GetLength
(
I0
);
constexpr
auto
N0
=
c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp
.
GetLength
(
I1
);
constexpr
auto
M1
=
c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp
.
GetLength
(
I2
);
constexpr
auto
N1
=
c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp
.
GetLength
(
I3
);
constexpr
auto
M2
=
c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp
.
GetLength
(
I4
);
constexpr
auto
M3
=
c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp
.
GetLength
(
I5
);
constexpr
auto
M4
=
c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp
.
GetLength
(
I6
);
constexpr
auto
N2
=
c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp
.
GetLength
(
I7
);
constexpr
auto
c_shuffle_block_desc_mblock_mperblock_nblock_nperblock
=
GetCShuffleBlockDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
();
auto
c_shuffle_block_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Lds
>
(
static_cast
<
CShuffleDataType
*>
(
p_shared
),
c_shuffle_block_desc_mblock_mperblock_nblock_nperblock
.
GetElementSpaceSize
());
constexpr
auto
c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2
=
transform_tensor_descriptor
(
c_shuffle_block_desc_mblock_mperblock_nblock_nperblock
,
make_tuple
(
make_freeze_transform
(
I0
),
make_unmerge_transform
(
make_tuple
(
Number
<
CShuffleMXdlPerWavePerShuffle
>
{},
M1
,
M2
,
M3
,
M4
)),
make_freeze_transform
(
I0
),
make_unmerge_transform
(
make_tuple
(
Number
<
CShuffleNXdlPerWavePerShuffle
>
{},
N1
,
N2
))),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{},
Sequence
<
3
>
{}),
make_tuple
(
Sequence
<>
{},
Sequence
<
0
,
2
,
4
,
5
,
6
>
{},
Sequence
<>
{},
Sequence
<
1
,
3
,
7
>
{}));
const
auto
c_thread_mtx_on_block
=
blockwise_gemm_pipeline
.
CalculateCThreadOriginDataIndex
(
I0
,
I0
,
I0
,
I0
);
const
index_t
m_thread_data_on_block
=
c_thread_mtx_on_block
[
I0
];
const
index_t
n_thread_data_on_block
=
c_thread_mtx_on_block
[
I1
];
const
auto
m_thread_data_on_block_to_m0_m1_m2_m3_m4_adaptor
=
make_single_stage_tensor_adaptor
(
make_tuple
(
make_merge_transform
(
make_tuple
(
M0
,
M1
,
M2
,
M3
,
M4
))),
make_tuple
(
Sequence
<
0
,
1
,
2
,
3
,
4
>
{}),
make_tuple
(
Sequence
<
0
>
{}));
const
auto
m_thread_data_on_block_idx
=
m_thread_data_on_block_to_m0_m1_m2_m3_m4_adaptor
.
CalculateBottomIndex
(
make_multi_index
(
m_thread_data_on_block
));
const
auto
n_thread_data_on_block_to_n0_n1_n2_adaptor
=
make_single_stage_tensor_adaptor
(
make_tuple
(
make_merge_transform
(
make_tuple
(
N0
,
N1
,
N2
))),
make_tuple
(
Sequence
<
0
,
1
,
2
>
{}),
make_tuple
(
Sequence
<
0
>
{}));
const
auto
n_thread_data_on_block_idx
=
n_thread_data_on_block_to_n0_n1_n2_adaptor
.
CalculateBottomIndex
(
make_multi_index
(
n_thread_data_on_block
));
// shuffle: threadwise copy C from VGPR to LDS
auto
c_thread_copy_vgpr_to_lds
=
ThreadwiseTensorSliceTransfer_v1r3
<
AccDataType
,
ThreadwiseTensorSliceTransfer_v1r3
<
AccDataType
,
CShuffleDataType
,
CShuffleDataType
,
decltype
(
c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2
),
decltype
(
c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2
),
...
@@ -1647,9 +1715,85 @@ struct GridwiseGemmMultiD_xdl_cshuffle_v3
...
@@ -1647,9 +1715,85 @@ struct GridwiseGemmMultiD_xdl_cshuffle_v3
ck
::
tensor_operation
::
element_wise
::
PassThrough
{}};
ck
::
tensor_operation
::
element_wise
::
PassThrough
{}};
// copy c from lds to vgpr
// copy c from lds to vgpr
auto
c_threadwise_copy_lds_to_vgpr
;
auto
c_threadwise_copy_lds_to_vgpr
;
// copy e from vgpr to vgpr
// copy e from vgpr to global
auto
e_threadwise_copy
;
constexpr
auto
n_vec
=
CDEShuffleBlockTransferScalarPerVectors
.
At
(
Number
<
0
>
{});
constexpr
auto
n_thread
=
NPerBlock
/
n_vec
/
NRepeat
;
constexpr
auto
m_thread
=
BlockSize
/
n_thread
;
constexpr
auto
m_thread_repeat
=
MPerBlock
/
MRepeat
/
m_thread
;
auto
c_thread_desc_coalescing
=
make_naive_tensor_descriptor_packed
(
make_tuple
(
I1
,
MRepeat
,
m_thread_repeat
,
I1
,
I1
,
NRepeat
,
I1
,
n_vec
));
auto
c_block_desc_coalescing
=
make_naive_tensor_descriptor_packed
(
make_tuple
(
I1
,
MRepeat
,
m_thread_repeat
,
m_thread
,
I1
,
NRepeat
,
n_thread
,
n_vec
));
auto
c_thread_copy_lds_to_vgpr
=
ThreadwiseTensorSliceTransfer_v2
<
CShuffleDataType
,
CShuffleDataType
,
decltype
(
c_block_desc_coalescing
),
decltype
(
c_thread_desc_coalescing
),
Sequence
<
I1
,
CShuffleMXdlPerWavePerShuffle
,
m_thread_repeat
,
I1
,
I1
,
CShuffleMXdlPerWavePerShuffle
,
I1
,
n_vec
>
,
Sequence
<
0
,
1
,
2
,
3
,
4
,
5
,
6
,
7
>
,
6
,
ABlockTransferSrcScalarPerVector
,
AThreadTransferSrcResetCoordinateAfterRun
,
true
>
(
a_grid_desc
,
make_multi_index
(
0
,
m_block_data_idx_on_grid
/
(
MWaves
*
MPerWmma
),
get_thread_local_1d_id
()
/
32
,
0
,
(
get_thread_local_1d_id
()
%
32
)
/
16
,
get_thread_local_1d_id
()
%
16
,
0
));
auto
e_grid_desc_coalescing
=
make_naive_tensor_descriptor_packed
(
make_tuple
(
problem
.
MBlock
,
MRepeat
,
m_thread_repeat
,
m_thread
,
problem
.
NBlock
,
NRepeat
,
n_thread
,
n_vec
));
auto
e_threadwise_copy
=
ThreadwiseTensorSliceTransfer_v1r3
<
EDataType
,
EDataType
,
decltype
(
c_thread_desc_coalescing
),
decltype
(
e_grid_desc_coalescing
),
ck
::
tensor_operation
::
element_wise
::
PassThrough
,
Sequence
<
I1
,
CShuffleMXdlPerWavePerShuffle
,
m_thread_repeat
,
I1
,
I1
,
CShuffleMXdlPerWavePerShuffle
,
I1
,
n_vec
>
,
Sequence
<
0
,
1
,
2
,
3
,
4
,
5
,
6
,
7
>
,
7
,
n_vec
,
CGlobalMemoryDataOperation
,
1
,
true
>
{
e_grid_desc_coalescing
,
make_multi_index
(
block_m_id
,
I0
,
I0
,
get_thread_local_1d_id
()
/
n_thread
,
block_n_id
,
I0
,
get_thread_local_1d_id
()
%
n_thread
,
I0
),
ck
::
tensor_operation
::
element_wise
::
PassThrough
{}};
auto
xdlops_gemm
=
blockwise_gemm_pipeline
.
xdlops_gemm
;
auto
xdlops_gemm
=
blockwise_gemm_pipeline
.
xdlops_gemm
;
constexpr
auto
MRepeat
=
MXdlPerWave
;
constexpr
auto
MRepeat
=
MXdlPerWave
;
...
@@ -1663,8 +1807,7 @@ struct GridwiseGemmMultiD_xdl_cshuffle_v3
...
@@ -1663,8 +1807,7 @@ struct GridwiseGemmMultiD_xdl_cshuffle_v3
static_for
<
0
,
MRepeat
/
CShuffleMXdlPerWavePerShuffle
,
1
>
{}([
&
](
auto
shuffle_m0
)
{
static_for
<
0
,
MRepeat
/
CShuffleMXdlPerWavePerShuffle
,
1
>
{}([
&
](
auto
shuffle_m0
)
{
static_for
<
0
,
NRepeat
/
CShuffleNXdlPerWavePerShuffle
,
1
>
{}([
&
](
auto
shuffle_n0
)
{
static_for
<
0
,
NRepeat
/
CShuffleNXdlPerWavePerShuffle
,
1
>
{}([
&
](
auto
shuffle_n0
)
{
// MutilpeD bufferload
// MutilpeD bufferload
d_threadwise_copy
.
Run
(
d_threadwise_copy
.
Run
(
c_ds_desc_refs
,
c_ds_desc_refs
,
c_ds_buf_refs
,
c_ds_buf_refs
,
tie
(
e_grid_desc_mblock_mperblock_nblock_nperblock
),
tie
(
e_grid_desc_mblock_mperblock_nblock_nperblock
),
tie
(
c_grid_buf
));
tie
(
c_grid_buf
));
...
@@ -1708,12 +1851,36 @@ struct GridwiseGemmMultiD_xdl_cshuffle_v3
...
@@ -1708,12 +1851,36 @@ struct GridwiseGemmMultiD_xdl_cshuffle_v3
});
});
});
});
// Shuffle: DS_WRITE
// Shuffle: DS_WRITE
c_thread_copy_vgpr_to_lds
.
Run
();
c_thread_copy_vgpr_to_lds
.
Run
(
c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2
,
make_tuple
(
Number
<
shuffle_m0
*
MRepeat
>
{},
Number
<
shuffle_n0
*
NRepeat
>
{},
Number
<
0
>
{},
Number
<
0
>
{},
Number
<
0
>
{},
Number
<
0
>
{},
Number
<
0
>
{},
Number
<
0
>
{}),
c_thread_buf
,
c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2
,
c_shuffle_block_buf
);
block_sync_lds
();
block_sync_lds
();
// Shuffle: DS_READ
// Shuffle: DS_READ
e_blockwise_copy
.
RunRead
();
c_thread_copy_lds_to_vgpr
.
Run
();
cde_element
();
cde_element
();
e_blockwise_copy
.
RunWrite
();
e_threadwise_copy
.
Run
(
c_thread_desc_coalescing
,
make_tuple
(
Number
<
0
>
{},
Number
<
0
>
{},
Number
<
0
>
{},
Number
<
0
>
{},
Number
<
0
>
{},
Number
<
0
>
{},
Number
<
0
>
{},
Number
<
0
>
{}),
e_thread_buf
,
e_grid_desc_coalescing
,
c_grid_buf
);
// move e_grid desc slice origin
});
});
});
});
}
}
...
...
include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp
View file @
9a99c841
...
@@ -602,7 +602,6 @@ struct ThreadwiseTensorSliceTransfer_v2r1
...
@@ -602,7 +602,6 @@ struct ThreadwiseTensorSliceTransfer_v2r1
SrcCoord
src_coord_
;
SrcCoord
src_coord_
;
};
// namespace ck
};
// namespace ck
// Assume:
// Assume:
// 1. src_desc and dst_desc are not known at compile-time
// 1. src_desc and dst_desc are not known at compile-time
// 2. SrcBuffer and DstBuffer are DynamicBuffer
// 2. SrcBuffer and DstBuffer are DynamicBuffer
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment