Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel_ROCM
Commits
72a488cc
Commit
72a488cc
authored
Oct 23, 2024
by
aska-0096
Browse files
All layout sanity pass
parent
ea41fc2f
Changes
6
Expand all
Hide whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
743 additions
and
609 deletions
+743
-609
include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_base.hpp
...eration/gpu/block/blockwise_gemm_pipeline_xdlops_base.hpp
+67
-10
include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_v4.hpp
...operation/gpu/block/blockwise_gemm_pipeline_xdlops_v4.hpp
+63
-70
include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_v5.hpp
...operation/gpu/block/blockwise_gemm_pipeline_xdlops_v5.hpp
+148
-93
include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3.hpp
...nsor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3.hpp
+377
-419
include/ck/tensor_operation/gpu/warp/xdlops_gemm.hpp
include/ck/tensor_operation/gpu/warp/xdlops_gemm.hpp
+34
-0
include/ck/utility/transpose_vectors.hpp
include/ck/utility/transpose_vectors.hpp
+54
-17
No files found.
include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_base.hpp
View file @
72a488cc
...
@@ -154,12 +154,12 @@ struct BlockwiseGemmXdlops_pipeline_base
...
@@ -154,12 +154,12 @@ struct BlockwiseGemmXdlops_pipeline_base
return
make_tuple
(
c_thread_m
,
c_thread_n
);
return
make_tuple
(
c_thread_m
,
c_thread_n
);
}
}
// Contiguous output tile
//
N
Contiguous output tile
template
<
index_t
m0
,
index_t
n0
,
index_t
xdlops_i
,
index_t
blk_i
>
template
<
index_t
m0
,
index_t
n0
,
index_t
xdlops_i
,
index_t
blk_i
>
__device__
static
auto
CalculateCThreadOriginDataIndexContiguous
(
Number
<
m0
>
,
__device__
static
auto
CalculateCThreadOriginDataIndex
N
Contiguous
(
Number
<
m0
>
,
Number
<
n0
>
,
Number
<
n0
>
,
Number
<
xdlops_i
>
,
Number
<
xdlops_i
>
,
Number
<
blk_i
>
)
Number
<
blk_i
>
)
{
{
const
auto
wave_idx
=
GetWaveIdx
();
const
auto
wave_idx
=
GetWaveIdx
();
...
@@ -186,6 +186,38 @@ struct BlockwiseGemmXdlops_pipeline_base
...
@@ -186,6 +186,38 @@ struct BlockwiseGemmXdlops_pipeline_base
return
make_tuple
(
c_thread_m
,
c_thread_n
);
return
make_tuple
(
c_thread_m
,
c_thread_n
);
}
}
// MNContiguous output tile
template
<
index_t
m0
,
index_t
n0
,
index_t
xdlops_i
,
index_t
blk_i
>
__device__
static
auto
CalculateCThreadOriginDataIndexMNContiguous
(
Number
<
m0
>
,
Number
<
n0
>
,
Number
<
xdlops_i
>
,
Number
<
blk_i
>
)
{
const
auto
wave_idx
=
GetWaveIdx
();
const
auto
waveId_m
=
wave_idx
[
I0
];
const
auto
waveId_n
=
wave_idx
[
I1
];
const
auto
blk_idx
=
xdlops_gemm
.
GetBeginOfThreadBlk
(
xdlops_i
,
blk_i
);
constexpr
auto
mrepeat_mwave_mperxdl_to_m_adaptor
=
make_single_stage_tensor_adaptor
(
make_tuple
(
make_unmerge_transform
(
make_tuple
(
MWaves
,
MPerXDL
,
MRepeat
))),
make_tuple
(
Sequence
<
0
>
{}),
make_tuple
(
Sequence
<
0
,
1
,
2
>
{}));
constexpr
auto
nrepeat_nwave_nperxdl_to_n_adaptor
=
make_single_stage_tensor_adaptor
(
make_tuple
(
make_unmerge_transform
(
make_tuple
(
NWaves
,
NPerXDL
,
NRepeat
))),
make_tuple
(
Sequence
<
0
>
{}),
make_tuple
(
Sequence
<
0
,
1
,
2
>
{}));
const
index_t
c_thread_m
=
mrepeat_mwave_mperxdl_to_m_adaptor
.
CalculateBottomIndex
(
make_tuple
(
waveId_m
,
blk_idx
[
I0
],
m0
))[
I0
];
const
index_t
c_thread_n
=
nrepeat_nwave_nperxdl_to_n_adaptor
.
CalculateBottomIndex
(
make_tuple
(
waveId_n
,
blk_idx
[
I1
],
n0
))[
I0
];
return
make_tuple
(
c_thread_m
,
c_thread_n
);
}
template
<
index_t
m0
,
index_t
n0
,
index_t
xdlops_i
,
index_t
blk_i
>
template
<
index_t
m0
,
index_t
n0
,
index_t
xdlops_i
,
index_t
blk_i
>
__device__
static
auto
__device__
static
auto
CalculateCThreadOriginDataIndex8D
(
Number
<
m0
>
,
Number
<
n0
>
,
Number
<
xdlops_i
>
,
Number
<
blk_i
>
)
CalculateCThreadOriginDataIndex8D
(
Number
<
m0
>
,
Number
<
n0
>
,
Number
<
xdlops_i
>
,
Number
<
blk_i
>
)
...
@@ -246,19 +278,30 @@ struct BlockwiseGemmXdlops_pipeline_base
...
@@ -246,19 +278,30 @@ struct BlockwiseGemmXdlops_pipeline_base
make_tuple
(
Number
<
MRepeat
>
{},
Number
<
NRepeat
>
{},
I1
,
I1
,
M0
,
M1
,
M2
,
N
));
make_tuple
(
Number
<
MRepeat
>
{},
Number
<
NRepeat
>
{},
I1
,
I1
,
M0
,
M1
,
M2
,
N
));
}
}
// Contiguous output tile
//
N-
Contiguous output tile
__host__
__device__
static
constexpr
auto
__host__
__device__
static
constexpr
auto
GetCThreadDescriptor_MBlock_NBlock_M0_M1_N0_M2_M3_N1_N2_M4
()
GetCThreadDescriptor_MBlock_NBlock_M0_M1_N0_M2_M3_N1_N2_M4
()
{
{
constexpr
auto
c_m0_m1_m2_n_tblk_lens
=
xdlops_gemm
.
GetCM0M1M2NThreadBlkLengths
();
constexpr
auto
c_m0_m1_m2_n_tblk_lens
=
xdlops_gemm
.
GetCM0M1M2NThreadBlkLengths
();
constexpr
auto
M0
=
c_m0_m1_m2_n_tblk_lens
[
I0
];
constexpr
auto
M0
=
c_m0_m1_m2_n_tblk_lens
[
I0
];
constexpr
auto
M1
=
c_m0_m1_m2_n_tblk_lens
[
I1
];
constexpr
auto
M2
=
c_m0_m1_m2_n_tblk_lens
[
I2
];
constexpr
auto
M2
=
c_m0_m1_m2_n_tblk_lens
[
I2
];
constexpr
auto
N
=
c_m0_m1_m2_n_tblk_lens
[
I3
];
return
make_naive_tensor_descriptor_packed
(
return
make_naive_tensor_descriptor_packed
(
make_tuple
(
I1
,
I1
,
Number
<
MRepeat
>
{},
I1
,
I1
,
M0
,
M1
,
N
,
Number
<
NRepeat
>
{},
M2
));
make_tuple
(
I1
,
I1
,
Number
<
MRepeat
>
{},
I1
,
I1
,
M0
,
I1
,
I1
,
Number
<
NRepeat
>
{},
M2
));
}
// MN-Contiguous output tile
__host__
__device__
static
constexpr
auto
GetCThreadDescriptor_MBlock_NBlock_M0_N0_M1_M2_N1_M3_N2_M4
()
{
constexpr
auto
c_m0_m1_m2_n_tblk_lens
=
xdlops_gemm
.
GetCM0M1M2NThreadBlkLengths
();
constexpr
auto
M0
=
c_m0_m1_m2_n_tblk_lens
[
I0
];
constexpr
auto
M2
=
c_m0_m1_m2_n_tblk_lens
[
I2
];
return
make_naive_tensor_descriptor_packed
(
make_tuple
(
I1
,
I1
,
I1
,
I1
,
M0
,
I1
,
I1
,
Number
<
MRepeat
>
{},
Number
<
NRepeat
>
{},
M2
));
}
}
__host__
__device__
static
constexpr
auto
GetCThreadDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2
()
__host__
__device__
static
constexpr
auto
GetCThreadDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2
()
...
@@ -315,6 +358,20 @@ struct BlockwiseGemmXdlops_pipeline_base
...
@@ -315,6 +358,20 @@ struct BlockwiseGemmXdlops_pipeline_base
return
xdlops_gemm
.
MakeCDescriptor_M0_M1_N0_M2_M3_N1_N2_M4
(
c_block_desc_m0_n0_m1_n1_m2_n2
);
return
xdlops_gemm
.
MakeCDescriptor_M0_M1_N0_M2_M3_N1_N2_M4
(
c_block_desc_m0_n0_m1_n1_m2_n2
);
}
}
// TransposeA
__host__
__device__
static
constexpr
auto
GetCBlockDescriptor_M0_N0_M1_M2_N1_M3_N2_M4
()
{
constexpr
auto
c_block_desc_m0_n0_m1_n1_m2_n2
=
make_naive_tensor_descriptor_packed
(
make_tuple
(
Number
<
MRepeat
>
{},
Number
<
NRepeat
>
{},
Number
<
MWaves
>
{},
Number
<
NWaves
>
{},
Number
<
MPerXDL
>
{},
Number
<
NPerXDL
>
{}));
return
xdlops_gemm
.
MakeCDescriptor_M0_N0_M1_M2_N1_M3_N2_M4
(
c_block_desc_m0_n0_m1_n1_m2_n2
);
}
__host__
__device__
static
constexpr
auto
GetCBlockDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2
()
__host__
__device__
static
constexpr
auto
GetCBlockDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2
()
{
{
constexpr
auto
c_block_desc_g_m0_n0_m1_n1_m2_n2
=
constexpr
auto
c_block_desc_g_m0_n0_m1_n1_m2_n2
=
...
@@ -435,7 +492,7 @@ struct BlockwiseGemmXdlops_pipeline_base
...
@@ -435,7 +492,7 @@ struct BlockwiseGemmXdlops_pipeline_base
decltype
(
b_block_desc_n0_n1_n2_k
),
decltype
(
b_block_desc_n0_n1_n2_k
),
decltype
(
b_thread_desc_
),
decltype
(
b_thread_desc_
),
Sequence
<
NRepeat
,
1
,
1
,
KPack
>
,
Sequence
<
NRepeat
,
1
,
1
,
KPack
>
,
Sequence
<
0
,
1
,
2
,
3
>
,
Sequence
<
2
,
0
,
1
,
3
>
,
Sequence
<
0
,
1
,
2
,
3
>
,
Sequence
<
0
,
1
,
2
,
3
>
,
3
,
3
,
3
,
3
,
...
...
include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_v4.hpp
View file @
72a488cc
...
@@ -32,7 +32,9 @@ template <BlockGemmPipelineScheduler BlkGemmPipelineVer,
...
@@ -32,7 +32,9 @@ template <BlockGemmPipelineScheduler BlkGemmPipelineVer,
index_t
NPerXDL
,
index_t
NPerXDL
,
index_t
MRepeat
,
index_t
MRepeat
,
index_t
NRepeat
,
index_t
NRepeat
,
index_t
KPacks
>
index_t
KPacks
,
bool
TransposeA
,
bool
TransposeB
>
struct
BlockwiseGemmXdlops_pipeline_v4
struct
BlockwiseGemmXdlops_pipeline_v4
{
{
};
};
...
@@ -55,7 +57,9 @@ template <index_t BlockSize,
...
@@ -55,7 +57,9 @@ template <index_t BlockSize,
index_t
NPerXDL
,
index_t
NPerXDL
,
index_t
MRepeat
,
index_t
MRepeat
,
index_t
NRepeat
,
index_t
NRepeat
,
index_t
KPack
index_t
KPack
,
bool
TransposeA
,
bool
TransposeB
// ,bool TransposeC //disable transposec right now...
// ,bool TransposeC //disable transposec right now...
>
>
struct
BlockwiseGemmXdlops_pipeline_v4
<
BlockGemmPipelineScheduler
::
Intrawave
,
struct
BlockwiseGemmXdlops_pipeline_v4
<
BlockGemmPipelineScheduler
::
Intrawave
,
...
@@ -77,7 +81,9 @@ struct BlockwiseGemmXdlops_pipeline_v4<BlockGemmPipelineScheduler::Intrawave,
...
@@ -77,7 +81,9 @@ struct BlockwiseGemmXdlops_pipeline_v4<BlockGemmPipelineScheduler::Intrawave,
NPerXDL
,
NPerXDL
,
MRepeat
,
MRepeat
,
NRepeat
,
NRepeat
,
KPack
>
KPack
,
TransposeA
,
TransposeB
>
:
BlockwiseGemmXdlops_pipeline_base
<
BlockSize
,
:
BlockwiseGemmXdlops_pipeline_base
<
BlockSize
,
ADataType
,
ADataType
,
BDataType
,
BDataType
,
...
@@ -96,7 +102,9 @@ struct BlockwiseGemmXdlops_pipeline_v4<BlockGemmPipelineScheduler::Intrawave,
...
@@ -96,7 +102,9 @@ struct BlockwiseGemmXdlops_pipeline_v4<BlockGemmPipelineScheduler::Intrawave,
NPerXDL
,
NPerXDL
,
MRepeat
,
MRepeat
,
NRepeat
,
NRepeat
,
KPack
>
KPack
,
TransposeA
,
TransposeB
>
{
{
using
Base
=
BlockwiseGemmXdlops_pipeline_base
<
BlockSize
,
using
Base
=
BlockwiseGemmXdlops_pipeline_base
<
BlockSize
,
...
@@ -117,7 +125,9 @@ struct BlockwiseGemmXdlops_pipeline_v4<BlockGemmPipelineScheduler::Intrawave,
...
@@ -117,7 +125,9 @@ struct BlockwiseGemmXdlops_pipeline_v4<BlockGemmPipelineScheduler::Intrawave,
NPerXDL
,
NPerXDL
,
MRepeat
,
MRepeat
,
NRepeat
,
NRepeat
,
KPack
>
;
KPack
,
TransposeA
,
TransposeB
>
;
using
Base
::
I0
;
using
Base
::
I0
;
using
Base
::
I1
;
using
Base
::
I1
;
using
Base
::
KRepeat
;
using
Base
::
KRepeat
;
...
@@ -298,22 +308,18 @@ struct BlockwiseGemmXdlops_pipeline_v4<BlockGemmPipelineScheduler::Intrawave,
...
@@ -298,22 +308,18 @@ struct BlockwiseGemmXdlops_pipeline_v4<BlockGemmPipelineScheduler::Intrawave,
// Local prefetch 1
// Local prefetch 1
block_sync_lds
();
block_sync_lds
();
static_for
<
0
,
KRepeat
,
1
>
{}([
&
](
auto
k
)
{
static_for
<
0
,
KRepeat
,
1
>
{}([
&
](
auto
k
)
{
static_for
<
0
,
MRepeat
,
1
>
{}([
&
](
auto
m0
)
{
a_thread_copy_
.
Run
(
a_block_desc_m0_m1_m2_k
,
a_thread_copy_
.
Run
(
a_block_desc_m0_m1_m2_k
,
make_tuple
(
I0
,
I0
,
I0
,
Number
<
k
*
AMmaKStride
>
{}),
make_tuple
(
m0
,
I0
,
I0
,
Number
<
k
*
AMmaKStride
>
{}),
a_block_buf
.
At
(
I0
),
a_block_buf
.
At
(
I0
),
a_thread_desc_
,
a_thread_desc_
,
make_tuple
(
I0
,
I0
,
k
,
I0
),
make_tuple
(
m0
,
I0
,
k
,
I0
),
a_thread_bufs
(
I0
));
a_thread_bufs
(
I0
));
b_thread_copy_
.
Run
(
b_block_desc_n0_n1_n2_k
,
static_for
<
0
,
NRepeat
,
1
>
{}([
&
](
auto
n0
)
{
make_tuple
(
I0
,
I0
,
I0
,
Number
<
k
*
BMmaKStride
>
{}),
b_thread_copy_
.
Run
(
b_block_desc_n0_n1_n2_k
,
b_block_buf
.
At
(
I0
),
make_tuple
(
n0
,
I0
,
I0
,
Number
<
k
*
BMmaKStride
>
{}),
b_thread_desc_
,
b_block_buf
.
At
(
I0
),
make_tuple
(
I0
,
I0
,
k
,
I0
),
b_thread_desc_
,
b_thread_bufs
(
I0
));
make_tuple
(
n0
,
I0
,
k
,
I0
),
b_thread_bufs
(
I0
));
});
});
});
});
// Global prefetch 3
// Global prefetch 3
...
@@ -349,23 +355,18 @@ struct BlockwiseGemmXdlops_pipeline_v4<BlockGemmPipelineScheduler::Intrawave,
...
@@ -349,23 +355,18 @@ struct BlockwiseGemmXdlops_pipeline_v4<BlockGemmPipelineScheduler::Intrawave,
block_sync_lds
();
block_sync_lds
();
static_for
<
0
,
KRepeat
,
1
>
{}([
&
](
auto
k
)
{
static_for
<
0
,
KRepeat
,
1
>
{}([
&
](
auto
k
)
{
static_for
<
0
,
MRepeat
,
1
>
{}([
&
](
auto
m0
)
{
a_thread_copy_
.
Run
(
a_block_desc_m0_m1_m2_k
,
a_thread_copy_
.
Run
(
a_block_desc_m0_m1_m2_k
,
make_tuple
(
I0
,
I0
,
I0
,
Number
<
k
*
AMmaKStride
>
{}),
make_tuple
(
m0
,
I0
,
I0
,
Number
<
k
*
AMmaKStride
>
{}),
a_block_buf
.
At
(
lds_read_buf
),
a_block_buf
.
At
(
lds_read_buf
),
a_thread_desc_
,
a_thread_desc_
,
make_tuple
(
I0
,
I0
,
k
,
I0
),
make_tuple
(
m0
,
I0
,
k
,
I0
),
a_thread_bufs
(
lds_read_reg_buf
));
a_thread_bufs
(
lds_read_reg_buf
));
b_thread_copy_
.
Run
(
b_block_desc_n0_n1_n2_k
,
static_for
<
0
,
NRepeat
,
1
>
{}([
&
](
auto
n0
)
{
make_tuple
(
I0
,
I0
,
I0
,
Number
<
k
*
BMmaKStride
>
{}),
b_thread_copy_
.
Run
(
b_block_buf
.
At
(
lds_read_buf
),
b_block_desc_n0_n1_n2_k
,
b_thread_desc_
,
make_tuple
(
n0
,
I0
,
I0
,
Number
<
k
*
BMmaKStride
>
{}),
make_tuple
(
I0
,
I0
,
k
,
I0
),
b_block_buf
.
At
(
lds_read_buf
),
b_thread_bufs
(
lds_read_reg_buf
));
b_thread_desc_
,
make_tuple
(
n0
,
I0
,
k
,
I0
),
b_thread_bufs
(
lds_read_reg_buf
));
});
});
});
});
a_blockwise_copy
.
RunWrite
(
a_blockwise_copy
.
RunWrite
(
...
@@ -430,22 +431,18 @@ struct BlockwiseGemmXdlops_pipeline_v4<BlockGemmPipelineScheduler::Intrawave,
...
@@ -430,22 +431,18 @@ struct BlockwiseGemmXdlops_pipeline_v4<BlockGemmPipelineScheduler::Intrawave,
block_sync_lds
();
block_sync_lds
();
static_for
<
0
,
KRepeat
,
1
>
{}([
&
](
auto
k
)
{
static_for
<
0
,
KRepeat
,
1
>
{}([
&
](
auto
k
)
{
static_for
<
0
,
MRepeat
,
1
>
{}([
&
](
auto
m0
)
{
a_thread_copy_
.
Run
(
a_block_desc_m0_m1_m2_k
,
a_thread_copy_
.
Run
(
a_block_desc_m0_m1_m2_k
,
make_tuple
(
I0
,
I0
,
I0
,
Number
<
k
*
AMmaKStride
>
{}),
make_tuple
(
m0
,
I0
,
I0
,
Number
<
k
*
AMmaKStride
>
{}),
a_block_buf
.
At
(
lds_read_buf
),
a_block_buf
.
At
(
lds_read_buf
),
a_thread_desc_
,
a_thread_desc_
,
make_tuple
(
I0
,
I0
,
k
,
I0
),
make_tuple
(
m0
,
I0
,
k
,
I0
),
a_thread_bufs
(
lds_read_reg_buf
));
a_thread_bufs
(
lds_read_reg_buf
));
b_thread_copy_
.
Run
(
b_block_desc_n0_n1_n2_k
,
static_for
<
0
,
NRepeat
,
1
>
{}([
&
](
auto
n0
)
{
make_tuple
(
I0
,
I0
,
I0
,
Number
<
k
*
BMmaKStride
>
{}),
b_thread_copy_
.
Run
(
b_block_desc_n0_n1_n2_k
,
b_block_buf
.
At
(
lds_read_buf
),
make_tuple
(
n0
,
I0
,
I0
,
Number
<
k
*
BMmaKStride
>
{}),
b_thread_desc_
,
b_block_buf
.
At
(
lds_read_buf
),
make_tuple
(
I0
,
I0
,
k
,
I0
),
b_thread_desc_
,
b_thread_bufs
(
lds_read_reg_buf
));
make_tuple
(
n0
,
I0
,
k
,
I0
),
b_thread_bufs
(
lds_read_reg_buf
));
});
});
});
});
a_blockwise_copy
.
RunWrite
(
a_block_desc
,
a_block_buf
.
At
(
lds_write_buf
),
vmem_buf
);
a_blockwise_copy
.
RunWrite
(
a_block_desc
,
a_block_buf
.
At
(
lds_write_buf
),
vmem_buf
);
...
@@ -489,22 +486,18 @@ struct BlockwiseGemmXdlops_pipeline_v4<BlockGemmPipelineScheduler::Intrawave,
...
@@ -489,22 +486,18 @@ struct BlockwiseGemmXdlops_pipeline_v4<BlockGemmPipelineScheduler::Intrawave,
block_sync_lds
();
block_sync_lds
();
static_for
<
0
,
KRepeat
,
1
>
{}([
&
](
auto
k
)
{
static_for
<
0
,
KRepeat
,
1
>
{}([
&
](
auto
k
)
{
static_for
<
0
,
MRepeat
,
1
>
{}([
&
](
auto
m0
)
{
a_thread_copy_
.
Run
(
a_block_desc_m0_m1_m2_k
,
a_thread_copy_
.
Run
(
a_block_desc_m0_m1_m2_k
,
make_tuple
(
I0
,
I0
,
I0
,
Number
<
k
*
AMmaKStride
>
{}),
make_tuple
(
m0
,
I0
,
I0
,
Number
<
k
*
AMmaKStride
>
{}),
a_block_buf
.
At
(
lds_read_buf
),
a_block_buf
.
At
(
lds_read_buf
),
a_thread_desc_
,
a_thread_desc_
,
make_tuple
(
I0
,
I0
,
k
,
I0
),
make_tuple
(
m0
,
I0
,
k
,
I0
),
a_thread_bufs
(
lds_read_reg_buf
));
a_thread_bufs
(
lds_read_reg_buf
));
b_thread_copy_
.
Run
(
b_block_desc_n0_n1_n2_k
,
static_for
<
0
,
NRepeat
,
1
>
{}([
&
](
auto
n0
)
{
make_tuple
(
I0
,
I0
,
I0
,
Number
<
k
*
BMmaKStride
>
{}),
b_thread_copy_
.
Run
(
b_block_desc_n0_n1_n2_k
,
b_block_buf
.
At
(
lds_read_buf
),
make_tuple
(
n0
,
I0
,
I0
,
Number
<
k
*
BMmaKStride
>
{}),
b_thread_desc_
,
b_block_buf
.
At
(
lds_read_buf
),
make_tuple
(
I0
,
I0
,
k
,
I0
),
b_thread_desc_
,
b_thread_bufs
(
lds_read_reg_buf
));
make_tuple
(
n0
,
I0
,
k
,
I0
),
b_thread_bufs
(
lds_read_reg_buf
));
});
});
});
});
static_for
<
0
,
KRepeat
,
1
>
{}([
&
](
auto
k0
)
{
static_for
<
0
,
KRepeat
,
1
>
{}([
&
](
auto
k0
)
{
...
...
include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_v5.hpp
View file @
72a488cc
...
@@ -32,7 +32,9 @@ template <BlockGemmPipelineScheduler BlkGemmPipelineVer,
...
@@ -32,7 +32,9 @@ template <BlockGemmPipelineScheduler BlkGemmPipelineVer,
index_t
NPerXDL
,
index_t
NPerXDL
,
index_t
MRepeat
,
index_t
MRepeat
,
index_t
NRepeat
,
index_t
NRepeat
,
index_t
KPacks
>
index_t
KPacks
,
bool
TransposeA
,
bool
TransposeB
>
struct
BlockwiseGemmXdlops_pipeline_v5
struct
BlockwiseGemmXdlops_pipeline_v5
{
{
};
};
...
@@ -55,7 +57,9 @@ template <index_t BlockSize,
...
@@ -55,7 +57,9 @@ template <index_t BlockSize,
index_t
NPerXDL
,
index_t
NPerXDL
,
index_t
MRepeat
,
index_t
MRepeat
,
index_t
NRepeat
,
index_t
NRepeat
,
index_t
KPack
index_t
KPack
,
bool
TransposeA
,
bool
TransposeB
// ,bool TransposeC //disable transposec right now...
// ,bool TransposeC //disable transposec right now...
>
>
struct
BlockwiseGemmXdlops_pipeline_v5
<
BlockGemmPipelineScheduler
::
Intrawave
,
struct
BlockwiseGemmXdlops_pipeline_v5
<
BlockGemmPipelineScheduler
::
Intrawave
,
...
@@ -77,7 +81,9 @@ struct BlockwiseGemmXdlops_pipeline_v5<BlockGemmPipelineScheduler::Intrawave,
...
@@ -77,7 +81,9 @@ struct BlockwiseGemmXdlops_pipeline_v5<BlockGemmPipelineScheduler::Intrawave,
NPerXDL
,
NPerXDL
,
MRepeat
,
MRepeat
,
NRepeat
,
NRepeat
,
KPack
>
KPack
,
TransposeA
,
TransposeB
>
:
BlockwiseGemmXdlops_pipeline_base
<
BlockSize
,
:
BlockwiseGemmXdlops_pipeline_base
<
BlockSize
,
ADataType
,
ADataType
,
BDataType
,
BDataType
,
...
@@ -96,7 +102,9 @@ struct BlockwiseGemmXdlops_pipeline_v5<BlockGemmPipelineScheduler::Intrawave,
...
@@ -96,7 +102,9 @@ struct BlockwiseGemmXdlops_pipeline_v5<BlockGemmPipelineScheduler::Intrawave,
NPerXDL
,
NPerXDL
,
MRepeat
,
MRepeat
,
NRepeat
,
NRepeat
,
KPack
>
KPack
,
TransposeA
,
TransposeB
>
{
{
using
Base
=
BlockwiseGemmXdlops_pipeline_base
<
BlockSize
,
using
Base
=
BlockwiseGemmXdlops_pipeline_base
<
BlockSize
,
...
@@ -117,7 +125,9 @@ struct BlockwiseGemmXdlops_pipeline_v5<BlockGemmPipelineScheduler::Intrawave,
...
@@ -117,7 +125,9 @@ struct BlockwiseGemmXdlops_pipeline_v5<BlockGemmPipelineScheduler::Intrawave,
NPerXDL
,
NPerXDL
,
MRepeat
,
MRepeat
,
NRepeat
,
NRepeat
,
KPack
>
;
KPack
,
TransposeA
,
TransposeB
>
;
using
Base
::
A_K1
;
using
Base
::
A_K1
;
using
Base
::
B_K1
;
using
Base
::
B_K1
;
using
Base
::
I0
;
using
Base
::
I0
;
...
@@ -381,22 +391,19 @@ struct BlockwiseGemmXdlops_pipeline_v5<BlockGemmPipelineScheduler::Intrawave,
...
@@ -381,22 +391,19 @@ struct BlockwiseGemmXdlops_pipeline_v5<BlockGemmPipelineScheduler::Intrawave,
// Local prefetch 1
// Local prefetch 1
block_sync_lds
();
block_sync_lds
();
static_for
<
0
,
MRepeat
,
1
>
{}([
&
](
auto
m0
)
{
a_thread_copy_
.
Run
(
a_block_desc_m0_m1_m2_k
,
a_thread_copy_
.
Run
(
a_block_desc_m0_m1_m2_k
,
make_tuple
(
I0
,
I0
,
I0
,
I0
),
make_tuple
(
m0
,
I0
,
I0
,
I0
),
a_block_buf
,
a_block_buf
,
a_thread_desc_
,
a_thread_desc_
,
make_tuple
(
I0
,
I0
,
I0
,
I0
),
make_tuple
(
m0
,
I0
,
I0
,
I0
),
a_thread_buf
);
a_thread_buf
);
});
b_thread_copy_
.
Run
(
b_block_desc_n0_n1_n2_k
,
static_for
<
0
,
NRepeat
,
1
>
{}([
&
](
auto
n0
)
{
make_tuple
(
I0
,
I0
,
I0
,
I0
),
b_thread_copy_
.
Run
(
b_block_desc_n0_n1_n2_k
,
b_block_buf
,
make_tuple
(
n0
,
I0
,
I0
,
I0
),
b_thread_desc_
,
b_block_buf
,
make_tuple
(
I0
,
I0
,
I0
,
I0
),
b_thread_desc_
,
b_thread_buf
);
make_tuple
(
n0
,
I0
,
I0
,
I0
),
b_thread_buf
);
});
// main body
// main body
if
constexpr
(
HasMainLoop
)
if
constexpr
(
HasMainLoop
)
...
@@ -449,25 +456,23 @@ struct BlockwiseGemmXdlops_pipeline_v5<BlockGemmPipelineScheduler::Intrawave,
...
@@ -449,25 +456,23 @@ struct BlockwiseGemmXdlops_pipeline_v5<BlockGemmPipelineScheduler::Intrawave,
b_thread_vec
.
template
AsType
<
mfma_input_type
>(),
b_thread_vec
.
template
AsType
<
mfma_input_type
>(),
c_thread_buf
.
GetVectorTypeReference
(
Number
<
c_offset
>
{}));
c_thread_buf
.
GetVectorTypeReference
(
Number
<
c_offset
>
{}));
});
});
a_thread_copy_
.
Run
(
a_block_desc_m0_m1_m2_k
,
make_tuple
(
m0
,
I0
,
I0
,
Number
<
(
k0
+
1
)
%
KRepeat
*
AMmaKStride
>
{}),
a_block_buf
,
a_thread_desc_
,
make_tuple
(
m0
,
I0
,
I0
,
I0
),
a_thread_buf
);
});
});
static_for
<
0
,
NRepeat
,
1
>
{}([
&
](
auto
n0
)
{
a_thread_copy_
.
Run
(
b_thread_copy_
.
Run
(
a_block_desc_m0_m1_m2_k
,
b_block_desc_n0_n1_n2_k
,
make_tuple
(
I0
,
I0
,
I0
,
Number
<
(
k0
+
1
)
%
KRepeat
*
AMmaKStride
>
{}),
make_tuple
(
n0
,
I0
,
I0
,
Number
<
(
k0
+
1
)
%
KRepeat
*
BMmaKStride
>
{}),
a_block_buf
,
b_block_buf
,
a_thread_desc_
,
b_thread_desc_
,
make_tuple
(
I0
,
I0
,
I0
,
I0
),
make_tuple
(
n0
,
I0
,
I0
,
I0
),
a_thread_buf
);
b_thread_buf
);
});
b_thread_copy_
.
Run
(
b_block_desc_n0_n1_n2_k
,
make_tuple
(
I0
,
I0
,
I0
,
Number
<
(
k0
+
1
)
%
KRepeat
*
BMmaKStride
>
{}),
b_block_buf
,
b_thread_desc_
,
make_tuple
(
I0
,
I0
,
I0
,
I0
),
b_thread_buf
);
});
});
HotLoopScheduler
();
HotLoopScheduler
();
...
@@ -517,24 +522,23 @@ struct BlockwiseGemmXdlops_pipeline_v5<BlockGemmPipelineScheduler::Intrawave,
...
@@ -517,24 +522,23 @@ struct BlockwiseGemmXdlops_pipeline_v5<BlockGemmPipelineScheduler::Intrawave,
b_thread_vec
.
template
AsType
<
mfma_input_type
>(),
b_thread_vec
.
template
AsType
<
mfma_input_type
>(),
c_thread_buf
.
GetVectorTypeReference
(
Number
<
c_offset
>
{}));
c_thread_buf
.
GetVectorTypeReference
(
Number
<
c_offset
>
{}));
});
});
a_thread_copy_
.
Run
(
a_block_desc_m0_m1_m2_k
,
make_tuple
(
m0
,
I0
,
I0
,
Number
<
(
k0
+
1
)
%
KRepeat
*
AMmaKStride
>
{}),
a_block_buf
,
a_thread_desc_
,
make_tuple
(
m0
,
I0
,
I0
,
I0
),
a_thread_buf
);
});
});
static_for
<
0
,
NRepeat
,
1
>
{}([
&
](
auto
n0
)
{
a_thread_copy_
.
Run
(
b_thread_copy_
.
Run
(
a_block_desc_m0_m1_m2_k
,
b_block_desc_n0_n1_n2_k
,
make_tuple
(
I0
,
I0
,
I0
,
Number
<
(
k0
+
1
)
%
KRepeat
*
AMmaKStride
>
{}),
make_tuple
(
n0
,
I0
,
I0
,
Number
<
(
k0
+
1
)
%
KRepeat
*
BMmaKStride
>
{}),
a_block_buf
,
b_block_buf
,
a_thread_desc_
,
b_thread_desc_
,
make_tuple
(
I0
,
I0
,
I0
,
I0
),
make_tuple
(
n0
,
I0
,
I0
,
I0
),
a_thread_buf
);
b_thread_buf
);
});
b_thread_copy_
.
Run
(
b_block_desc_n0_n1_n2_k
,
make_tuple
(
I0
,
I0
,
I0
,
Number
<
(
k0
+
1
)
%
KRepeat
*
BMmaKStride
>
{}),
b_block_buf
,
b_thread_desc_
,
make_tuple
(
I0
,
I0
,
I0
,
I0
),
b_thread_buf
);
});
});
HotLoopScheduler
();
HotLoopScheduler
();
...
@@ -567,25 +571,23 @@ struct BlockwiseGemmXdlops_pipeline_v5<BlockGemmPipelineScheduler::Intrawave,
...
@@ -567,25 +571,23 @@ struct BlockwiseGemmXdlops_pipeline_v5<BlockGemmPipelineScheduler::Intrawave,
b_thread_vec
.
template
AsType
<
mfma_input_type
>(),
b_thread_vec
.
template
AsType
<
mfma_input_type
>(),
c_thread_buf
.
GetVectorTypeReference
(
Number
<
c_offset
>
{}));
c_thread_buf
.
GetVectorTypeReference
(
Number
<
c_offset
>
{}));
});
});
a_thread_copy_
.
Run
(
a_block_desc_m0_m1_m2_k
,
make_tuple
(
m0
,
I0
,
I0
,
Number
<
(
k0
+
1
)
%
KRepeat
*
AMmaKStride
>
{}),
a_block_buf
,
a_thread_desc_
,
make_tuple
(
m0
,
I0
,
I0
,
I0
),
a_thread_buf
);
});
});
static_for
<
0
,
NRepeat
,
1
>
{}([
&
](
auto
n0
)
{
a_thread_copy_
.
Run
(
b_thread_copy_
.
Run
(
a_block_desc_m0_m1_m2_k
,
b_block_desc_n0_n1_n2_k
,
make_tuple
(
I0
,
I0
,
I0
,
Number
<
(
k0
+
1
)
%
KRepeat
*
AMmaKStride
>
{}),
make_tuple
(
n0
,
I0
,
I0
,
Number
<
(
k0
+
1
)
%
KRepeat
*
BMmaKStride
>
{}),
a_block_buf
,
b_block_buf
,
a_thread_desc_
,
b_thread_desc_
,
make_tuple
(
I0
,
I0
,
I0
,
I0
),
make_tuple
(
n0
,
I0
,
I0
,
I0
),
a_thread_buf
);
b_thread_buf
);
});
b_thread_copy_
.
Run
(
b_block_desc_n0_n1_n2_k
,
make_tuple
(
I0
,
I0
,
I0
,
Number
<
(
k0
+
1
)
%
KRepeat
*
BMmaKStride
>
{}),
b_block_buf
,
b_thread_desc_
,
make_tuple
(
I0
,
I0
,
I0
,
I0
),
b_thread_buf
);
});
});
static_for
<
0
,
MRepeat
,
1
>
{}([
&
](
auto
m0
)
{
static_for
<
0
,
MRepeat
,
1
>
{}([
&
](
auto
m0
)
{
...
@@ -636,28 +638,81 @@ struct BlockwiseGemmXdlops_pipeline_v5<BlockGemmPipelineScheduler::Intrawave,
...
@@ -636,28 +638,81 @@ struct BlockwiseGemmXdlops_pipeline_v5<BlockGemmPipelineScheduler::Intrawave,
static
constexpr
auto
b_thread_desc_
=
static
constexpr
auto
b_thread_desc_
=
make_naive_tensor_descriptor_packed
(
make_tuple
(
Number
<
NRepeat
>
{},
I1
,
I1
,
Number
<
KPack
>
{}));
make_naive_tensor_descriptor_packed
(
make_tuple
(
Number
<
NRepeat
>
{},
I1
,
I1
,
Number
<
KPack
>
{}));
using
AThreadCopy
=
ThreadwiseTensorSliceTransfer_v4
<
ADataType
,
template
<
bool
Transpose
>
ComputeDataType
,
struct
AThreadCopySelector
;
decltype
(
a_block_desc_m0_m1_m2_k
),
decltype
(
a_thread_desc_
),
template
<
>
Sequence
<
1
,
1
,
1
,
KPack
>
,
struct
AThreadCopySelector
<
false
>
Sequence
<
0
,
1
,
2
,
3
>
,
{
3
,
using
type
=
ThreadwiseTensorSliceTransfer_v5
<
ADataType
,
A_K1
,
ComputeDataType
,
A_K1
>
;
decltype
(
a_block_desc_m0_m1_m2_k
),
decltype
(
a_thread_desc_
),
using
BThreadCopy
=
ThreadwiseTensorSliceTransfer_v4
<
BDataType
,
Sequence
<
MRepeat
,
1
,
1
,
KPack
>
,
ComputeDataType
,
Sequence
<
0
,
1
,
2
,
3
>
,
decltype
(
b_block_desc_n0_n1_n2_k
),
Sequence
<
0
,
1
,
2
,
3
>
,
decltype
(
b_thread_desc_
),
3
,
Sequence
<
1
,
1
,
1
,
KPack
>
,
3
,
Sequence
<
0
,
1
,
2
,
3
>
,
A_K1
,
3
,
A_K1
>
;
B_K1
,
};
B_K1
>
;
template
<
>
AThreadCopy
a_thread_copy_
{
Base
::
CalculateAThreadOriginDataIndex
()};
struct
AThreadCopySelector
<
true
>
BThreadCopy
b_thread_copy_
{
Base
::
CalculateBThreadOriginDataIndex
()};
{
using
type
=
ThreadwiseTensorSliceTransfer_v5
<
ADataType
,
ComputeDataType
,
decltype
(
a_block_desc_m0_m1_m2_k
),
decltype
(
a_thread_desc_
),
Sequence
<
MRepeat
,
1
,
1
,
KPack
>
,
Sequence
<
3
,
1
,
2
,
0
>
,
Sequence
<
0
,
1
,
2
,
3
>
,
0
,
3
,
MRepeat
,
A_K1
>
;
};
template
<
bool
Transpose
>
struct
BThreadCopySelector
;
template
<
>
struct
BThreadCopySelector
<
false
>
{
using
type
=
ThreadwiseTensorSliceTransfer_v5
<
BDataType
,
ComputeDataType
,
decltype
(
b_block_desc_n0_n1_n2_k
),
decltype
(
b_thread_desc_
),
Sequence
<
NRepeat
,
1
,
1
,
KPack
>
,
Sequence
<
2
,
0
,
1
,
3
>
,
Sequence
<
0
,
1
,
2
,
3
>
,
3
,
3
,
B_K1
,
B_K1
>
;
};
template
<
>
struct
BThreadCopySelector
<
true
>
{
using
type
=
ThreadwiseTensorSliceTransfer_v5
<
BDataType
,
ComputeDataType
,
decltype
(
b_block_desc_n0_n1_n2_k
),
decltype
(
b_thread_desc_
),
Sequence
<
NRepeat
,
1
,
1
,
KPack
>
,
Sequence
<
3
,
1
,
2
,
0
>
,
Sequence
<
0
,
1
,
2
,
3
>
,
0
,
3
,
NRepeat
,
B_K1
>
;
};
typename
AThreadCopySelector
<
TransposeA
>::
type
a_thread_copy_
{
Base
::
CalculateAThreadOriginDataIndex
()};
typename
BThreadCopySelector
<
TransposeB
>::
type
b_thread_copy_
{
Base
::
CalculateBThreadOriginDataIndex
()};
using
Base
::
c_thread_desc_
;
using
Base
::
c_thread_desc_
;
};
};
...
...
include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3.hpp
View file @
72a488cc
This diff is collapsed.
Click to expand it.
include/ck/tensor_operation/gpu/warp/xdlops_gemm.hpp
View file @
72a488cc
...
@@ -983,6 +983,40 @@ struct XdlopsGemm
...
@@ -983,6 +983,40 @@ struct XdlopsGemm
Sequence
<
5
>
{}));
Sequence
<
5
>
{}));
}
}
// TransposeA
template
<
typename
CDesc_M0_N0_M1_N1_M2_N2
>
__host__
__device__
static
constexpr
auto
MakeCDescriptor_M0_N0_M1_M2_N1_M3_N2_M4
(
const
CDesc_M0_N0_M1_N1_M2_N2
&
c_desc_m0_n0_m1_n1_m2_n2
)
{
const
auto
M0
=
c_desc_m0_n0_m1_n1_m2_n2
.
GetLength
(
I0
);
const
auto
N0
=
c_desc_m0_n0_m1_n1_m2_n2
.
GetLength
(
I1
);
const
auto
M1
=
c_desc_m0_n0_m1_n1_m2_n2
.
GetLength
(
I2
);
const
auto
N1
=
c_desc_m0_n0_m1_n1_m2_n2
.
GetLength
(
I3
);
return
transform_tensor_descriptor
(
c_desc_m0_n0_m1_n1_m2_n2
,
make_tuple
(
make_pass_through_transform
(
M0
),
make_pass_through_transform
(
N0
),
make_pass_through_transform
(
M1
),
make_pass_through_transform
(
N1
),
make_unmerge_transform
(
make_tuple
(
Number
<
mfma_instr
.
num_groups_per_blk
>
{},
Number
<
mfma_instr
.
num_input_blks
>
{},
Number
<
mfma_instr
.
group_size
>
{})),
make_pass_through_transform
(
Number
<
mfma_instr
.
num_threads_per_blk
>
{})),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{},
Sequence
<
3
>
{},
Sequence
<
4
>
{},
Sequence
<
5
>
{}),
make_tuple
(
Sequence
<
7
>
{},
Sequence
<
6
>
{},
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
,
3
,
5
>
{},
Sequence
<
4
>
{}));
}
// transposed XDL output supporting C' = B' * A'
// transposed XDL output supporting C' = B' * A'
// M2_N2 -> M2_N2_N3_N4
// M2_N2 -> M2_N2_N3_N4
template
<
typename
CDesc_M0_N0_M1_N1_M2_N2
>
template
<
typename
CDesc_M0_N0_M1_N1_M2_N2
>
...
...
include/ck/utility/transpose_vectors.hpp
View file @
72a488cc
...
@@ -18,22 +18,6 @@ struct transpose_vectors;
...
@@ -18,22 +18,6 @@ struct transpose_vectors;
// transpose fp16 2x2
// transpose fp16 2x2
__device__
void
transpose_fp16_2x2
(
const
half2_t
&
x0
,
const
half2_t
&
x1
,
half2_t
&
y0
,
half2_t
&
y1
)
__device__
void
transpose_fp16_2x2
(
const
half2_t
&
x0
,
const
half2_t
&
x1
,
half2_t
&
y0
,
half2_t
&
y1
)
{
{
#if 0
static constexpr auto I0 = Number<0>{};
static constexpr auto I1 = Number<1>{};
const vector_type<half_t, 2> vx0{x0}, vx1{x1};
vector_type<half_t, 2> vy0, vy1;
vy0.template AsType<half_t>()(I0) = vx0.template AsType<half_t>()[I0];
vy0.template AsType<half_t>()(I1) = vx1.template AsType<half_t>()[I0];
vy1.template AsType<half_t>()(I0) = vx0.template AsType<half_t>()[I1];
vy1.template AsType<half_t>()(I1) = vx1.template AsType<half_t>()[I1];
y0 = vy0.template AsType<half2_t>()[I0];
y1 = vy1.template AsType<half2_t>()[I0];
#else
constexpr
int32_t
m0
=
0x05040100
;
constexpr
int32_t
m0
=
0x05040100
;
constexpr
int32_t
m1
=
0x07060302
;
constexpr
int32_t
m1
=
0x07060302
;
...
@@ -43,7 +27,6 @@ __device__ void transpose_fp16_2x2(const half2_t& x0, const half2_t& x1, half2_t
...
@@ -43,7 +27,6 @@ __device__ void transpose_fp16_2x2(const half2_t& x0, const half2_t& x1, half2_t
// index is reversed because of little endianness (least significant bits first)
// index is reversed because of little endianness (least significant bits first)
y0
=
bit_cast
<
half2_t
>
(
__builtin_amdgcn_perm
(
bit_cast
<
int32_t
>
(
x1
),
bit_cast
<
int32_t
>
(
x0
),
m0
));
y0
=
bit_cast
<
half2_t
>
(
__builtin_amdgcn_perm
(
bit_cast
<
int32_t
>
(
x1
),
bit_cast
<
int32_t
>
(
x0
),
m0
));
y1
=
bit_cast
<
half2_t
>
(
__builtin_amdgcn_perm
(
bit_cast
<
int32_t
>
(
x1
),
bit_cast
<
int32_t
>
(
x0
),
m1
));
y1
=
bit_cast
<
half2_t
>
(
__builtin_amdgcn_perm
(
bit_cast
<
int32_t
>
(
x1
),
bit_cast
<
int32_t
>
(
x0
),
m1
));
#endif
}
}
template
<
index_t
NX
,
index_t
NY
>
template
<
index_t
NX
,
index_t
NY
>
...
@@ -83,6 +66,60 @@ struct transpose_vectors<half_t, NX, NY>
...
@@ -83,6 +66,60 @@ struct transpose_vectors<half_t, NX, NY>
}
}
};
};
// transpose bf16 2x2
__device__
void
transpose_bf16_2x2
(
const
bhalf2_t
&
x0
,
const
bhalf2_t
&
x1
,
bhalf2_t
&
y0
,
bhalf2_t
&
y1
)
{
constexpr
int32_t
m0
=
0x05040100
;
constexpr
int32_t
m1
=
0x07060302
;
// ex: v_perm_b32(0x 11 22 33 44, 0x 55 66 77 88, 0x 05 01 04 00) -> 0x33774488
// -- -- -- -- -- -- -- -- - - - -
// index 7 6 5 4 3 2 1 0 33 77 44 88
// index is reversed because of little endianness (least significant bits first)
y0
=
bit_cast
<
bhalf2_t
>
(
__builtin_amdgcn_perm
(
bit_cast
<
int32_t
>
(
x1
),
bit_cast
<
int32_t
>
(
x0
),
m0
));
y1
=
bit_cast
<
bhalf2_t
>
(
__builtin_amdgcn_perm
(
bit_cast
<
int32_t
>
(
x1
),
bit_cast
<
int32_t
>
(
x0
),
m1
));
}
template
<
index_t
NX
,
index_t
NY
>
struct
transpose_vectors
<
bhalf_t
,
NX
,
NY
>
{
// we got [NY * NX] amount of S data to be transposed
static
constexpr
index_t
s_per_x
=
NY
;
static
constexpr
index_t
s_per_y
=
NX
;
using
S
=
bhalf_t
;
using
VX
=
vector_type
<
bhalf_t
,
s_per_x
>
;
using
VY
=
vector_type
<
bhalf_t
,
s_per_y
>
;
__device__
void
operator
()(
const
StaticallyIndexedArray
<
const
VX
&
,
NX
>&
vx_tuple
,
StaticallyIndexedArray
<
VY
&
,
NY
>&
vy_tuple
)
{
static
constexpr
auto
I1
=
Number
<
1
>
{};
static
constexpr
auto
I2
=
Number
<
2
>
{};
static_assert
((
NX
%
2
==
0
&&
NY
%
2
==
0
),
"wrong!"
);
// loop over 2x2 tile and transpose data from vx_tuple into vy_tuple
static_for
<
0
,
NY
,
2
>
{}([
&
](
auto
iy
)
{
static_for
<
0
,
NX
,
2
>
{}([
&
](
auto
ix
)
{
// reference to 2 bhalf2_t data from vx_tuple
const
auto
&
x_s2_0
=
vx_tuple
[
ix
].
template
AsType
<
bhalf2_t
>()[
iy
/
I2
];
const
auto
&
x_s2_1
=
vx_tuple
[
ix
+
I1
].
template
AsType
<
bhalf2_t
>()[
iy
/
I2
];
// reference to 2 bhalf2_t data from vy_tuple
auto
&
y_s2_0
=
vy_tuple
(
iy
).
template
AsType
<
bhalf2_t
>()(
ix
/
I2
);
auto
&
y_s2_1
=
vy_tuple
(
iy
+
I1
).
template
AsType
<
bhalf2_t
>()(
ix
/
I2
);
// transpose
transpose_bf16_2x2
(
x_s2_0
,
x_s2_1
,
y_s2_0
,
y_s2_1
);
});
});
}
};
// transpose int8 4x4
// transpose int8 4x4
__device__
void
transpose_int8_4x4
(
const
int8x4_t
&
x0
,
__device__
void
transpose_int8_4x4
(
const
int8x4_t
&
x0
,
const
int8x4_t
&
x1
,
const
int8x4_t
&
x1
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment