Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
d3341a67
Commit
d3341a67
authored
Aug 16, 2021
by
Jing Zhang
Browse files
xdlops refactor
parent
b62bf8c3
Changes
5
Show whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
354 additions
and
617 deletions
+354
-617
composable_kernel/include/tensor_operation/blockwise_gemm_xdlops.hpp
...kernel/include/tensor_operation/blockwise_gemm_xdlops.hpp
+186
-392
composable_kernel/include/tensor_operation/gridwise_gemm_xdlops_v2r3.hpp
...el/include/tensor_operation/gridwise_gemm_xdlops_v2r3.hpp
+84
-182
composable_kernel/include/tensor_operation/xdlops_gemm.hpp
composable_kernel/include/tensor_operation/xdlops_gemm.hpp
+74
-35
host/driver_offline/include/driver_gemm_xdlops_v2r3.hpp
host/driver_offline/include/driver_gemm_xdlops_v2r3.hpp
+9
-7
host/driver_offline/src/conv_fwd_driver_offline.cpp
host/driver_offline/src/conv_fwd_driver_offline.cpp
+1
-1
No files found.
composable_kernel/include/tensor_operation/blockwise_gemm_xdlops.hpp
View file @
d3341a67
...
...
@@ -9,82 +9,99 @@ namespace ck {
template
<
index_t
BlockSize
,
typename
FloatAB
,
class
ABlockDesc
,
class
BBlockDesc
,
index_t
MPerWave
,
index_t
NPerWave
,
typename
AK0MK1BlockDesc
,
typename
BK0NK1BlockDesc
,
index_t
MPerXDL
,
index_t
NPerXDL
,
index_t
MRepeat
,
index_t
NRepeat
,
index_t
K1
>
struct
BlockwiseGemmXdlops_k
m_kn_m0m1m2n
_v1
struct
BlockwiseGemmXdlops_k
0mk1_k0nk1_m0n0m1n1m2m3m4n2
_v1
{
using
CIndex
=
MultiIndex
<
2
>
;
static
constexpr
auto
I0
=
Number
<
0
>
{};
static
constexpr
auto
I1
=
Number
<
1
>
{};
static
constexpr
auto
I2
=
Number
<
2
>
{};
static
constexpr
auto
I3
=
Number
<
3
>
{};
using
CIndex
=
MultiIndex
<
2
>
;
static
constexpr
index_t
WaveSize
=
64
;
static
constexpr
index_t
M0
=
ABlockDesc
{}.
GetLength
(
I1
);
static
constexpr
index_t
M1
=
ABlockDesc
{}.
GetLength
(
I2
);
static
constexpr
index_t
MPerBlock
=
AK0MK1BlockDesc
{}.
GetLength
(
I1
);
static
constexpr
index_t
NPerBlock
=
BK0NK1BlockDesc
{}.
GetLength
(
I1
);
static
constexpr
index_t
K0
=
BK0NK1BlockDesc
{}.
GetLength
(
I0
);
static
constexpr
index_t
KPerBlock
=
K0
;
static
constexpr
index_t
KPack
=
K1
;
static
constexpr
index_t
N0
=
BBlockDesc
{}.
GetLength
(
I1
);
static
constexpr
index_t
N1
=
BBlockDesc
{}.
GetLength
(
I2
);
static
constexpr
auto
xdlops_gemm
=
XdlopsGemm
<
FloatAB
,
MPerXDL
,
NPerXDL
,
K1
>
{};
static
constexpr
auto
x
dlops
_gemm
=
X
dlops
G
emm
<
FloatAB
,
MPerWave
,
NPerWave
,
K1
>
{}
;
static
constexpr
auto
CX
dlops
Layout
=
x
dlops
_g
emm
.
GetCXdlopsLayout
()
;
static
constexpr
index_t
MWaves
=
M
1
/
MPerWave
;
static
constexpr
index_t
NWaves
=
N
1
/
NPerWave
;
static
constexpr
index_t
MWaves
=
M
PerBlock
/
(
MRepeat
*
MPerXDL
)
;
static
constexpr
index_t
NWaves
=
N
PerBlock
/
(
NRepeat
*
NPerXDL
)
;
static
constexpr
index_t
MRepeat
=
M0
;
static
constexpr
index_t
NRepeat
=
N0
;
__device__
static
constexpr
auto
GetCM0N0M1N1M2M3M4N2ThreadDesc
()
{
constexpr
auto
M0
=
Number
<
CXdlopsLayout
.
M1
()
>
{};
constexpr
auto
M2
=
Number
<
CXdlopsLayout
.
M0
()
>
{};
return
make_naive_tensor_descriptor_packed
(
make_tuple
(
I1
,
I1
,
I1
,
I1
,
M0
,
I1
,
M2
,
I1
));
}
__device__
constexpr
auto
GetCLayout
()
const
{
return
xdlops_gemm
.
GetCLayout
();
}
__device__
static
auto
GetWaveIdx
()
{
const
index_t
thread_id
=
get_thread_local_1d_id
();
__device__
constexpr
auto
GetNumBlks
()
const
{
return
xdlops_gemm
.
GetCLayout
().
GetNumBlks
();
}
const
auto
threadid_to_wave_idx_adaptor
=
make_single_stage_tensor_adaptor
(
make_tuple
(
make_merge_transform
(
make_tuple
(
MWaves
,
NWaves
,
WaveSize
))),
make_tuple
(
Sequence
<
0
,
1
,
2
>
{}),
make_tuple
(
Sequence
<
0
>
{}));
__device__
constexpr
auto
GetBlkSize
()
const
{
return
xdlops_gemm
.
GetCLayout
().
GetBlkSize
();
}
return
threadid_to_wave_idx_adaptor
.
CalculateBottomIndex
(
make_multi_index
(
thread_id
));
}
__device__
static
auto
CalculateAThreadOriginDataIndex
()
{
const
index_t
thread_id
=
get_thread_local_1d_id
();
const
index_t
waveId
=
thread_id
/
WaveSize
;
const
index_t
laneId
=
thread_id
%
WaveSize
;
const
index_t
waveId_m
=
waveId
/
NWaves
;
const
auto
wave_idx
=
GetWaveIdx
();
const
auto
waveId_m
=
wave_idx
[
I0
];
const
auto
laneId
=
wave_idx
[
I2
];
const
auto
blk_idx
=
xdlops_gemm
.
GetBlkIdx
();
const
auto
blk_id
=
blk_idx
[
I0
];
const
auto
blk_td
=
blk_idx
[
I1
];
if
constexpr
(
xdlops_gemm
.
IsKReduction
)
{
const
index_t
m_offset
=
waveId_m
*
MPerWave
+
xdlops_gemm
.
GetBlkTd
(
laneId
);
const
index_t
k_offset
=
xdlops_gemm
.
GetBlkId
(
laneId
);
return
make_tuple
(
k_offset
,
0
,
m_offset
,
0
);
return
make_tuple
(
blk_id
,
0
,
waveId_m
,
blk_td
,
0
);
}
else
{
const
index_t
m_offset
=
waveId_m
*
MPerWave
+
laneId
;
const
index_t
k_offset
=
0
;
return
make_tuple
(
k_offset
,
0
,
m_offset
,
0
);
return
make_tuple
(
0
,
0
,
waveId_m
,
laneId
,
0
);
}
}
__device__
static
auto
CalculateBThreadOriginDataIndex
()
{
const
index_t
thread_id
=
get_thread_local_1d_id
();
const
index_t
waveId
=
thread_id
/
WaveSize
;
const
index_t
laneId
=
thread_id
%
WaveSize
;
const
index_t
waveId_n
=
waveId
%
NWaves
;
const
auto
wave_idx
=
GetWaveIdx
();
const
auto
waveId_n
=
wave_idx
[
I1
];
const
auto
laneId
=
wave_idx
[
I2
];
const
auto
blk_idx
=
xdlops_gemm
.
GetBlkIdx
();
const
auto
blk_id
=
blk_idx
[
I0
];
const
auto
blk_td
=
blk_idx
[
I1
];
if
constexpr
(
xdlops_gemm
.
IsKReduction
)
{
const
index_t
n_offset
=
waveId_n
*
NPerWave
+
xdlops_gemm
.
GetBlkTd
(
laneId
);
const
index_t
k_offset
=
xdlops_gemm
.
GetBlkId
(
laneId
);
return
make_tuple
(
k_offset
,
0
,
n_offset
,
0
);
return
make_tuple
(
blk_id
,
0
,
waveId_n
,
blk_td
,
0
);
}
else
{
const
index_t
n_offset
=
waveId_n
*
NPerWave
+
laneId
;
const
index_t
k_offset
=
0
;
return
make_tuple
(
k_offset
,
0
,
n_offset
,
0
);
return
make_tuple
(
0
,
0
,
waveId_n
,
laneId
,
0
);
}
}
...
...
@@ -92,263 +109,117 @@ struct BlockwiseGemmXdlops_km_kn_m0m1m2n_v1
__device__
static
CIndex
CalculateCThreadOriginDataIndex
(
Number
<
m0
>
,
Number
<
n0
>
,
Number
<
xdlops_i
>
,
Number
<
blk_i
>
)
{
const
auto
wave_idx
=
GetWaveIdx
();
const
auto
waveId_m
=
wave_idx
[
I0
];
const
auto
waveId_n
=
wave_idx
[
I1
];
const
index_t
waveId
=
get_thread_local_1d_id
()
/
WaveSize
;
const
auto
blk_idx
=
xdlops_gemm
.
GetBeginOfThreadBlk
(
xdlops_i
,
blk_i
)
;
const
auto
thread_mtx_on_blk
=
xdlops_gemm
.
GetBeginOfThreadBlk
(
xdlops_i
,
blk_i
);
constexpr
auto
mrepeat_mwave_mperxdl_to_m
=
make_naive_tensor_descriptor_packed
(
make_tuple
(
Number
<
MRepeat
>
{},
Number
<
MWaves
>
{},
Number
<
MPerXDL
>
{}));
const
index_t
waveId_m
=
waveId
/
NWaves
;
const
index_t
waveId_n
=
waveId
%
NWaves
;
const
expr
auto
nrepeat_nwave_nperxdl_to_n
=
make_naive_tensor_descriptor_packed
(
make_tuple
(
Number
<
NRepeat
>
{},
Number
<
NWaves
>
{},
Number
<
NPerXDL
>
{}))
;
const
index_t
m_offset
=
m0
*
M1
+
waveId_m
*
MPerWave
+
thread_mtx_on_blk
[
I0
];
const
index_t
n_offset
=
n0
*
N1
+
waveId_n
*
NPerWave
+
thread_mtx_on_blk
[
I1
];
const
index_t
c_thread_m
=
mrepeat_mwave_mperxdl_to_m
.
CalculateOffset
(
make_tuple
(
m0
,
waveId_m
,
blk_idx
[
I0
]));
const
index_t
c_thread_n
=
nrepeat_nwave_nperxdl_to_n
.
CalculateOffset
(
make_tuple
(
n0
,
waveId_n
,
blk_idx
[
I1
]));
return
CIndex
{
m_offset
,
n_offset
};
return
CIndex
{
c_thread_m
,
c_thread_n
};
}
__device__
BlockwiseGemmXdlops_km_kn_m0m1m2n_v1
()
:
a_thread_copy_
{
CalculateAThreadOriginDataIndex
()},
b_thread_copy_
{
CalculateBThreadOriginDataIndex
()}
__host__
__device__
BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1
()
{
static_assert
(
ABlockDesc
::
IsKnownAtCompileTime
()
&&
BBlockDesc
::
IsKnownAtCompileTime
(),
static_assert
(
AK0MK1BlockDesc
::
IsKnownAtCompileTime
()
&&
BK0NK1BlockDesc
::
IsKnownAtCompileTime
(),
"wrong! Desc should be known at compile-time"
);
static_assert
(
ABlockDesc
{}.
GetLength
(
I0
)
==
BBlockDesc
{}.
GetLength
(
I0
),
"wrong! K dimension not consistent"
);
static_assert
(
A
K0MK1
BlockDesc
{}.
GetLength
(
I0
)
==
B
K0NK1
BlockDesc
{}.
GetLength
(
I0
),
"wrong! K
0
dimension not consistent"
);
static_assert
(
ABlockDesc
{}.
GetLength
(
I
3
)
==
BBlockDesc
{}.
GetLength
(
I
3
),
static_assert
(
A
K0MK1
BlockDesc
{}.
GetLength
(
I
2
)
==
B
K0NK1
BlockDesc
{}.
GetLength
(
I
2
),
"wrong! K1 dimension not consistent"
);
static_assert
(
BlockSize
==
MWaves
*
NWaves
*
WaveSize
,
"BlockSize != MWaves * NWaves * WaveSize
\n
"
);
static_assert
(
K1
==
BBlockDesc
{}.
GetLength
(
I3
),
"K1 is wrong!"
);
constexpr
index_t
KPerBlock
=
ABlockDesc
{}.
GetLength
(
I0
);
static_assert
(
KPerBlock
%
xdlops_gemm
.
KPerXdlops
==
0
,
"KPerBlock is wrong!"
);
static_assert
(
K1
%
xdlops_gemm
.
mfma_type
.
k_base
==
0
,
"K1 is wrong!"
);
}
template
<
typename
ABlockBuffer
,
typename
BBlockBuffer
,
typename
CThreadBuffer
>
__device__
void
Run
(
const
ABlockBuffer
&
a_block_buf
,
const
BBlockBuffer
&
b_block_buf
,
CThreadBuffer
&
c_thread_buf
)
const
{
auto
a_thread_buf
=
make_static_buffer
<
AddressSpaceEnum_t
::
Vgpr
,
FloatAB
>
(
a_thread_desc_
.
GetElementSpaceSize
());
auto
b_thread_buf
=
make_static_buffer
<
AddressSpaceEnum_t
::
Vgpr
,
FloatAB
>
(
b_thread_desc_
.
GetElementSpaceSize
());
constexpr
index_t
KPerBlock
=
ABlockDesc
{}.
GetLength
(
I0
);
vector_type
<
FloatAB
,
a_thread_desc_
.
GetElementSpaceSize
()
>
a_thread_vec
;
vector_type
<
FloatAB
,
b_thread_desc_
.
GetElementSpaceSize
()
>
b_thread_vec
;
static_for
<
0
,
KPerBlock
,
xdlops_gemm
.
KPerXdlops
>
{}([
&
](
auto
k
)
{
// read A
a_thread_copy_
.
Run
(
ABlockDesc
{},
make_tuple
(
k
,
I0
,
I0
,
I0
),
a_block_buf
,
a_thread_desc_
,
make_tuple
(
I0
,
I0
,
I0
,
I0
),
a_thread_buf
);
// read B
b_thread_copy_
.
Run
(
BBlockDesc
{},
make_tuple
(
k
,
I0
,
I0
,
I0
),
b_block_buf
,
b_thread_desc_
,
make_tuple
(
I0
,
I0
,
I0
,
I0
),
b_thread_buf
);
using
mfma_input_type
=
typename
vector_type
<
FloatAB
,
xdlops_gemm
.
mfma_type
.
k_base
>::
type
;
static_for
<
0
,
a_thread_desc_
.
GetElementSpaceSize
(),
1
>
{}([
&
](
auto
i
)
{
a_thread_vec
.
template
AsType
<
FloatAB
>()(
Number
<
i
>
{})
=
a_thread_buf
[
Number
<
i
>
{}];
});
static_assert
(
MPerBlock
%
(
MPerXDL
*
MRepeat
)
==
0
&&
NPerBlock
%
(
NPerXDL
*
NRepeat
)
==
0
,
"wrong!"
);
static_for
<
0
,
b_thread_desc_
.
GetElementSpaceSize
(),
1
>
{}([
&
](
auto
i
)
{
b_thread_vec
.
template
AsType
<
FloatAB
>()(
Number
<
i
>
{})
=
b_thread_buf
[
Number
<
i
>
{}];
});
constexpr
index_t
NumBlks
=
CXdlopsLayout
.
GetNumBlks
();
constexpr
index_t
NumXdlops
=
CXdlopsLayout
.
GetNumXdlops
();
static_for
<
0
,
MRepeat
,
1
>
{}([
&
](
auto
m0
)
{
static_for
<
0
,
NRepeat
,
1
>
{}([
&
](
auto
n0
)
{
xdlops_gemm
.
template
Run
<
decltype
(
a_thread_desc_
),
decltype
(
b_thread_desc_
),
decltype
(
c_thread_desc_
),
m0
,
n0
>(
a_thread_vec
.
template
AsType
<
mfma_input_type
>(),
b_thread_vec
.
template
AsType
<
mfma_input_type
>(),
c_thread_buf
);
});
});
});
static_assert
(
NumBlks
==
1
&&
NumXdlops
==
1
,
"K Reduction Mfma only"
);
}
private:
// A[K, M]
static
constexpr
auto
a_thread_desc_
=
make_naive_tensor_descriptor_packed
(
make_tuple
(
I1
,
Number
<
MRepeat
>
{},
I1
,
Number
<
K1
>
{}));
// B[K, N]
static
constexpr
auto
b_thread_desc_
=
make_naive_tensor_descriptor_packed
(
make_tuple
(
I1
,
Number
<
NRepeat
>
{},
I1
,
Number
<
K1
>
{}));
static
constexpr
auto
c_thread_desc_
=
make_naive_tensor_descriptor_packed
(
make_tuple
(
Number
<
MRepeat
>
{},
Number
<
NRepeat
>
{}));
using
AThreadCopy
=
ThreadwiseTensorSliceTransfer_v4
<
FloatAB
,
FloatAB
,
ABlockDesc
,
decltype
(
a_thread_desc_
),
Sequence
<
1
,
MRepeat
,
1
,
K1
>
,
Sequence
<
0
,
1
,
2
,
3
>
,
3
,
K1
,
1
>
;
using
BThreadCopy
=
ThreadwiseTensorSliceTransfer_v4
<
FloatAB
,
FloatAB
,
BBlockDesc
,
decltype
(
b_thread_desc_
),
Sequence
<
1
,
NRepeat
,
1
,
K1
>
,
Sequence
<
0
,
1
,
2
,
3
>
,
3
,
K1
,
1
>
;
AThreadCopy
a_thread_copy_
;
BThreadCopy
b_thread_copy_
;
};
template
<
index_t
BlockSize
,
typename
FloatAB
,
class
ABlockDesc
,
class
BBlockDesc
,
index_t
MPerWave
,
index_t
NPerWave
,
index_t
K1
>
struct
BlockwiseGemmXdlops_km_kn_m0m1m2n_v1_2x2pipeline
{
using
CIndex
=
MultiIndex
<
2
>
;
static
constexpr
auto
I0
=
Number
<
0
>
{};
static
constexpr
auto
I1
=
Number
<
1
>
{};
static
constexpr
auto
I2
=
Number
<
2
>
{};
static
constexpr
auto
I3
=
Number
<
3
>
{};
static
constexpr
auto
xdlops_gemm
=
XdlopsGemm
<
float
,
MPerWave
,
NPerWave
,
K1
>
{};
static
constexpr
index_t
WaveSize
=
64
;
static
constexpr
index_t
M0
=
ABlockDesc
{}.
GetLength
(
I1
);
static
constexpr
index_t
M1
=
ABlockDesc
{}.
GetLength
(
I2
);
static
constexpr
index_t
N0
=
BBlockDesc
{}.
GetLength
(
I1
);
static
constexpr
index_t
N1
=
BBlockDesc
{}.
GetLength
(
I2
);
static
constexpr
index_t
MWaves
=
M1
/
MPerWave
;
static
constexpr
index_t
NWaves
=
N1
/
NPerWave
;
static
constexpr
index_t
MRepeat
=
M0
;
static
constexpr
index_t
NRepeat
=
N0
;
__device__
constexpr
auto
GetCLayout
()
const
{
return
xdlops_gemm
.
GetCLayout
();
}
__device__
constexpr
auto
GetNumBlks
()
const
{
return
xdlops_gemm
.
GetCLayout
().
GetNumBlks
();
}
__device__
constexpr
auto
GetBlkSize
()
const
{
return
xdlops_gemm
.
GetCLayout
().
GetBlkSize
();
}
__device__
static
auto
CalculateAThreadOriginDataIndex
()
{
const
index_t
thread_id
=
get_thread_local_1d_id
();
const
index_t
waveId
=
thread_id
/
WaveSize
;
const
index_t
laneId
=
thread_id
%
WaveSize
;
const
index_t
waveId_m
=
waveId
/
NWaves
;
if
constexpr
(
xdlops_gemm
.
IsKReduction
)
__host__
__device__
static
constexpr
auto
GetCM0N0M1N1M2M3M4N2BlockDescriptor
()
{
const
index_t
m_offset
=
waveId_m
*
MPerWave
+
xdlops_gemm
.
GetBlkTd
(
laneId
);
const
index_t
k_offset
=
xdlops_gemm
.
GetBlkId
(
laneId
);
return
make_tuple
(
k_offset
,
0
,
m_offset
,
0
);
constexpr
auto
M2
=
Number
<
CXdlopsLayout
.
M1
()
>
{};
constexpr
auto
M3
=
Number
<
CXdlopsLayout
.
N1
()
>
{};
constexpr
auto
M4
=
Number
<
CXdlopsLayout
.
M0
()
>
{};
constexpr
auto
N2
=
Number
<
CXdlopsLayout
.
N0
()
>
{};
return
make_naive_tensor_descriptor_packed
(
make_tuple
(
Number
<
MRepeat
>
{},
Number
<
NRepeat
>
{},
Number
<
MWaves
>
{},
Number
<
NWaves
>
{},
Number
<
M2
>
{},
Number
<
M3
>
{},
Number
<
M4
>
{},
Number
<
N2
>
{}));
}
else
{
const
index_t
m_offset
=
waveId_m
*
MPerWave
+
laneId
;
const
index_t
k_offset
=
0
;
return
make_tuple
(
k_offset
,
0
,
m_offset
,
0
);
}
}
__device__
static
auto
CalculateBThreadOriginDataIndex
()
{
const
index_t
thread_id
=
get_thread_local_1d_id
();
const
index_t
waveId
=
thread_id
/
WaveSize
;
const
index_t
laneId
=
thread_id
%
WaveSize
;
const
index_t
waveId_n
=
waveId
%
NWaves
;
if
constexpr
(
xdlops_gemm
.
IsKReduction
)
{
const
index_t
n_offset
=
waveId_n
*
NPerWave
+
xdlops_gemm
.
GetBlkTd
(
laneId
);
const
index_t
k_offset
=
xdlops_gemm
.
GetBlkId
(
laneId
);
return
make_tuple
(
k_offset
,
0
,
n_offset
,
0
);
}
else
template
<
typename
CMNGridDesc
>
__host__
__device__
static
constexpr
auto
MakeCM0N0M1N1M2M3M4N2GridDescriptor
(
const
CMNGridDesc
&
c_m_n_grid_desc
)
{
const
index_t
n_offset
=
waveId_n
*
NPerWave
+
laneId
;
const
index_t
k_offset
=
0
;
return
make_tuple
(
k_offset
,
0
,
n_offset
,
0
);
}
///\To-do: pass CGrid desc transform deep inside xdlops gemm
constexpr
auto
M2
=
Number
<
CXdlopsLayout
.
M1
()
>
{};
constexpr
auto
M3
=
Number
<
CXdlopsLayout
.
N1
()
>
{};
constexpr
auto
M4
=
Number
<
CXdlopsLayout
.
M0
()
>
{};
constexpr
auto
N2
=
Number
<
CXdlopsLayout
.
N0
()
>
{};
return
transform_tensor_descriptor
(
c_m_n_grid_desc
,
make_tuple
(
make_unmerge_transform
(
make_tuple
(
MRepeat
,
MWaves
,
M2
,
M3
,
M4
)),
make_unmerge_transform
(
make_tuple
(
NRepeat
,
NWaves
,
N2
))),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
0
,
2
,
4
,
5
,
6
>
{},
Sequence
<
1
,
3
,
7
>
{}));
}
template
<
index_t
m0
,
index_t
n0
,
index_t
xdlops_i
,
index_t
blk_i
>
__device__
static
CIndex
CalculateCThreadOriginDataIndex
(
Number
<
m0
>
,
Number
<
n0
>
,
Number
<
xdlops_i
>
,
Number
<
blk_i
>
)
__host__
__device__
static
constexpr
auto
MakeAK0M0M1M2K1BlockDescriptor
()
{
const
index_t
waveId
=
get_thread_local_1d_id
()
/
WaveSize
;
const
auto
thread_mtx_on_blk
=
xdlops_gemm
.
GetBeginOfThreadBlk
(
xdlops_i
,
blk_i
);
const
index_t
waveId_m
=
waveId
/
NWaves
;
const
index_t
waveId_n
=
waveId
%
NWaves
;
const
index_t
m_offset
=
m0
*
M1
+
waveId_m
*
MPerWave
+
thread_mtx_on_blk
[
I0
];
const
index_t
n_offset
=
n0
*
N1
+
waveId_n
*
NPerWave
+
thread_mtx_on_blk
[
I1
];
return
CIndex
{
m_offset
,
n_offset
};
return
transform_tensor_descriptor
(
AK0MK1BlockDesc
{},
make_tuple
(
make_pass_through_transform
(
Number
<
KPerBlock
>
{}),
make_unmerge_transform
(
make_tuple
(
Number
<
MRepeat
>
{},
Number
<
MWaves
>
{},
Number
<
MPerXDL
>
{})),
make_pass_through_transform
(
Number
<
K1
>
{})),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
,
2
,
3
>
{},
Sequence
<
4
>
{}));
}
__device__
BlockwiseGemmXdlops_km_kn_m0m1m2n_v1_2x2pipeline
()
:
a_thread_copy_
{
CalculateAThreadOriginDataIndex
()},
b_thread_copy_
{
CalculateBThreadOriginDataIndex
()}
__host__
__device__
static
constexpr
auto
MakeBK0N0N1N2K1BlockDescriptor
()
{
static_assert
(
ABlockDesc
::
IsKnownAtCompileTime
()
&&
BBlockDesc
::
IsKnownAtCompileTime
(),
"wrong! Desc should be known at compile-time"
);
static_assert
(
ABlockDesc
{}.
GetLength
(
I0
)
==
BBlockDesc
{}.
GetLength
(
I0
),
"wrong! K dimension not consistent"
);
static_assert
(
ABlockDesc
{}.
GetLength
(
I3
)
==
BBlockDesc
{}.
GetLength
(
I3
),
"wrong! K1 dimension not consistent"
);
static_assert
(
BlockSize
==
MWaves
*
NWaves
*
WaveSize
,
"BlockSize != MWaves * NWaves * WaveSize
\n
"
);
static_assert
(
K1
==
BBlockDesc
{}.
GetLength
(
I3
),
"K1 is wrong!"
);
constexpr
index_t
KPerBlock
=
ABlockDesc
{}.
GetLength
(
I0
);
static_assert
(
KPerBlock
%
xdlops_gemm
.
KPerXdlops
==
0
,
"KPerBlock is wrong!"
);
static_assert
(
K1
%
xdlops_gemm
.
mfma_type
.
k_base
==
0
,
"K1 is wrong!"
);
return
transform_tensor_descriptor
(
BK0NK1BlockDesc
{},
make_tuple
(
make_pass_through_transform
(
Number
<
KPerBlock
>
{}),
make_unmerge_transform
(
make_tuple
(
Number
<
NRepeat
>
{},
Number
<
NWaves
>
{},
Number
<
NPerXDL
>
{})),
make_pass_through_transform
(
Number
<
K1
>
{})),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
,
2
,
3
>
{},
Sequence
<
4
>
{}));
}
static
constexpr
auto
a_k0_m0_m1_m2_k1_block_desc
=
MakeAK0M0M1M2K1BlockDescriptor
();
static
constexpr
auto
b_k0_n0_n1_n2_k1_block_desc
=
MakeBK0N0N1N2K1BlockDescriptor
();
template
<
typename
ABlockBuffer
,
typename
BBlockBuffer
,
typename
CThreadBuffer
>
__device__
void
Run
(
const
ABlockBuffer
&
a_block_buf
,
const
BBlockBuffer
&
b_block_buf
,
...
...
@@ -359,165 +230,88 @@ struct BlockwiseGemmXdlops_km_kn_m0m1m2n_v1_2x2pipeline
auto
b_thread_buf
=
make_static_buffer
<
AddressSpaceEnum_t
::
Vgpr
,
FloatAB
>
(
b_thread_desc_
.
GetElementSpaceSize
());
constexpr
index_t
KPerBlock
=
ABlockDesc
{}.
GetLength
(
I0
)
;
vector_type
<
FloatAB
,
K1
>
a_thread_vec
;
// read A_sub_0
a_thread_copy_
.
Run
(
ABlockDesc
{},
make_tuple
(
I0
,
I0
,
I0
,
I0
),
a_block_buf
,
a_thread_desc_
,
make_tuple
(
I0
,
I0
,
I0
,
I0
),
a_thread_buf
);
// read B_sub_0
b_thread_copy_
.
Run
(
BBlockDesc
{},
make_tuple
(
I0
,
I0
,
I0
,
I0
),
b_block_buf
,
b_thread_desc_
,
make_tuple
(
I0
,
I0
,
I0
,
I0
),
b_thread_buf
);
// read B_sub_1
b_thread_copy_
.
Run
(
BBlockDesc
{},
make_tuple
(
I0
,
I1
,
I0
,
I0
),
b_block_buf
,
b_thread_desc_
,
make_tuple
(
I0
,
I1
,
I0
,
I0
),
b_thread_buf
);
// read A_sub_1
a_thread_copy_
.
Run
(
ABlockDesc
{},
make_tuple
(
I0
,
I1
,
I0
,
I0
),
a_block_buf
,
a_thread_desc_
,
make_tuple
(
I0
,
I1
,
I0
,
I0
),
a_thread_buf
);
// C_sub_00 += transpose(A_sub_0) * B_sub_0
xdlops_gemm
.
template
Run
<
decltype
(
a_thread_desc_
),
decltype
(
b_thread_desc_
),
decltype
(
c_thread_desc_
),
0
,
0
>(
a_thread_buf
,
b_thread_buf
,
c_thread_buf
);
vector_type
<
FloatAB
,
K1
>
b_thread_vec
;
// C_sub_01 += transpose(A_sub_0) * B_sub_1
xdlops_gemm
.
template
Run
<
decltype
(
a_thread_desc_
),
decltype
(
b_thread_desc_
),
decltype
(
c_thread_desc_
),
0
,
1
>(
a_thread_buf
,
b_thread_buf
,
c_thread_buf
);
static_for
<
xdlops_gemm
.
KPerXdlops
,
KPerBlock
,
xdlops_gemm
.
KPerXdlops
>
{}([
&
](
auto
k
)
{
// read A_sub_0
a_thread_copy_
.
Run
(
ABlockDesc
{},
make_tuple
(
k
,
I0
,
I0
,
I0
),
static_for
<
0
,
KPerBlock
,
xdlops_gemm
.
KPerXdlops
>
{}([
&
](
auto
k
)
{
// read A
a_thread_copy_
.
Run
(
a_k0_m0_m1_m2_k1_block_desc
,
make_tuple
(
k
,
I0
,
I0
,
I0
,
I0
),
a_block_buf
,
a_thread_desc_
,
make_tuple
(
I0
,
I0
,
I0
,
I0
),
make_tuple
(
I0
,
I0
,
I0
,
I0
,
I0
),
a_thread_buf
);
// C_sub_10 += transpose(A_sub_1) * B_sub_0
xdlops_gemm
.
template
Run
<
decltype
(
a_thread_desc_
),
decltype
(
b_thread_desc_
),
decltype
(
c_thread_desc_
),
1
,
0
>(
a_thread_buf
,
b_thread_buf
,
c_thread_buf
);
// read B_sub_0
b_thread_copy_
.
Run
(
BBlockDesc
{},
make_tuple
(
k
,
I0
,
I0
,
I0
),
b_block_buf
,
b_thread_desc_
,
make_tuple
(
I0
,
I0
,
I0
,
I0
),
b_thread_buf
);
// C_sub_11 += transpose(A_sub_1) * B_sub_1
xdlops_gemm
.
template
Run
<
decltype
(
a_thread_desc_
),
decltype
(
b_thread_desc_
),
decltype
(
c_thread_desc_
),
1
,
1
>(
a_thread_buf
,
b_thread_buf
,
c_thread_buf
);
// read B_sub_1
b_thread_copy_
.
Run
(
BBlockDesc
{},
make_tuple
(
k
,
I1
,
I0
,
I0
),
// read B
b_thread_copy_
.
Run
(
b_k0_n0_n1_n2_k1_block_desc
,
make_tuple
(
k
,
I0
,
I0
,
I0
,
I0
),
b_block_buf
,
b_thread_desc_
,
make_tuple
(
I0
,
I
1
,
I0
,
I0
),
make_tuple
(
I0
,
I
0
,
I0
,
I0
,
I0
),
b_thread_buf
);
// read A_sub_1
a_thread_copy_
.
Run
(
ABlockDesc
{},
make_tuple
(
k
,
I1
,
I0
,
I0
),
a_block_buf
,
a_thread_desc_
,
make_tuple
(
I0
,
I1
,
I0
,
I0
),
a_thread_buf
);
using
mfma_input_type
=
typename
vector_type
<
FloatAB
,
xdlops_gemm
.
mfma_type
.
k_base
>::
type
;
// C_sub_00 += transpose(A_sub_0) * B_sub_0
xdlops_gemm
.
template
Run
<
decltype
(
a_thread_desc_
),
decltype
(
b_thread_desc_
),
decltype
(
c
_thread_
desc_
),
0
,
0
>(
a_thread_buf
,
b_thread_buf
,
c_thread_buf
);
static_for
<
0
,
MRepeat
,
1
>
{}([
&
](
auto
m0
)
{
static_for
<
0
,
NRepeat
,
1
>
{}([
&
](
auto
n0
)
{
static_for
<
0
,
K1
,
1
>
{}([
&
](
auto
i
)
{
a_thread_vec
.
template
AsType
<
FloatAB
>()(
Number
<
i
>
{})
=
a
_thread_
buf
[
Number
<
a_thread_desc_
.
CalculateOffset
(
make_tuple
(
0
,
m0
,
0
,
0
,
i
))
>
{}];
}
);
// C_sub_01 += transpose(A_sub_0) * B_sub_1
xdlops_gemm
.
template
Run
<
decltype
(
a_thread_desc_
),
decltype
(
b_thread_desc_
),
decltype
(
c_thread_desc_
),
0
,
1
>(
a_thread_buf
,
b_thread_buf
,
c_thread_buf
);
static_for
<
0
,
K1
,
1
>
{}([
&
](
auto
i
)
{
b_thread_vec
.
template
AsType
<
FloatAB
>()(
Number
<
i
>
{})
=
b_thread_buf
[
Number
<
b_thread_desc_
.
CalculateOffset
(
make_tuple
(
0
,
n0
,
0
,
0
,
i
))
>
{}];
});
// C_sub_10 += transpose(A_sub_1) * B_sub_0
xdlops_gemm
.
template
Run
<
decltype
(
a_thread_desc_
),
decltype
(
b_thread_desc_
),
decltype
(
c_thread_desc_
),
1
,
0
>(
a_thread_buf
,
b_thread_buf
,
c_thread_buf
);
constexpr
index_t
c_offset
=
c_thread_desc_
.
CalculateOffset
(
make_tuple
(
m0
,
n0
,
0
));
// C_sub_11 += transpose(A_sub_1) * B_sub_1
xdlops_gemm
.
template
Run
<
decltype
(
a_thread_desc_
),
decltype
(
b_thread_desc_
),
decltype
(
c_thread_desc_
),
1
,
1
>(
a_thread_buf
,
b_thread_buf
,
c_thread_buf
);
xdlops_gemm
.
template
Run
<
c_offset
>(
a_thread_vec
.
template
AsType
<
mfma_input_type
>(),
b_thread_vec
.
template
AsType
<
mfma_input_type
>(),
c_thread_buf
);
});
});
});
}
private:
// A[K, M]
static
constexpr
auto
a_thread_desc_
=
make_naive_tensor_descriptor_packed
(
make_tuple
(
I1
,
Number
<
MRepeat
>
{},
I1
,
Number
<
K1
>
{}));
static
constexpr
auto
a_thread_desc_
=
make_naive_tensor_descriptor_packed
(
make_tuple
(
I1
,
Number
<
MRepeat
>
{},
I1
,
I1
,
Number
<
K1
>
{}));
// B[K, N]
static
constexpr
auto
b_thread_desc_
=
make_naive_tensor_descriptor_packed
(
make_tuple
(
I1
,
Number
<
NRepeat
>
{},
I1
,
Number
<
K1
>
{}));
static
constexpr
auto
b_thread_desc_
=
make_naive_tensor_descriptor_packed
(
make_tuple
(
I1
,
Number
<
NRepeat
>
{},
I1
,
I1
,
Number
<
K1
>
{}));
static
constexpr
auto
c_thread_desc_
=
make_naive_tensor_descriptor_packed
(
make_tuple
(
Number
<
MRepeat
>
{},
Number
<
NRepeat
>
{}));
static
constexpr
auto
c_thread_desc_
=
make_naive_tensor_descriptor_packed
(
make_tuple
(
Number
<
MRepeat
>
{},
Number
<
NRepeat
>
{}
,
Number
<
xdlops_gemm
.
GetNumXdlops
()
>
{}
));
using
AThreadCopy
=
ThreadwiseTensorSliceTransfer_v4
<
FloatAB
,
FloatAB
,
AB
lock
D
esc
,
decltype
(
a_k0_m0_m1_m2_k1_b
lock
_d
esc
)
,
decltype
(
a_thread_desc_
),
Sequence
<
1
,
1
,
1
,
K1
>
,
Sequence
<
0
,
1
,
2
,
3
>
,
3
,
1
,
//
K1,
Sequence
<
1
,
MRepeat
,
1
,
1
,
K1
>
,
Sequence
<
0
,
1
,
2
,
3
,
4
>
,
4
,
K1
,
1
>
;
using
BThreadCopy
=
ThreadwiseTensorSliceTransfer_v4
<
FloatAB
,
FloatAB
,
BB
lock
D
esc
,
decltype
(
b_k0_n0_n1_n2_k1_b
lock
_d
esc
)
,
decltype
(
b_thread_desc_
),
Sequence
<
1
,
1
,
1
,
K1
>
,
Sequence
<
0
,
1
,
2
,
3
>
,
3
,
1
,
//
K1,
Sequence
<
1
,
NRepeat
,
1
,
1
,
K1
>
,
Sequence
<
0
,
1
,
2
,
3
,
4
>
,
4
,
K1
,
1
>
;
AThreadCopy
a_thread_copy_
;
BThreadCopy
b_thread_copy_
;
AThreadCopy
a_thread_copy_
{
CalculateAThreadOriginDataIndex
()}
;
BThreadCopy
b_thread_copy_
{
CalculateBThreadOriginDataIndex
()}
;
};
}
// namespace ck
...
...
composable_kernel/include/tensor_operation/gridwise_gemm_xdlops_v2r3.hpp
View file @
d3341a67
...
...
@@ -18,7 +18,7 @@ template <typename GridwiseGemm,
typename
FloatC
,
typename
AK0MK1GridDesc
,
typename
BK0NK1GridDesc
,
typename
CM0
M1M2N
GridDesc
,
typename
CM0
N0M1N1M2M3M4N2
GridDesc
,
typename
CBlockClusterAdaptor
>
__global__
void
#if CK_USE_LAUNCH_BOUNDS
...
...
@@ -29,7 +29,7 @@ __global__ void
FloatC
*
__restrict__
p_c_grid
,
const
AK0MK1GridDesc
a_k0_m_k1_grid_desc
,
const
BK0NK1GridDesc
b_k0_n_k1_grid_desc
,
const
CM0
M1M2N
GridDesc
c_m0_m1_m2_n_grid_desc
,
const
CM0
N0M1N1M2M3M4N2
GridDesc
c_m0_m1_m2_n_grid_desc
,
const
CBlockClusterAdaptor
c_block_cluster_adaptor
)
{
constexpr
index_t
shared_block_size
=
...
...
@@ -43,7 +43,7 @@ __global__ void
p_shared_block
,
a_k0_m_k1_grid_desc
,
b_k0_n_k1_grid_desc
,
c_m0_
m1_m2_n
_grid_desc
,
c_m0_
n0_m1_n1_m2_m3_m4_n2
_grid_desc
,
c_block_cluster_adaptor
);
}
#elif CK_EXPERIMENTAL_PASS_TENSOR_DESCRIPTOR_BY_VOID_POINTER
...
...
@@ -52,7 +52,7 @@ template <typename GridwiseGemm,
typename
FloatC
,
typename
AK0MK1GridDesc
,
typename
BK0NK1GridDesc
,
typename
CM0
M1M2N
GridDesc
,
typename
CM0
N0M1N1M2M3M4N2
GridDesc
,
typename
CBlockClusterAdaptor
>
__global__
void
#if CK_USE_LAUNCH_BOUNDS
...
...
@@ -63,7 +63,7 @@ __global__ void
FloatC
*
__restrict__
p_c_grid
,
const
void
CONSTANT
*
p_a_k0_m_k1_grid_desc
,
const
void
CONSTANT
*
p_b_k0_n_k1_grid_desc
,
const
void
CONSTANT
*
p_c_m0_
m1_m2_n
_grid_desc
,
const
void
CONSTANT
*
p_c_m0_
n0_m1_n1_m2_m3_m4_n2
_grid_desc
,
const
void
CONSTANT
*
p_c_block_cluster_adaptor
)
{
constexpr
index_t
shared_block_size
=
...
...
@@ -73,8 +73,9 @@ __global__ void
cast_pointer_to_generic_address_space
(
p_a_k0_m_k1_grid_desc
));
const
auto
b_k0_n_k1_grid_desc
=
*
reinterpret_cast
<
const
BK0NK1GridDesc
*>
(
cast_pointer_to_generic_address_space
(
p_b_k0_n_k1_grid_desc
));
const
auto
c_m0_m1_m2_n_grid_desc
=
*
reinterpret_cast
<
const
CM0M1M2NGridDesc
*>
(
cast_pointer_to_generic_address_space
(
p_c_m0_m1_m2_n_grid_desc
));
const
auto
c_m0_n0_m1_n1_m2_m3_m4_n2_grid_desc
=
*
reinterpret_cast
<
const
CM0N0M1N1M2M3M4N2GridDesc
*>
(
cast_pointer_to_generic_address_space
(
p_c_m0_n0_m1_n1_m2_m3_m4_n2_grid_desc
));
const
auto
c_block_cluster_adaptor
=
*
reinterpret_cast
<
const
CBlockClusterAdaptor
*>
(
cast_pointer_to_generic_address_space
(
p_c_block_cluster_adaptor
));
...
...
@@ -86,7 +87,7 @@ __global__ void
p_shared_block
,
a_k0_m_k1_grid_desc
,
b_k0_n_k1_grid_desc
,
c_m0_
m1_m2_n
_grid_desc
,
c_m0_
n0_m1_n1_m2_m3_m4_n2
_grid_desc
,
c_block_cluster_adaptor
);
}
#endif
...
...
@@ -138,6 +139,9 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3
static
constexpr
auto
I1
=
Number
<
1
>
{};
static
constexpr
auto
I2
=
Number
<
2
>
{};
static
constexpr
auto
I3
=
Number
<
3
>
{};
static
constexpr
auto
I4
=
Number
<
4
>
{};
static
constexpr
auto
I5
=
Number
<
5
>
{};
static
constexpr
auto
I6
=
Number
<
6
>
{};
// K1 should be Number<...>
static
constexpr
auto
K1
=
Number
<
K1Value
>
{};
...
...
@@ -201,29 +205,28 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3
}
__host__
__device__
static
constexpr
auto
MakeCM0
M1M2N
GridDescriptor
(
const
CMNGridDesc
&
c_m_n_grid_desc
)
MakeCM0
N0M1N1M2M3M4N2
GridDescriptor
(
const
CMNGridDesc
&
c_m_n_grid_desc
)
{
constexpr
auto
xdlops_gemm
=
XdlopsGemm
<
FloatAB
,
MPerWave
,
NPerWave
,
K1
>
{};
constexpr
auto
CLayout
=
xdlops_gemm
.
GetCLayout
();
constexpr
auto
M0
=
Number
<
CLayout
.
M1
()
>
{};
constexpr
auto
M1
=
Number
<
CLayout
.
N1
()
>
{};
constexpr
auto
M2
=
Number
<
CLayout
.
M0
()
>
{};
constexpr
auto
max_lds_align
=
K1
;
constexpr
index_t
MWaves
=
MPerBlock
/
(
MPerWave
*
MRepeat
);
constexpr
index_t
NWaves
=
NPerBlock
/
(
NPerWave
*
NRepeat
);
constexpr
auto
a_k0_m_k1_block_desc
=
make_naive_tensor_descriptor_aligned
(
make_tuple
(
Number
<
KPerBlock
>
{},
Number
<
MPerBlock
>
{},
K1
),
max_lds_align
);
constexpr
auto
N1
=
Number
<
CLayout
.
N0
()
>
{};
constexpr
auto
b_k0_n_k1_block_desc
=
make_naive_tensor_descriptor_aligned
(
make_tuple
(
Number
<
KPerBlock
>
{},
Number
<
NPerBlock
>
{},
K1
),
max_lds_align
);
const
auto
c_m0_m1_m2_n_grid_desc
=
transform_tensor_descriptor
(
c_m_n_grid_desc
,
make_tuple
(
make_unmerge_transform
(
make_tuple
(
MRepeat
,
MWaves
,
M0
,
M1
,
M2
)),
make_unmerge_transform
(
make_tuple
(
NRepeat
,
NWaves
,
N1
))),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
0
,
2
,
4
,
5
,
6
>
{},
Sequence
<
1
,
3
,
7
>
{}));
using
BlockwiseGemm
=
BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1
<
BlockSize
,
FloatAB
,
decltype
(
a_k0_m_k1_block_desc
),
decltype
(
b_k0_n_k1_block_desc
),
MPerWave
,
NPerWave
,
MRepeat
,
NRepeat
,
K1
>
;
return
c_m0_m1
_m
2
_n_grid_desc
;
return
BlockwiseGemm
::
MakeCM0N0M1N1M2M3M4N2GridDescriptor
(
c
_m_n_grid_desc
)
;
}
__host__
__device__
static
constexpr
auto
...
...
@@ -253,7 +256,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3
return
c_blockid_to_m0_n0_block_cluster_adaptor
;
}
using
CM0
M1M2N
GridDesc
=
decltype
(
MakeCM0
M1M2N
GridDescriptor
(
CMNGridDesc
{}));
using
CM0
N0M1N1M2M3M4N2
GridDesc
=
decltype
(
MakeCM0
N0M1N1M2M3M4N2
GridDescriptor
(
CMNGridDesc
{}));
using
CBlockClusterAdaptor
=
decltype
(
MakeCBlockClusterAdaptor
(
CMNGridDesc
{}));
__device__
static
void
Run
(
const
FloatAB
*
__restrict__
p_a_grid
,
...
...
@@ -262,7 +265,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3
FloatAB
*
__restrict__
p_shared_block
,
const
AK0MK1GridDesc
&
a_k0_m_k1_grid_desc
,
const
BK0NK1GridDesc
&
b_k0_n_k1_grid_desc
,
const
CM0
M1M2N
GridDesc
&
c_m0_
m1_m2_n
_grid_desc
,
const
CM0
N0M1N1M2M3M4N2
GridDesc
&
c_m0_
n0_m1_n1_m2_m3_m4_n2
_grid_desc
,
const
CBlockClusterAdaptor
&
c_block_cluster_adaptor
)
{
const
auto
a_grid_buf
=
make_dynamic_buffer
<
AddressSpaceEnum_t
::
Global
>
(
...
...
@@ -270,7 +273,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3
const
auto
b_grid_buf
=
make_dynamic_buffer
<
AddressSpaceEnum_t
::
Global
>
(
p_b_grid
,
b_k0_n_k1_grid_desc
.
GetElementSpaceSize
());
auto
c_grid_buf
=
make_dynamic_buffer
<
AddressSpaceEnum_t
::
Global
>
(
p_c_grid
,
c_m0_
m1_m2_n
_grid_desc
.
GetElementSpaceSize
());
p_c_grid
,
c_m0_
n0_m1_n1_m2_m3_m4_n2
_grid_desc
.
GetElementSpaceSize
());
const
auto
K0
=
a_k0_m_k1_grid_desc
.
GetLength
(
I0
);
...
...
@@ -358,50 +361,26 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3
// register
// sanity check
static_assert
(
MPerBlock
%
(
MPerWave
*
MRepeat
)
==
0
&&
NPerBlock
%
(
NPerWave
*
NRepeat
)
==
0
,
"wrong!"
);
constexpr
auto
a_k0_m0_m1_k1_block_desc
=
transform_tensor_descriptor
(
a_k0_m_k1_block_desc
,
make_tuple
(
make_pass_through_transform
(
Number
<
KPerBlock
>
{}),
make_unmerge_transform
(
make_tuple
(
Number
<
MRepeat
>
{},
Number
<
MPerBlock
/
MRepeat
>
{})),
make_pass_through_transform
(
K1
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
,
2
>
{},
Sequence
<
3
>
{}));
constexpr
auto
b_k0_n0_n1_k1_block_desc
=
transform_tensor_descriptor
(
b_k0_n_k1_block_desc
,
make_tuple
(
make_pass_through_transform
(
Number
<
KPerBlock
>
{}),
make_unmerge_transform
(
make_tuple
(
Number
<
NRepeat
>
{},
Number
<
NPerBlock
/
NRepeat
>
{})),
make_pass_through_transform
(
K1
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
,
2
>
{},
Sequence
<
3
>
{}));
const
auto
blockwise_gemm
=
BlockwiseGemmXdlops_k
m_kn_m0m1m2n
_v1
<
BlockSize
,
BlockwiseGemmXdlops_k
0mk1_k0nk1_m0n0m1n1m2m3m4n2
_v1
<
BlockSize
,
FloatAB
,
decltype
(
a_k0_m
0_m1
_k1_block_desc
),
decltype
(
b_k0_n
0_n1
_k1_block_desc
),
decltype
(
a_k0_m_k1_block_desc
),
decltype
(
b_k0_n_k1_block_desc
),
MPerWave
,
NPerWave
,
MRepeat
,
NRepeat
,
K1
>
{};
constexpr
auto
CLayout
=
blockwise_gemm
.
GetCLayout
();
constexpr
index_t
BlkSize
=
CLayout
.
GetBlkSize
();
constexpr
index_t
NumBlks
=
CLayout
.
GetNumBlks
();
constexpr
index_t
NumXdlops
=
CLayout
.
GetNumXdlops
();
static_assert
(
NumBlks
==
1
&&
NumXdlops
==
1
,
"K Reduction Mfma only"
);
constexpr
auto
c_mr_nr_blk_desc
=
make_naive_tensor_descriptor_packed
(
make_tuple
(
Number
<
MRepeat
>
{},
Number
<
NRepeat
>
{}));
constexpr
auto
c_m0_n0_m1_n1_m2_m3_m4_n2_thread_desc
=
blockwise_gemm
.
GetCM0N0M1N1M2M3M4N2ThreadDesc
();
constexpr
auto
CBlkSize
=
c_m0_n0_m1_n1_m2_m3_m4_n2_thread_desc
.
GetElementSpaceSize
();
StaticBuffer
<
AddressSpaceEnum_t
::
Vgpr
,
vector_type
<
FloatAcc
,
BlkSize
>
,
vector_type
<
FloatAcc
,
C
BlkSize
>
,
c_mr_nr_blk_desc
.
GetElementSpaceSize
(),
true
>
c_thread_buf
;
...
...
@@ -474,94 +453,14 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3
blockwise_gemm
.
Run
(
a_block_buf
,
b_block_buf
,
c_thread_buf
);
}
#if 0
// output: register to global memory
{
constexpr index_t M0 = CLayout.M1();
constexpr index_t M1 = CLayout.N1();
constexpr index_t M2 = CLayout.M0();
constexpr index_t N0 = CLayout.N1();
constexpr index_t N1 = CLayout.N0();
constexpr auto c_m0_m1_m2_n_thread_desc =
make_naive_tensor_descriptor_packed(make_tuple(Number<MRepeat>{},
Number<NRepeat>{},
Number<1>{},
Number<1>{},
Number<M0>{},
Number<1>{},
Number<M2>{},
Number<1>{}));
StaticBuffer<AddressSpaceEnum_t::Vgpr, FloatC, c_m0_m1_m2_n_thread_desc.GetElementSpaceSize(), true>
c_blk_buf_;
static_for<0, MRepeat, 1>{}([&](auto mr_i) {
static_for<0, NRepeat, 1>{}([&](auto nr_i) {
constexpr auto blk_off =
c_mr_nr_blk_desc.CalculateOffset(make_tuple(mr_i, nr_i));
static_for<0, BlkSize, 1>{}([&](auto j) {
c_blk_buf_(Number<blk_off * BlkSize + j>{}) =
c_thread_buf[Number<blk_off>{}]
.template AsType<FloatAcc>()[Number<j>{}];
});
});
});
constexpr
auto
c_m0_n0_m1_n1_m2_m3_m4_n2_block_desc
=
blockwise_gemm
.
GetCM0N0M1N1M2M3M4N2BlockDescriptor
();
// calculate origin of thread output tensor on global memory
// blockwise GEMM c matrix starting index
const auto c_thread_mtx_on_block =
blockwise_gemm.CalculateCThreadOriginDataIndex(I0, I0, I0, I0);
const index_t m_thread_data_on_grid =
m_block_data_idx_on_grid + c_thread_mtx_on_block[I0];
const index_t n_thread_data_on_grid =
n_block_data_idx_on_grid + c_thread_mtx_on_block[I1];
constexpr auto c_m0_m1_m2_n_grid_tensor_step_hacks = CGridStepHacks{};
constexpr index_t MWaves = MPerBlock / (MPerWave * MRepeat);
constexpr index_t NWaves = NPerBlock / (NPerWave * NRepeat);
ThreadwiseTensorSliceTransfer_v1r3<
FloatC,
FloatC,
decltype(c_m0_m1_m2_n_thread_desc),
decltype(c_m0_m1_m2_n_grid_desc),
Sequence<MRepeat, NRepeat, 1, 1, M0, 1, M2, 1>,
CThreadTransferSrcDstAccessOrder,
CThreadTransferSrcDstVectorDim,
CThreadTransferDstScalarPerVector,
CGlobalMemoryDataOperation,
1,
true>{
c_m0_m1_m2_n_grid_desc,
make_multi_index(m_thread_data_on_grid / (M2 * M1 * M0 * MWaves),
n_thread_data_on_grid / (N1 * NWaves),
m_thread_data_on_grid % (M2 * M1 * M0 * MWaves) / (M2 * M1 * M0),
n_thread_data_on_grid % (N1 * NWaves) / N1,
m_thread_data_on_grid % (M2 * M1 * M0) / (M2 * M1),
m_thread_data_on_grid % (M2 * M1) / M2,
m_thread_data_on_grid % M2,
n_thread_data_on_grid % N1)}
.Run(c_m0_m1_m2_n_thread_desc,
make_tuple(I0, I0, I0, I0, I0, I0, I0, I0),
c_blk_buf_,
c_m0_m1_m2_n_grid_desc,
c_grid_buf,
c_m0_m1_m2_n_grid_tensor_step_hacks);
}
#else
{
constexpr
index_t
M0
=
CLayout
.
M1
();
constexpr
index_t
M1
=
CLayout
.
N1
();
constexpr
index_t
M2
=
CLayout
.
M0
();
constexpr
auto
c_m0_m1_m2_n_thread_desc
=
make_naive_tensor_descriptor_packed
(
make_tuple
(
I1
,
I1
,
I1
,
I1
,
Number
<
M0
>
{},
Number
<
1
>
{},
Number
<
M2
>
{},
Number
<
1
>
{}));
constexpr
auto
M2
=
c_m0_n0_m1_n1_m2_m3_m4_n2_block_desc
.
GetLength
(
I4
);
constexpr
auto
M3
=
c_m0_n0_m1_n1_m2_m3_m4_n2_block_desc
.
GetLength
(
I5
);
constexpr
auto
M4
=
c_m0_n0_m1_n1_m2_m3_m4_n2_block_desc
.
GetLength
(
I6
);
// calculate origin of thread output tensor on global memory
// blockwise GEMM c matrix starting index
...
...
@@ -574,92 +473,96 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3
const
index_t
n_thread_data_on_grid
=
n_block_data_idx_on_grid
+
c_thread_mtx_on_block
[
I1
];
constexpr
auto
c_m0_
m1_m2_n
_grid_tensor_
s
te
p
_hacks
=
CGridStepHacks
{};
constexpr
auto
c_m0_
n0_m1_n1_m2_m3_m4_n2
_grid_tensor_
i
te
rator
_hacks
=
CGridStepHacks
{};
auto
c_thread_copy
=
ThreadwiseTensorSliceTransfer_v1r3
<
FloatC
,
FloatC
,
decltype
(
c_m0_
m1_m2_n
_thread_desc
),
decltype
(
c_m0_
m1_m2_n
_grid_desc
),
Sequence
<
1
,
1
,
1
,
1
,
M
0
,
1
,
M
2
,
1
>
,
decltype
(
c_m0_
n0_m1_n1_m2_m3_m4_n2
_thread_desc
),
decltype
(
c_m0_
n0_m1_n1_m2_m3_m4_n2
_grid_desc
),
Sequence
<
I
1
,
I
1
,
I
1
,
I
1
,
M
2
,
I
1
,
M
4
,
I
1
>
,
CThreadTransferSrcDstAccessOrder
,
CThreadTransferSrcDstVectorDim
,
CThreadTransferDstScalarPerVector
,
CGlobalMemoryDataOperation
,
1
,
true
>
{
c_m0_
m1_m2_n
_grid_desc
,
c_m0_
n0_m1_n1_m2_m3_m4_n2
_grid_desc
,
make_multi_index
(
0
,
0
,
0
,
0
,
m_thread_data_on_grid
/
(
M
2
*
M
1
),
m_thread_data_on_grid
%
(
M
2
*
M
1
)
/
M
2
,
m_thread_data_on_grid
%
M
2
,
m_thread_data_on_grid
/
(
M
3
*
M
4
),
m_thread_data_on_grid
%
(
M
3
*
M
4
)
/
M
4
,
m_thread_data_on_grid
%
M
4
,
n_thread_data_on_grid
)};
auto
init_copy
=
[
&
](
auto
c_thread_idx_
)
{
constexpr
auto
blk_off
=
c_mr_nr_blk_desc
.
CalculateOffset
(
c_thread_idx_
);
c_thread_copy
.
Run
(
c_m0_
m1_m2_n
_thread_desc
,
c_thread_copy
.
Run
(
c_m0_
n0_m1_n1_m2_m3_m4_n2
_thread_desc
,
make_tuple
(
I0
,
I0
,
I0
,
I0
,
I0
,
I0
,
I0
,
I0
),
c_thread_buf
[
Number
<
blk_off
>
{}].
template
AsType
<
FloatAcc
>(),
c_m0_
m1_m2_n
_grid_desc
,
c_m0_
n0_m1_n1_m2_m3_m4_n2
_grid_desc
,
c_grid_buf
,
c_m0_
m1_m2_n
_grid_tensor_
s
te
p
_hacks
);
c_m0_
n0_m1_n1_m2_m3_m4_n2
_grid_tensor_
i
te
rator
_hacks
);
return
c_thread_idx_
;
};
auto
mrepeat_plus_copy
=
[
&
](
auto
c_thread_idx_
)
{
constexpr
auto
mrepeat_step_plus
=
make_multi_index
(
1
,
0
,
0
,
0
,
0
,
0
,
0
,
0
);
c_thread_copy
.
MoveDstSliceWindow
(
c_m0_m1_m2_n_grid_desc
,
mrepeat_step_plus
);
c_thread_copy
.
MoveDstSliceWindow
(
c_m0_n0_m1_n1_m2_m3_m4_n2_grid_desc
,
mrepeat_step_plus
);
constexpr
auto
blk_off
=
c_mr_nr_blk_desc
.
CalculateOffset
(
c_thread_idx_
);
c_thread_copy
.
Run
(
c_m0_
m1_m2_n
_thread_desc
,
c_thread_copy
.
Run
(
c_m0_
n0_m1_n1_m2_m3_m4_n2
_thread_desc
,
make_tuple
(
I0
,
I0
,
I0
,
I0
,
I0
,
I0
,
I0
,
I0
),
c_thread_buf
[
Number
<
blk_off
>
{}].
template
AsType
<
FloatAcc
>(),
c_m0_
m1_m2_n
_grid_desc
,
c_m0_
n0_m1_n1_m2_m3_m4_n2
_grid_desc
,
c_grid_buf
,
c_m0_
m1_m2_n
_grid_tensor_
s
te
p
_hacks
);
c_m0_
n0_m1_n1_m2_m3_m4_n2
_grid_tensor_
i
te
rator
_hacks
);
};
auto
nrepeat_plus_copy
=
[
&
](
auto
c_thread_idx_
)
{
constexpr
auto
nrepeat_step_plus
=
make_multi_index
(
0
,
1
,
0
,
0
,
0
,
0
,
0
,
0
);
c_thread_copy
.
MoveDstSliceWindow
(
c_m0_m1_m2_n_grid_desc
,
nrepeat_step_plus
);
c_thread_copy
.
MoveDstSliceWindow
(
c_m0_n0_m1_n1_m2_m3_m4_n2_grid_desc
,
nrepeat_step_plus
);
constexpr
auto
blk_off
=
c_mr_nr_blk_desc
.
CalculateOffset
(
c_thread_idx_
);
c_thread_copy
.
Run
(
c_m0_
m1_m2_n
_thread_desc
,
c_thread_copy
.
Run
(
c_m0_
n0_m1_n1_m2_m3_m4_n2
_thread_desc
,
make_tuple
(
I0
,
I0
,
I0
,
I0
,
I0
,
I0
,
I0
,
I0
),
c_thread_buf
[
Number
<
blk_off
>
{}].
template
AsType
<
FloatAcc
>(),
c_m0_
m1_m2_n
_grid_desc
,
c_m0_
n0_m1_n1_m2_m3_m4_n2
_grid_desc
,
c_grid_buf
,
c_m0_
m1_m2_n
_grid_tensor_
s
te
p
_hacks
);
c_m0_
n0_m1_n1_m2_m3_m4_n2
_grid_tensor_
i
te
rator
_hacks
);
};
auto
mrepeat_minus_copy
=
[
&
](
auto
c_thread_idx_
)
{
constexpr
auto
mrepeat_step_plus
=
make_multi_index
(
-
1
,
0
,
0
,
0
,
0
,
0
,
0
,
0
);
c_thread_copy
.
MoveDstSliceWindow
(
c_m0_m1_m2_n_grid_desc
,
mrepeat_step_plus
);
c_thread_copy
.
MoveDstSliceWindow
(
c_m0_n0_m1_n1_m2_m3_m4_n2_grid_desc
,
mrepeat_step_plus
);
constexpr
auto
blk_off
=
c_mr_nr_blk_desc
.
CalculateOffset
(
c_thread_idx_
);
c_thread_copy
.
Run
(
c_m0_
m1_m2_n
_thread_desc
,
c_thread_copy
.
Run
(
c_m0_
n0_m1_n1_m2_m3_m4_n2
_thread_desc
,
make_tuple
(
I0
,
I0
,
I0
,
I0
,
I0
,
I0
,
I0
,
I0
),
c_thread_buf
[
Number
<
blk_off
>
{}].
template
AsType
<
FloatAcc
>(),
c_m0_
m1_m2_n
_grid_desc
,
c_m0_
n0_m1_n1_m2_m3_m4_n2
_grid_desc
,
c_grid_buf
,
c_m0_
m1_m2_n
_grid_tensor_
s
te
p
_hacks
);
c_m0_
n0_m1_n1_m2_m3_m4_n2
_grid_tensor_
i
te
rator
_hacks
);
};
auto
nrepeat_minus_copy
=
[
&
](
auto
c_thread_idx_
)
{
constexpr
auto
nrepeat_step_minus
=
make_multi_index
(
0
,
-
1
,
0
,
0
,
0
,
0
,
0
,
0
);
c_thread_copy
.
MoveDstSliceWindow
(
c_m0_m1_m2_n_grid_desc
,
nrepeat_step_minus
);
c_thread_copy
.
MoveDstSliceWindow
(
c_m0_n0_m1_n1_m2_m3_m4_n2_grid_desc
,
nrepeat_step_minus
);
constexpr
auto
blk_off
=
c_mr_nr_blk_desc
.
CalculateOffset
(
c_thread_idx_
);
c_thread_copy
.
Run
(
c_m0_
m1_m2_n
_thread_desc
,
c_thread_copy
.
Run
(
c_m0_
n0_m1_n1_m2_m3_m4_n2
_thread_desc
,
make_tuple
(
I0
,
I0
,
I0
,
I0
,
I0
,
I0
,
I0
,
I0
),
c_thread_buf
[
Number
<
blk_off
>
{}].
template
AsType
<
FloatAcc
>(),
c_m0_
m1_m2_n
_grid_desc
,
c_m0_
n0_m1_n1_m2_m3_m4_n2
_grid_desc
,
c_grid_buf
,
c_m0_
m1_m2_n
_grid_tensor_
s
te
p
_hacks
);
c_m0_
n0_m1_n1_m2_m3_m4_n2
_grid_tensor_
i
te
rator
_hacks
);
};
static_assert
((
MRepeat
==
4
&&
NRepeat
==
4
)
or
(
MRepeat
==
4
&&
NRepeat
==
2
)
or
...
...
@@ -791,7 +694,6 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3
init_copy
(
make_tuple
(
I0
,
I0
));
}
}
#endif
}
};
// namespace ck
...
...
composable_kernel/include/tensor_operation/xdlops_gemm.hpp
View file @
d3341a67
...
...
@@ -709,19 +709,59 @@ struct XdlopsGemm
static_assert
(
mfma_type
.
k
%
mfma_type
.
k_base
==
0
,
"k % kbase != 0!"
);
}
template
<
typename
CM0N0M1N1M2N2GridDesc
>
__host__
__device__
static
constexpr
auto
MakeCM0N0M1N1M2M3M4N2GridDescriptor
(
const
CM0N0M1N1M2N2GridDesc
&
c_m0_n0_m1_n1_m2_n2_grid_desc
)
{
constexpr
auto
I0
=
Number
<
0
>
{};
constexpr
auto
I1
=
Number
<
1
>
{};
constexpr
auto
I2
=
Number
<
2
>
{};
constexpr
auto
I3
=
Number
<
3
>
{};
constexpr
auto
I4
=
Number
<
4
>
{};
constexpr
auto
I5
=
Number
<
5
>
{};
constexpr
auto
M0
=
c_m0_n0_m1_n1_m2_n2_grid_desc
.
GetLength
(
I0
);
constexpr
auto
N0
=
c_m0_n0_m1_n1_m2_n2_grid_desc
.
GetLength
(
I1
);
constexpr
auto
M1
=
c_m0_n0_m1_n1_m2_n2_grid_desc
.
GetLength
(
I2
);
constexpr
auto
N1
=
c_m0_n0_m1_n1_m2_n2_grid_desc
.
GetLength
(
I3
);
constexpr
auto
M2
=
c_m0_n0_m1_n1_m2_n2_grid_desc
.
GetLength
(
I4
);
constexpr
auto
N2
=
c_m0_n0_m1_n1_m2_n2_grid_desc
.
GetLength
(
I5
);
static_assert
(
N2
==
mfma_type
.
num_threads_blk
,
""
);
static_assert
(
M2
==
(
mfma_type
.
num_groups_blk
*
mfma_type
.
num_output_blks
*
mfma_type
.
group_size
),
""
);
return
transform_dynamic_tensor_descriptor
(
c_m0_n0_m1_n1_m2_n2_grid_desc
,
make_tuple
(
make_pass_through_transform
(
M0
),
make_pass_through_transform
(
N0
),
make_pass_through_transform
(
M1
),
make_pass_through_transform
(
N1
),
make_unmerge_transform
(
make_tuple
(
mfma_type
.
num_groups_blk
,
mfma_type
.
num_input_blks
,
mfma_type
.
group_size
)),
make_pass_through_transform
(
mfma_type
.
num_threads_blk
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{},
Sequence
<
3
>
{},
Sequence
<
4
>
{},
Sequence
<
5
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{},
Sequence
<
3
>
{},
Sequence
<
4
,
5
,
6
>
{},
Sequence
<
7
>
{}));
}
__device__
static
constexpr
index_t
GetRegSizePerXdlops
()
{
return
MPerXdlops
*
NPerXdlops
/
mfma_type
.
wave_size
;
}
template
<
class
ADesc
,
class
BDesc
,
class
CDesc
,
index_t
m0
,
index_t
n0
,
class
FloatA
,
class
FloatB
,
class
FloatC
>
template
<
index_t
c_offset
,
class
FloatA
,
class
FloatB
,
class
FloatC
>
__device__
void
Run
(
const
FloatA
&
p_a_wave
,
const
FloatB
&
p_b_wave
,
FloatC
&
p_c_thread
)
const
{
static_assert
(
is_same
<
base_type
,
float
>::
value
||
is_same
<
base_type
,
half_t
>::
value
||
...
...
@@ -730,24 +770,35 @@ struct XdlopsGemm
static_assert
(
KPack
%
mfma_type
.
k_base
==
0
,
"KPack cannot be divided by k_base"
);
constexpr
index_t
c_offset
=
CDesc
{}.
CalculateOffset
(
make_tuple
(
m0
,
n0
))
*
GetNumXdlops
();
static_for
<
0
,
KPack
,
mfma_type
.
k_base
>
{}([
&
](
auto
k
)
{
constexpr
index_t
a_offset
=
ADesc
{}.
CalculateOffset
(
make_tuple
(
0
,
m0
,
0
,
k
));
constexpr
index_t
b_offset
=
BDesc
{}.
CalculateOffset
(
make_tuple
(
0
,
n0
,
0
,
k
));
static_for
<
0
,
KPack
/
mfma_type
.
k_base
,
1
>
{}([
&
](
auto
k
)
{
mfma_type
.
template
run
<
MPerXdlops
,
NPerXdlops
,
c_offset
>(
p_a_wave
[
Number
<
a_offset
/
mfma_type
.
k_base
>
{}],
p_b_wave
[
Number
<
b_offset
/
mfma_type
.
k_base
>
{}],
p_c_thread
);
p_a_wave
[
k
],
p_b_wave
[
k
],
p_c_thread
);
});
}
static
constexpr
auto
GetBlkIdx
()
{
const
auto
threadidx_to_blk_idx_adaptor
=
make_single_stage_tensor_adaptor
(
make_tuple
(
make_merge_transform
(
make_tuple
(
1
,
mfma_type
.
num_input_blks
,
mfma_type
.
num_threads_blk
))),
make_tuple
(
Sequence
<
0
,
1
,
2
>
{}),
make_tuple
(
Sequence
<
0
>
{}));
const
auto
blk_idx
=
threadidx_to_blk_idx_adaptor
.
CalculateBottomIndex
(
make_multi_index
(
get_thread_local_1d_id
()));
const
auto
blk_id
=
blk_idx
[
Number
<
1
>
{}];
const
auto
blk_td
=
blk_idx
[
Number
<
2
>
{}];
return
make_tuple
(
blk_id
,
blk_td
);
}
__device__
static
CIndex
GetBeginOfThreadBlk
(
index_t
xdlops_i
,
index_t
blk_i
)
{
const
index_t
laneId
=
get_thread_local_1d_id
()
%
mfma_type
.
wave_size
;
const
index_t
blk_id
=
laneId
/
mfma_type
.
num_threads_blk
;
const
index_t
blk_td
=
laneId
%
mfma_type
.
num_threads_blk
;
const
auto
blk_idx
=
GetBlkIdx
();
const
auto
blk_id
=
blk_idx
[
Number
<
0
>
{}];
const
auto
blk_td
=
blk_idx
[
Number
<
1
>
{}];
index_t
n_offset
=
blk_i
*
mfma_type
.
n
+
blk_td
;
index_t
m_offset
=
xdlops_i
*
mfma_type
.
m
+
blk_id
*
mfma_type
.
group_size
;
...
...
@@ -755,24 +806,12 @@ struct XdlopsGemm
return
CIndex
{
m_offset
,
n_offset
};
}
static
constexpr
index_t
MRepeats
=
GetXdlopsInfo
().
MRepeats
;
static
constexpr
index_t
NRepeats
=
GetXdlopsInfo
().
NRepeats
;
static
constexpr
index_t
MPerXdlops
=
GetXdlopsInfo
().
MPerXdlops
;
static
constexpr
index_t
NPerXdlops
=
GetXdlopsInfo
().
NPerXdlops
;
static
constexpr
index_t
KPerXdlops
=
GetXdlopsInfo
().
GetKPerXdlops
();
static
constexpr
bool
IsKReduction
=
GetXdlopsInfo
().
IsKReduction
();
static
constexpr
bool
IsABroadcast
=
GetXdlopsInfo
().
IsABroadcast
();
static
constexpr
index_t
KPerXdlops
=
GetXdlopsInfo
().
GetKPerXdlops
();
static
constexpr
auto
GetBlkId
(
const
index_t
lane_id
)
{
return
lane_id
/
mfma_type
.
num_threads_blk
;
}
static
constexpr
auto
GetBlkTd
(
const
index_t
lane_id
)
{
return
lane_id
%
mfma_type
.
num_threads_blk
;
}
static
constexpr
auto
mfma_type
=
GetXdlopsInfo
().
mfma_type
;
...
...
@@ -794,7 +833,7 @@ struct XdlopsGemm
}
};
__host__
__device__
static
constexpr
auto
GetCLayout
()
{
return
CLayout
{};
}
__host__
__device__
static
constexpr
auto
GetC
Xdlops
Layout
()
{
return
CLayout
{};
}
};
}
// namespace ck
...
...
host/driver_offline/include/driver_gemm_xdlops_v2r3.hpp
View file @
d3341a67
...
...
@@ -129,9 +129,10 @@ __host__ float driver_gemm_xdlops_v2r3(const FloatAB* p_a_grid,
"wrong! GridwiseGemm_km_kn_m0m1n0n1_xdlops_v2r3 has invalid setting"
);
}
const
auto
c_m0_m1_m2_n_grid_desc
=
GridwiseGemm
::
MakeCM0M1M2NGridDescriptor
(
c_m_n_grid_desc
);
const
auto
c_m0_n0_m1_n1_m2_m3_m4_n2_grid_desc
=
GridwiseGemm
::
MakeCM0N0M1N1M2M3M4N2GridDescriptor
(
c_m_n_grid_desc
);
using
CM0
M1M2N
GridDesc
=
decltype
(
c_m0_
m1_m2_n
_grid_desc
);
using
CM0
N0M1N1M2M3M4N2
GridDesc
=
decltype
(
c_m0_
n0_m1_n1_m2_m3_m4_n2
_grid_desc
);
const
auto
c_block_cluster_adaptor
=
GridwiseGemm
::
MakeCBlockClusterAdaptor
(
c_m_n_grid_desc
);
...
...
@@ -144,7 +145,7 @@ __host__ float driver_gemm_xdlops_v2r3(const FloatAB* p_a_grid,
FloatC
,
remove_reference_t
<
AK0MK1GridDesc
>
,
remove_reference_t
<
BK0NK1GridDesc
>
,
remove_reference_t
<
CM0
M1M2N
GridDesc
>
,
remove_reference_t
<
CM0
N0M1N1M2M3M4N2
GridDesc
>
,
remove_reference_t
<
CBlockClusterAdaptor
>>
;
#if CK_EXPERIMENTAL_PASS_TENSOR_DESCRIPTOR_BY_VALUE
...
...
@@ -158,18 +159,18 @@ __host__ float driver_gemm_xdlops_v2r3(const FloatAB* p_a_grid,
p_c_grid
,
a_k0_m_k1_grid_desc
,
b_k0_n_k1_grid_desc
,
c_m0_
m1_m2_n
_grid_desc
,
c_m0_
n0_m1_n1_m2_m3_m4_n2
_grid_desc
,
c_block_cluster_adaptor
);
#elif CK_EXPERIMENTAL_PASS_TENSOR_DESCRIPTOR_BY_VOID_POINTER
DeviceMem
a_k0_m_k1_grid_desc_dev_buf
(
sizeof
(
AK0MK1GridDesc
));
DeviceMem
b_k0_n_k1_grid_desc_dev_buf
(
sizeof
(
BK0NK1GridDesc
));
DeviceMem
c_m0_
m1_m2_n
_grid_desc_dev_buf
(
sizeof
(
CM0
M1M2N
GridDesc
));
DeviceMem
c_m0_
n0_m1_n1_m2_m3_m4_n2
_grid_desc_dev_buf
(
sizeof
(
CM0
N0M1N1M2M3M4N2
GridDesc
));
DeviceMem
c_block_cluster_adaptor_dev_buf
(
sizeof
(
CBlockClusterAdaptor
));
a_k0_m_k1_grid_desc_dev_buf
.
ToDevice
(
&
a_k0_m_k1_grid_desc
);
b_k0_n_k1_grid_desc_dev_buf
.
ToDevice
(
&
b_k0_n_k1_grid_desc
);
c_m0_
m1_m2_n
_grid_desc_dev_buf
.
ToDevice
(
&
c_m0_
m1_m2_n
_grid_desc
);
c_m0_
n0_m1_n1_m2_m3_m4_n2
_grid_desc_dev_buf
.
ToDevice
(
&
c_m0_
n0_m1_n1_m2_m3_m4_n2
_grid_desc
);
c_block_cluster_adaptor_dev_buf
.
ToDevice
(
&
c_block_cluster_adaptor
);
float
ave_time
=
launch_and_time_kernel
(
...
...
@@ -183,7 +184,8 @@ __host__ float driver_gemm_xdlops_v2r3(const FloatAB* p_a_grid,
p_c_grid
,
cast_pointer_to_constant_address_space
(
a_k0_m_k1_grid_desc_dev_buf
.
GetDeviceBuffer
()),
cast_pointer_to_constant_address_space
(
b_k0_n_k1_grid_desc_dev_buf
.
GetDeviceBuffer
()),
cast_pointer_to_constant_address_space
(
c_m0_m1_m2_n_grid_desc_dev_buf
.
GetDeviceBuffer
()),
cast_pointer_to_constant_address_space
(
c_m0_n0_m1_n1_m2_m3_m4_n2_grid_desc_dev_buf
.
GetDeviceBuffer
()),
cast_pointer_to_constant_address_space
(
c_block_cluster_adaptor_dev_buf
.
GetDeviceBuffer
()));
#endif
return
ave_time
;
...
...
host/driver_offline/src/conv_fwd_driver_offline.cpp
View file @
d3341a67
...
...
@@ -24,7 +24,7 @@
#define USE_CONV_FWD_V4R4R2_NHWC 1
#define USE_CONV_FWD_V6R1_NCHW 0
#define USE_CONV_FWD_V5R1_NCHW 0
#define USE_CONV_FWD_V4R4R2_XDL_NCHW
0
#define USE_CONV_FWD_V4R4R2_XDL_NCHW
1
#define USE_CONV_FWD_V4R4R4_XDL_NHWC 0
enum
ConvForwardAlgo
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment