Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel_ROCM
Commits
69977fab
Commit
69977fab
authored
Oct 21, 2024
by
aska-0096
Browse files
tempsave
parent
1e339898
Changes
7
Hide whitespace changes
Inline
Side-by-side
Showing
7 changed files
with
146 additions
and
119 deletions
+146
-119
example/01_gemm/gemm_xdl_fp16_v3.cpp
example/01_gemm/gemm_xdl_fp16_v3.cpp
+3
-3
include/ck/tensor/static_tensor.hpp
include/ck/tensor/static_tensor.hpp
+0
-26
include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_base.hpp
...eration/gpu/block/blockwise_gemm_pipeline_xdlops_base.hpp
+4
-8
include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_v3.hpp
...operation/gpu/block/blockwise_gemm_pipeline_xdlops_v3.hpp
+12
-12
include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3.hpp
...nsor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3.hpp
+25
-23
include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp
...operation/gpu/thread/threadwise_tensor_slice_transfer.hpp
+89
-26
include/ck/tensor_operation/gpu/warp/xdlops_gemm.hpp
include/ck/tensor_operation/gpu/warp/xdlops_gemm.hpp
+13
-21
No files found.
example/01_gemm/gemm_xdl_fp16_v3.cpp
View file @
69977fab
...
@@ -28,15 +28,15 @@ using DeviceGemmV2Instance =
...
@@ -28,15 +28,15 @@ using DeviceGemmV2Instance =
ADataType
,
BDataType
,
CDataType
,
AccDataType
,
CShuffleDataType
,
ADataType
,
BDataType
,
CDataType
,
AccDataType
,
CShuffleDataType
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmDefault
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmDefault
,
256
,
256
,
2
24
,
256
,
2
56
,
256
,
64
,
8
,
8
,
64
,
8
,
8
,
16
,
16
,
16
,
16
,
7
,
8
,
8
,
8
,
S
<
8
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
S
<
8
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
0
,
2
,
8
,
8
,
0
,
S
<
8
,
32
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
8
,
32
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
8
,
8
,
0
,
1
,
8
,
8
,
0
,
1
,
2
,
S
<
1
,
32
,
1
,
8
>
,
8
,
1
,
2
,
S
<
1
,
32
,
1
,
8
>
,
8
,
// TODO: Deprecated
ck
::
BlockGemmPipelineScheduler
::
Intrawave
,
ck
::
BlockGemmPipelineVersion
::
v3
>
;
ck
::
BlockGemmPipelineScheduler
::
Intrawave
,
ck
::
BlockGemmPipelineVersion
::
v3
>
;
// clang-format on
// clang-format on
...
...
include/ck/tensor/static_tensor.hpp
View file @
69977fab
...
@@ -218,32 +218,6 @@ struct StaticTensorTupleOfVectorBuffer
...
@@ -218,32 +218,6 @@ struct StaticTensorTupleOfVectorBuffer
}
}
}
}
template
<
typename
X
,
typename
Idx
,
typename
enable_if
<
has_same_scalar_type
<
S
,
X
>
::
value
&&
is_known_at_compile_time
<
Idx
>::
value
&&
Idx
::
Size
()
==
ndim_
,
bool
>::
type
=
false
>
__host__
__device__
constexpr
void
SetAsType_Print
(
Idx
,
X
x
)
{
constexpr
auto
coord
=
make_tensor_coordinate
(
desc_
,
to_multi_index
(
Idx
{}));
constexpr
index_t
offset
=
coord
.
GetOffset
();
if
(
get_thread_local_1d_id
()
==
0
){
printf
(
"Tid: %d, Index: (%d, %d, %d, %d), Offset: %d
\n
"
,
get_thread_local_1d_id
(),
Idx
{}.
At
(
Number
<
0
>
{}).
value
,
Idx
{}.
At
(
Number
<
1
>
{}).
value
,
Idx
{}.
At
(
Number
<
2
>
{}).
value
,
Idx
{}.
At
(
Number
<
3
>
{}).
value
,
offset
);
}
constexpr
bool
is_valid
=
coordinate_has_valid_offset
(
desc_
,
coord
);
if
constexpr
(
is_valid
)
{
data_
.
template
SetAsType
<
X
>(
Number
<
offset
>
{},
x
);
}
}
// Get read access to V. No is_valid check
// Get read access to V. No is_valid check
// Idx is for S, not V. Idx should be aligned with V
// Idx is for S, not V. Idx should be aligned with V
template
<
typename
Idx
>
template
<
typename
Idx
>
...
...
include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_base.hpp
View file @
69977fab
...
@@ -302,21 +302,17 @@ struct BlockwiseGemmXdlops_pipeline_base
...
@@ -302,21 +302,17 @@ struct BlockwiseGemmXdlops_pipeline_base
return
xdlops_gemm
.
MakeCDescriptor_M0_N0_M1_N1_M2_M3_M4_N2
(
c_block_desc_m0_n0_m1_n1_m2_n2
);
return
xdlops_gemm
.
MakeCDescriptor_M0_N0_M1_N1_M2_M3_M4_N2
(
c_block_desc_m0_n0_m1_n1_m2_n2
);
}
}
__host__
__device__
static
constexpr
auto
__host__
__device__
static
constexpr
auto
GetCBlockDescriptor_M0_M1_N0_M2_M3_N1_N2_M4
()
GetCBlockDescriptor_MBlock_NBlock_M0_M1_N0_M2_M3_N1_N2_M4
()
{
{
constexpr
auto
c_block_desc_mblock_nblock_m0_n0_m1_n1_m2_n2
=
constexpr
auto
c_block_desc_m0_n0_m1_n1_m2_n2
=
make_naive_tensor_descriptor_packed
(
make_tuple
(
I1
,
make_naive_tensor_descriptor_packed
(
make_tuple
(
Number
<
MRepeat
>
{},
I1
,
Number
<
MRepeat
>
{},
Number
<
NRepeat
>
{},
Number
<
NRepeat
>
{},
Number
<
MWaves
>
{},
Number
<
MWaves
>
{},
Number
<
NWaves
>
{},
Number
<
NWaves
>
{},
Number
<
MPerXDL
>
{},
Number
<
MPerXDL
>
{},
Number
<
NPerXDL
>
{}));
Number
<
NPerXDL
>
{}));
return
xdlops_gemm
.
MakeCDescriptor_MBlock_NBlock_M0_M1_N0_M2_M3_N1_N2_M4
(
return
xdlops_gemm
.
MakeCDescriptor_M0_M1_N0_M2_M3_N1_N2_M4
(
c_block_desc_m0_n0_m1_n1_m2_n2
);
c_block_desc_mblock_nblock_m0_n0_m1_n1_m2_n2
);
}
}
__host__
__device__
static
constexpr
auto
GetCBlockDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2
()
__host__
__device__
static
constexpr
auto
GetCBlockDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2
()
...
...
include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_v3.hpp
View file @
69977fab
...
@@ -332,12 +332,12 @@ struct BlockwiseGemmXdlops_pipeline_v3<BlockGemmPipelineScheduler::Intrawave,
...
@@ -332,12 +332,12 @@ struct BlockwiseGemmXdlops_pipeline_v3<BlockGemmPipelineScheduler::Intrawave,
// Local prefetch 1
// Local prefetch 1
block_sync_lds
();
block_sync_lds
();
static_for
<
0
,
KRepeat
,
1
>
{}([
&
](
auto
k0
)
{
static_for
<
0
,
KRepeat
,
1
>
{}([
&
](
auto
k0
)
{
//
a_thread_copy_.Run(a_block_desc_m0_m1_m2_k,
a_thread_copy_
.
Run
(
a_block_desc_m0_m1_m2_k
,
//
make_tuple(I0, I0, I0, Number<k0 * AMmaKStride>{}),
make_tuple
(
I0
,
I0
,
I0
,
Number
<
k0
*
AMmaKStride
>
{}),
//
a_block_buf,
a_block_buf
,
//
a_thread_desc_,
a_thread_desc_
,
//
make_tuple(I0, I0, k0, I0),
make_tuple
(
I0
,
I0
,
k0
,
I0
),
//
a_thread_buf);
a_thread_buf
);
b_thread_copy_
.
Run
(
b_block_desc_n0_n1_n2_k
,
b_thread_copy_
.
Run
(
b_block_desc_n0_n1_n2_k
,
make_tuple
(
I0
,
I0
,
I0
,
Number
<
k0
*
BMmaKStride
>
{}),
make_tuple
(
I0
,
I0
,
I0
,
Number
<
k0
*
BMmaKStride
>
{}),
...
@@ -399,12 +399,12 @@ struct BlockwiseGemmXdlops_pipeline_v3<BlockGemmPipelineScheduler::Intrawave,
...
@@ -399,12 +399,12 @@ struct BlockwiseGemmXdlops_pipeline_v3<BlockGemmPipelineScheduler::Intrawave,
block_sync_lds
();
block_sync_lds
();
static_for
<
0
,
KRepeat
,
1
>
{}([
&
](
auto
k0
)
{
static_for
<
0
,
KRepeat
,
1
>
{}([
&
](
auto
k0
)
{
//
a_thread_copy_.Run(a_block_desc_m0_m1_m2_k,
a_thread_copy_
.
Run
(
a_block_desc_m0_m1_m2_k
,
//
make_tuple(I0, I0, I0, Number<k0 * AMmaKStride>{}),
make_tuple
(
I0
,
I0
,
I0
,
Number
<
k0
*
AMmaKStride
>
{}),
//
a_block_buf,
a_block_buf
,
//
a_thread_desc_,
a_thread_desc_
,
//
make_tuple(I0, I0, k0, I0),
make_tuple
(
I0
,
I0
,
k0
,
I0
),
//
a_thread_buf);
a_thread_buf
);
b_thread_copy_
.
Run
(
b_block_desc_n0_n1_n2_k
,
b_thread_copy_
.
Run
(
b_block_desc_n0_n1_n2_k
,
make_tuple
(
I0
,
I0
,
I0
,
Number
<
k0
*
BMmaKStride
>
{}),
make_tuple
(
I0
,
I0
,
I0
,
Number
<
k0
*
BMmaKStride
>
{}),
...
...
include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3.hpp
View file @
69977fab
...
@@ -146,7 +146,7 @@ struct GridwiseGemm_xdl_cshuffle_v3
...
@@ -146,7 +146,7 @@ struct GridwiseGemm_xdl_cshuffle_v3
static
constexpr
auto
BK1Number
=
Number
<
BK1Value
>
{};
static
constexpr
auto
BK1Number
=
Number
<
BK1Value
>
{};
static
constexpr
index_t
KPack
=
static
constexpr
index_t
KPack
=
math
::
max
(
math
::
lcm
(
AK1Number
,
BK1Number
),
math
::
max
(
math
::
gcd
(
AK1Number
,
BK1Number
),
MfmaSelector
<
ComputeTypeA
,
MPerXdl
,
NPerXdl
>::
selected_mfma
.
k_per_blk
);
MfmaSelector
<
ComputeTypeA
,
MPerXdl
,
NPerXdl
>::
selected_mfma
.
k_per_blk
);
using
ThisThreadBlock
=
ThisThreadBlock
<
BlockSize
>
;
using
ThisThreadBlock
=
ThisThreadBlock
<
BlockSize
>
;
...
@@ -1424,25 +1424,27 @@ struct GridwiseGemm_xdl_cshuffle_v3
...
@@ -1424,25 +1424,27 @@ struct GridwiseGemm_xdl_cshuffle_v3
constexpr
auto
c_thread_desc_mblock_nblock_m0_m1_n0_m2_m3_n1_n2_m4
=
constexpr
auto
c_thread_desc_mblock_nblock_m0_m1_n0_m2_m3_n1_n2_m4
=
blockwise_gemm_pipeline
.
GetCThreadDescriptor_MBlock_NBlock_M0_M1_N0_M2_M3_N1_N2_M4
();
blockwise_gemm_pipeline
.
GetCThreadDescriptor_MBlock_NBlock_M0_M1_N0_M2_M3_N1_N2_M4
();
constexpr
auto
c_block_desc_mblock_nblock_m0_m1_n0_m2_m3_n1_n2_m4
=
constexpr
auto
c_block_desc_m0_m1_n0_m2_m3_n1_n2_m4
=
blockwise_gemm_pipeline
.
GetCBlockDescriptor_MBlock_NBlock_M0_M1_N0_M2_M3_N1_N2_M4
();
blockwise_gemm_pipeline
.
GetCBlockDescriptor_M0_M1_N0_M2_M3_N1_N2_M4
();
constexpr
auto
M0
=
constexpr
auto
M0
=
c_block_desc_m0_m1_n0_m2_m3_n1_n2_m4
.
GetLength
(
Number
<
0
>
{});
c_block_desc_mblock_nblock_m0_m1_n0_m2_m3_n1_n2_m4
.
GetLength
(
Number
<
2
>
{});
constexpr
auto
M1
=
c_block_desc_m0_m1_n0_m2_m3_n1_n2_m4
.
GetLength
(
Number
<
1
>
{});
constexpr
auto
M1
=
constexpr
auto
N0
=
c_block_desc_m0_m1_n0_m2_m3_n1_n2_m4
.
GetLength
(
Number
<
2
>
{});
c_block_desc_mblock_nblock_m0_m1_n0_m2_m3_n1_n2_m4
.
GetLength
(
Number
<
3
>
{});
constexpr
auto
M2
=
c_block_desc_m0_m1_n0_m2_m3_n1_n2_m4
.
GetLength
(
Number
<
3
>
{});
constexpr
auto
N0
=
constexpr
auto
M3
=
c_block_desc_m0_m1_n0_m2_m3_n1_n2_m4
.
GetLength
(
Number
<
4
>
{});
c_block_desc_mblock_nblock_m0_m1_n0_m2_m3_n1_n2_m4
.
GetLength
(
Number
<
4
>
{});
constexpr
auto
N1
=
c_block_desc_m0_m1_n0_m2_m3_n1_n2_m4
.
GetLength
(
Number
<
5
>
{});
constexpr
auto
M2
=
constexpr
auto
N2
=
c_block_desc_m0_m1_n0_m2_m3_n1_n2_m4
.
GetLength
(
Number
<
6
>
{});
c_block_desc_mblock_nblock_m0_m1_n0_m2_m3_n1_n2_m4
.
GetLength
(
Number
<
5
>
{});
constexpr
auto
M4
=
c_block_desc_m0_m1_n0_m2_m3_n1_n2_m4
.
GetLength
(
Number
<
7
>
{});
constexpr
auto
M3
=
c_block_desc_mblock_nblock_m0_m1_n0_m2_m3_n1_n2_m4
.
GetLength
(
Number
<
6
>
{});
const
auto
c_grid_desc_mblock_nblock_m0_m1_n0_m2_m3_n1_n2_m4
=
transform_tensor_descriptor
(
constexpr
auto
N1
=
c_grid_desc_mblock_mperblock_nblock_nperblock
,
c_block_desc_mblock_nblock_m0_m1_n0_m2_m3_n1_n2_m4
.
GetLength
(
Number
<
7
>
{});
make_tuple
(
make_pass_through_transform
(
problem
.
MBlock
),
constexpr
auto
N2
=
make_unmerge_transform
(
make_tuple
(
M0
,
M1
,
M2
,
M3
,
M4
)),
c_block_desc_mblock_nblock_m0_m1_n0_m2_m3_n1_n2_m4
.
GetLength
(
Number
<
8
>
{});
make_pass_through_transform
(
problem
.
NBlock
),
constexpr
auto
M4
=
make_unmerge_transform
(
make_tuple
(
N0
,
N1
,
N2
))),
c_block_desc_mblock_nblock_m0_m1_n0_m2_m3_n1_n2_m4
.
GetLength
(
Number
<
9
>
{});
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{},
Sequence
<
3
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
2
,
3
,
5
,
6
,
9
>
{},
Sequence
<
1
>
{},
Sequence
<
4
,
7
,
8
>
{}));
const
auto
c_thread_mtx_on_block
=
const
auto
c_thread_mtx_on_block
=
blockwise_gemm_pipeline
.
CalculateCThreadOriginDataIndexContiguous
(
I0
,
I0
,
I0
,
I0
);
blockwise_gemm_pipeline
.
CalculateCThreadOriginDataIndexContiguous
(
I0
,
I0
,
I0
,
I0
);
...
@@ -1474,7 +1476,7 @@ struct GridwiseGemm_xdl_cshuffle_v3
...
@@ -1474,7 +1476,7 @@ struct GridwiseGemm_xdl_cshuffle_v3
AccDataType
,
AccDataType
,
CDataType
,
CDataType
,
decltype
(
c_thread_desc_mblock_nblock_m0_m1_n0_m2_m3_n1_n2_m4
),
decltype
(
c_thread_desc_mblock_nblock_m0_m1_n0_m2_m3_n1_n2_m4
),
decltype
(
c_
block
_desc_mblock_nblock_m0_m1_n0_m2_m3_n1_n2_m4
),
decltype
(
c_
grid
_desc_mblock_nblock_m0_m1_n0_m2_m3_n1_n2_m4
),
CElementwiseOperation
,
CElementwiseOperation
,
Sequence
<
I1
,
I1
,
M0
,
I1
,
I1
,
M2
,
I1
,
I1
,
N2
,
M4
>
,
Sequence
<
I1
,
I1
,
M0
,
I1
,
I1
,
M2
,
I1
,
I1
,
N2
,
M4
>
,
Sequence
<
0
,
1
,
2
,
3
,
4
,
5
,
6
,
7
,
8
,
9
>
,
Sequence
<
0
,
1
,
2
,
3
,
4
,
5
,
6
,
7
,
8
,
9
>
,
...
@@ -1484,7 +1486,7 @@ struct GridwiseGemm_xdl_cshuffle_v3
...
@@ -1484,7 +1486,7 @@ struct GridwiseGemm_xdl_cshuffle_v3
M4
,
M4
,
N2
,
N2
,
InMemoryDataOperationEnum
::
Set
,
InMemoryDataOperationEnum
::
Set
,
false
>
{
c_
block
_desc_mblock_nblock_m0_m1_n0_m2_m3_n1_n2_m4
,
false
>
{
c_
grid
_desc_mblock_nblock_m0_m1_n0_m2_m3_n1_n2_m4
,
make_multi_index
(
block_m_id
,
make_multi_index
(
block_m_id
,
block_n_id
,
block_n_id
,
m_thread_data_on_block_idx
[
I0
],
m_thread_data_on_block_idx
[
I0
],
...
@@ -1500,7 +1502,7 @@ struct GridwiseGemm_xdl_cshuffle_v3
...
@@ -1500,7 +1502,7 @@ struct GridwiseGemm_xdl_cshuffle_v3
c_thread_copy_vgpr_to_global
.
Run
(
c_thread_desc_mblock_nblock_m0_m1_n0_m2_m3_n1_n2_m4
,
c_thread_copy_vgpr_to_global
.
Run
(
c_thread_desc_mblock_nblock_m0_m1_n0_m2_m3_n1_n2_m4
,
make_tuple
(
I0
,
I0
,
I0
,
I0
,
I0
,
I0
,
I0
,
I0
,
I0
,
I0
),
make_tuple
(
I0
,
I0
,
I0
,
I0
,
I0
,
I0
,
I0
,
I0
,
I0
,
I0
),
c_thread_buf
,
c_thread_buf
,
c_
block
_desc_mblock_nblock_m0_m1_n0_m2_m3_n1_n2_m4
,
c_
grid
_desc_mblock_nblock_m0_m1_n0_m2_m3_n1_n2_m4
,
c_grid_buf
);
c_grid_buf
);
}
}
...
...
include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp
View file @
69977fab
...
@@ -399,10 +399,57 @@ struct ThreadwiseTensorSliceTransfer_v1r4
...
@@ -399,10 +399,57 @@ struct ThreadwiseTensorSliceTransfer_v1r4
constexpr
auto
dst_dim_access_order
=
DstDimAccessOrder
{};
constexpr
auto
dst_dim_access_order
=
DstDimAccessOrder
{};
constexpr
auto
dst_access_lengths
=
SliceLengths
{}
/
dst_scalar_per_access
;
constexpr
auto
ordered_dst_access_lengths
=
constexpr
auto
ordered_dst_access_lengths
=
container_reorder_given_new2old
(
access_lengths
,
dst_dim_access_order
);
container_reorder_given_new2old
(
dst_access_lengths
,
dst_dim_access_order
);
// make forward steps
const
auto
dst_forward_steps
=
generate_tuple
(
[
&
](
auto
i
)
{
Index
forward_step_idx
;
static_for
<
0
,
nDim
,
1
>
{}([
&
](
auto
j
)
{
forward_step_idx
(
j
)
=
(
i
.
value
==
j
.
value
)
?
dst_scalar_per_access
[
i
]
:
0
;
});
return
make_tensor_coordinate_step
(
dst_desc
,
forward_step_idx
);
},
Number
<
nDim
>
{});
// make backward steps
const
auto
dst_backward_steps
=
generate_tuple
(
[
&
](
auto
i
)
{
Index
backward_step_idx
;
static_for
<
0
,
nDim
,
1
>
{}([
&
](
auto
j
)
{
backward_step_idx
(
j
)
=
(
i
.
value
==
j
.
value
)
?
-
dst_scalar_per_access
[
i
]
:
0
;
});
return
make_tensor_coordinate_step
(
dst_desc
,
backward_step_idx
);
},
Number
<
nDim
>
{});
static_ford
<
decltype
(
ordered_dst_access_lengths
)
>
{}([
&
](
auto
ordered_dst_access_idx
)
{
static_ford
<
decltype
(
ordered_dst_access_lengths
)
>
{}([
&
](
auto
ordered_dst_access_idx
)
{
// judge move forward or move backward
constexpr
auto
forward_sweep
=
[
&
]()
{
StaticallyIndexedArray
<
bool
,
nDim
>
forward_sweep_
;
forward_sweep_
(
Number
<
0
>
{})
=
true
;
static_for
<
1
,
nDim
,
1
>
{}([
&
](
auto
i
)
{
index_t
tmp
=
ordered_dst_access_idx
[
Number
<
0
>
{}];
static_for
<
1
,
i
,
1
>
{}([
&
](
auto
j
)
{
tmp
=
tmp
*
ordered_dst_access_lengths
[
j
]
+
ordered_dst_access_idx
[
j
];
});
forward_sweep_
(
i
)
=
tmp
%
2
==
0
;
});
return
forward_sweep_
;
}();
using
dst_vector_type
=
vector_type_maker_t
<
DstData
,
DstScalarPerVector
>
;
using
dst_vector_type
=
vector_type_maker_t
<
DstData
,
DstScalarPerVector
>
;
using
dst_vector_t
=
typename
dst_vector_type
::
type
;
using
dst_vector_t
=
typename
dst_vector_type
::
type
;
...
@@ -423,10 +470,39 @@ struct ThreadwiseTensorSliceTransfer_v1r4
...
@@ -423,10 +470,39 @@ struct ThreadwiseTensorSliceTransfer_v1r4
is_dst_valid
,
is_dst_valid
,
dst_vector
.
template
AsType
<
dst_vector_t
>()[
Number
<
0
>
{}]);
dst_vector
.
template
AsType
<
dst_vector_t
>()[
Number
<
0
>
{}]);
move_tensor_coordinate
(
constexpr
auto
move_on_dim
=
[
&
]()
constexpr
dst_desc
,
{
dst_coord_
,
StaticallyIndexedArray
<
bool
,
nDim
>
move_on_dim_
;
make_tensor_coordinate_step
(
dst_desc
,
to_multi_index
(
data_to_origin_disp_idx
)));
static_for
<
0
,
nDim
,
1
>
{}([
&
](
auto
i
)
{
move_on_dim_
(
i
)
=
ordered_dst_access_idx
[
i
]
<
ordered_dst_access_lengths
[
i
]
-
1
;
static_for
<
i
+
1
,
nDim
,
1
>
{}([
&
](
auto
j
)
{
move_on_dim_
(
i
)
&=
ordered_dst_access_idx
[
j
]
==
ordered_dst_access_lengths
[
j
]
-
1
;
});
});
return
move_on_dim_
;
}
();
// move dst coord
static_for
<
0
,
nDim
,
1
>
{}([
&
](
auto
i
)
{
if
constexpr
(
move_on_dim
[
i
])
{
if
constexpr
(
forward_sweep
[
i
])
{
move_tensor_coordinate
(
dst_desc
,
dst_coord_
,
dst_forward_steps
[
dst_dim_access_order
[
i
]]);
}
else
{
move_tensor_coordinate
(
dst_desc
,
dst_coord_
,
dst_backward_steps
[
dst_dim_access_order
[
i
]]);
}
}
});
});
});
// move dst coordinate back to slice origin (or not)
// move dst coordinate back to slice origin (or not)
...
@@ -1697,28 +1773,20 @@ struct ThreadwiseTensorSliceTransfer_v5
...
@@ -1697,28 +1773,20 @@ struct ThreadwiseTensorSliceTransfer_v5
constexpr
auto
src_scalar_per_access
=
generate_sequence
(
constexpr
auto
src_scalar_per_access
=
generate_sequence
(
detail
::
lambda_scalar_per_access
<
SrcVectorDim
,
SrcScalarPerVector
>
{},
Number
<
nDim
>
{});
detail
::
lambda_scalar_per_access
<
SrcVectorDim
,
SrcScalarPerVector
>
{},
Number
<
nDim
>
{});
constexpr
auto
access_lengths
=
SliceLengths
{}
/
src_scalar_per_access
;
constexpr
auto
access_lengths
=
SliceLengths
{}
/
src_scalar_per_access
;
constexpr
auto
src_dim_access_order
=
SrcDimAccessOrder
{};
constexpr
auto
src_dim_access_order
=
SrcDimAccessOrder
{};
constexpr
auto
ordered_access_lengths
=
constexpr
auto
ordered_access_lengths
=
container_reorder_given_new2old
(
access_lengths
,
src_dim_access_order
);
container_reorder_given_new2old
(
access_lengths
,
src_dim_access_order
);
static_ford
<
decltype
(
ordered_access_lengths
)
>
{}([
&
](
auto
ordered_access_idx
)
{
static_ford
<
decltype
(
ordered_access_lengths
)
>
{}([
&
](
auto
ordered_access_idx
)
{
// position in slice window
// position in slice window
constexpr
auto
data_to_origin_disp_idx
=
constexpr
auto
data_to_origin_disp_idx
=
ordered_access_idx
.
ReorderGivenOld2New
(
src_dim_access_order
)
*
ordered_access_idx
.
ReorderGivenOld2New
(
src_dim_access_order
)
*
src_scalar_per_access
;
src_scalar_per_access
;
#if 0
if (get_thread_local_1d_id()==0){
printf("%d, %d, %d, %d\n",
data_to_origin_disp_idx.At(Number<0>{}).value,
data_to_origin_disp_idx.At(Number<1>{}).value,
data_to_origin_disp_idx.At(Number<2>{}).value,
data_to_origin_disp_idx.At(Number<3>{}).value);
}
#endif
// src coordinate
// src coordinate
constexpr
auto
src_ref_to_data_disp_idx
=
constexpr
auto
src_ref_to_data_disp_idx
=
src_ref_to_origin_disp_idx
+
data_to_origin_disp_idx
;
src_ref_to_origin_disp_idx
+
data_to_origin_disp_idx
;
...
@@ -1740,16 +1808,9 @@ struct ThreadwiseTensorSliceTransfer_v5
...
@@ -1740,16 +1808,9 @@ struct ThreadwiseTensorSliceTransfer_v5
// copy data from src_buf into src_tmp_vector
// copy data from src_buf into src_tmp_vector
src_tmp_vector
.
template
AsType
<
src_vector_t
>()(
Number
<
0
>
{})
=
src_tmp_vector
.
template
AsType
<
src_vector_t
>()(
Number
<
0
>
{})
=
src_buf
.
template
Get
<
src_vector_t
>(
src_data_coord
.
GetOffset
(),
is_src_valid
);
src_buf
.
template
Get
<
src_vector_t
>(
src_data_coord
.
GetOffset
(),
is_src_valid
);
#if 0
if (get_thread_local_1d_id()<32){
printf("Tid: %02d, Index(%d, %d, %d, %d), offset: %d\n", get_thread_local_1d_id(), src_data_coord.GetIndex().At(Number<0>{}),
src_data_coord.GetIndex().At(Number<1>{}),
src_data_coord.GetIndex().At(Number<2>{}),
src_data_coord.GetIndex().At(Number<3>{}), src_data_coord.GetOffset());
}
#endif
// Set data to scratch
// Set data to scratch
src_thread_scratch_
.
template
SetAsType
_Print
<
src_vector_t
>(
src_thread_scratch_
.
template
SetAsType
<
src_vector_t
>(
data_to_origin_disp_idx
,
src_tmp_vector
.
template
AsType
<
src_vector_t
>()[
I0
]);
data_to_origin_disp_idx
,
src_tmp_vector
.
template
AsType
<
src_vector_t
>()[
I0
]);
});
});
...
@@ -1847,8 +1908,10 @@ struct ThreadwiseTensorSliceTransfer_v5
...
@@ -1847,8 +1908,10 @@ struct ThreadwiseTensorSliceTransfer_v5
constexpr
auto
dst_dim_access_order
=
DstDimAccessOrder
{};
constexpr
auto
dst_dim_access_order
=
DstDimAccessOrder
{};
constexpr
auto
dst_access_lengths
=
SliceLengths
{}
/
dst_scalar_per_access
;
constexpr
auto
ordered_dst_access_lengths
=
constexpr
auto
ordered_dst_access_lengths
=
container_reorder_given_new2old
(
access_lengths
,
dst_dim_access_order
);
container_reorder_given_new2old
(
dst_
access_lengths
,
dst_dim_access_order
);
static_ford
<
decltype
(
ordered_dst_access_lengths
)
>
{}([
&
](
auto
ordered_dst_access_idx
)
{
static_ford
<
decltype
(
ordered_dst_access_lengths
)
>
{}([
&
](
auto
ordered_dst_access_idx
)
{
// position in slice window
// position in slice window
...
...
include/ck/tensor_operation/gpu/warp/xdlops_gemm.hpp
View file @
69977fab
...
@@ -950,22 +950,18 @@ struct XdlopsGemm
...
@@ -950,22 +950,18 @@ struct XdlopsGemm
Sequence
<
7
>
{}));
Sequence
<
7
>
{}));
}
}
template
<
typename
CDesc_
MBlock_NBlock_
M0_N0_M1_N1_M2_N2
>
template
<
typename
CDesc_M0_N0_M1_N1_M2_N2
>
__host__
__device__
static
constexpr
auto
MakeCDescriptor_MBlock_NBlock_M0_M1_N0_M2_M3_N1_N2_M4
(
__host__
__device__
static
constexpr
auto
const
CDesc_MBlock_NBlock
_M0_N0_M1_N1_M2_N2
&
c_desc_
mblock_nblock_
m0_n0_m1_n1_m2_n2
)
MakeCDescriptor_M0_M1_N0_M2_M3_N1_N2_M4
(
const
CDesc
_M0_N0_M1_N1_M2_N2
&
c_desc_m0_n0_m1_n1_m2_n2
)
{
{
const
auto
MBlock
=
c_desc_mblock_nblock_m0_n0_m1_n1_m2_n2
.
GetLength
(
I0
);
const
auto
M0
=
c_desc_m0_n0_m1_n1_m2_n2
.
GetLength
(
I0
);
const
auto
NBlock
=
c_desc_mblock_nblock_m0_n0_m1_n1_m2_n2
.
GetLength
(
I1
);
const
auto
N0
=
c_desc_m0_n0_m1_n1_m2_n2
.
GetLength
(
I1
);
const
auto
M0
=
c_desc_mblock_nblock_m0_n0_m1_n1_m2_n2
.
GetLength
(
I2
);
const
auto
M1
=
c_desc_m0_n0_m1_n1_m2_n2
.
GetLength
(
I2
);
const
auto
N0
=
c_desc_mblock_nblock_m0_n0_m1_n1_m2_n2
.
GetLength
(
I3
);
const
auto
N1
=
c_desc_m0_n0_m1_n1_m2_n2
.
GetLength
(
I3
);
const
auto
M1
=
c_desc_mblock_nblock_m0_n0_m1_n1_m2_n2
.
GetLength
(
I4
);
const
auto
N1
=
c_desc_mblock_nblock_m0_n0_m1_n1_m2_n2
.
GetLength
(
I5
);
return
transform_tensor_descriptor
(
return
transform_tensor_descriptor
(
c_desc_mblock_nblock_m0_n0_m1_n1_m2_n2
,
c_desc_m0_n0_m1_n1_m2_n2
,
make_tuple
(
make_pass_through_transform
(
MBlock
),
make_tuple
(
make_pass_through_transform
(
M0
),
make_pass_through_transform
(
NBlock
),
make_pass_through_transform
(
M0
),
make_pass_through_transform
(
N0
),
make_pass_through_transform
(
N0
),
make_pass_through_transform
(
M1
),
make_pass_through_transform
(
M1
),
make_pass_through_transform
(
N1
),
make_pass_through_transform
(
N1
),
...
@@ -978,17 +974,13 @@ struct XdlopsGemm
...
@@ -978,17 +974,13 @@ struct XdlopsGemm
Sequence
<
2
>
{},
Sequence
<
2
>
{},
Sequence
<
3
>
{},
Sequence
<
3
>
{},
Sequence
<
4
>
{},
Sequence
<
4
>
{},
Sequence
<
5
>
{},
Sequence
<
5
>
{}),
Sequence
<
6
>
{},
Sequence
<
7
>
{}),
make_tuple
(
Sequence
<
0
>
{},
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
6
>
{},
Sequence
<
1
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{},
Sequence
<
2
>
{},
Sequence
<
8
>
{},
Sequence
<
3
,
4
,
7
>
{},
Sequence
<
3
>
{},
Sequence
<
5
>
{}));
Sequence
<
4
>
{},
Sequence
<
5
,
6
,
9
>
{},
Sequence
<
7
>
{}));
}
}
// transposed XDL output supporting C' = B' * A'
// transposed XDL output supporting C' = B' * A'
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment