Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
1d48b521
Commit
1d48b521
authored
May 18, 2021
by
Jing Zhang
Browse files
clean code
parent
c0ffe379
Changes
6
Hide whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
220 additions
and
17 deletions
+220
-17
composable_kernel/include/driver/driver_dynamic_gemm_xdlops_v1.hpp
...e_kernel/include/driver/driver_dynamic_gemm_xdlops_v1.hpp
+4
-0
composable_kernel/include/tensor_operation/blockwise_gemm_xdlops.hpp
...kernel/include/tensor_operation/blockwise_gemm_xdlops.hpp
+197
-2
composable_kernel/include/tensor_operation/gridwise_dynamic_gemm_xdlops.hpp
...include/tensor_operation/gridwise_dynamic_gemm_xdlops.hpp
+2
-2
composable_kernel/include/tensor_operation/xdlops_gemm.hpp
composable_kernel/include/tensor_operation/xdlops_gemm.hpp
+3
-4
composable_kernel/include/utility/amd_xdlops.hpp
composable_kernel/include/utility/amd_xdlops.hpp
+6
-6
driver/include/device_dynamic_convolution_forward_implicit_gemm_v4r4_xdlops_nchw_kcyx_nkhw.hpp
...tion_forward_implicit_gemm_v4r4_xdlops_nchw_kcyx_nkhw.hpp
+8
-3
No files found.
composable_kernel/include/driver/driver_dynamic_gemm_xdlops_v1.hpp
View file @
1d48b521
...
...
@@ -24,6 +24,8 @@ template <index_t BlockSize,
index_t
MPerWave
,
index_t
NPerWave
,
index_t
KPerWave
,
index_t
MRepeat
,
index_t
NRepeat
,
typename
ABlockTransferThreadSliceLengths_K_M
,
typename
ABlockTransferThreadClusterLengths_K_M
,
typename
ABlockTransferThreadClusterArrangeOrder
,
...
...
@@ -99,6 +101,8 @@ __host__ float launch_kernel_dynamic_gemm_xdlops_v1(const FloatAB* p_a_global,
MPerWave
,
NPerWave
,
KPerWave
,
MRepeat
,
NRepeat
,
ABlockTransferThreadSliceLengths_K_M
,
ABlockTransferThreadClusterLengths_K_M
,
ABlockTransferThreadClusterArrangeOrder
,
...
...
composable_kernel/include/tensor_operation/blockwise_gemm_xdlops.hpp
View file @
1d48b521
...
...
@@ -125,8 +125,203 @@ struct BlockwiseGemmXdlops_km_kn_m0m1m2n_v1
static_assert
(
ABlockDesc
{}.
GetLength
(
I0
)
==
BBlockDesc
{}.
GetLength
(
I0
),
"wrong! K dimension not consistent"
);
// static_assert(MPerWave * MWaves == MPerBlock, "GemmMWaves * MPerWave != M");
// static_assert(NPerWave * NWaves == NPerBlock, "GemmNWaves * NPerWave != N");
static_assert
(
BlockSize
==
MWaves
*
NWaves
*
WaveSize
,
"BlockSize != MWaves * NWaves * WaveSize
\n
"
);
}
template
<
typename
ABlockBuffer
,
typename
BBlockBuffer
,
typename
CThreadBuffer
>
__device__
void
Run
(
const
ABlockBuffer
&
a_block_buf
,
const
BBlockBuffer
&
b_block_buf
,
CThreadBuffer
&
c_thread_buf
)
const
{
auto
a_thread_buf
=
make_static_buffer
<
AddressSpace
::
Vgpr
,
FloatA
>
(
a_thread_desc_
.
GetElementSpaceSize
());
auto
b_thread_buf
=
make_static_buffer
<
AddressSpace
::
Vgpr
,
FloatB
>
(
b_thread_desc_
.
GetElementSpaceSize
());
constexpr
index_t
KPerBlock
=
ABlockDesc
{}.
GetLength
(
I0
);
static_for
<
0
,
KPerBlock
,
KPerWave
>
{}([
&
](
auto
k
)
{
// read A
a_thread_copy_
.
Run
(
ABlockDesc
{},
make_tuple
(
k
,
I0
,
I0
),
a_block_buf
,
a_thread_desc_
,
make_tuple
(
I0
,
I0
,
I0
),
a_thread_buf
);
// read B
b_thread_copy_
.
Run
(
BBlockDesc
{},
make_tuple
(
k
,
I0
,
I0
),
b_block_buf
,
b_thread_desc_
,
make_tuple
(
I0
,
I0
,
I0
),
b_thread_buf
);
static_for
<
0
,
MRepeat
,
1
>
{}([
&
](
auto
m0
)
{
static_for
<
0
,
NRepeat
,
1
>
{}([
&
](
auto
n0
)
{
xdlops_gemm
.
template
Run2
<
decltype
(
a_thread_desc_
),
decltype
(
b_thread_desc_
),
decltype
(
c_thread_desc_
),
m0
,
n0
>(
a_thread_buf
,
b_thread_buf
,
c_thread_buf
);
});
});
});
}
private:
// A[K, M]
static
constexpr
auto
a_thread_desc_
=
make_dynamic_naive_tensor_descriptor_packed_v2
(
make_tuple
(
Number
<
KPerWave
>
{},
Number
<
MRepeat
>
{},
I1
));
// B[K, N]
static
constexpr
auto
b_thread_desc_
=
make_dynamic_naive_tensor_descriptor_packed_v2
(
make_tuple
(
Number
<
KPerWave
>
{},
Number
<
NRepeat
>
{},
I1
));
static
constexpr
auto
c_thread_desc_
=
make_dynamic_naive_tensor_descriptor_packed_v2
(
make_tuple
(
Number
<
MRepeat
>
{},
Number
<
NRepeat
>
{}));
using
AThreadCopy
=
ThreadwiseDynamicTensorSliceTransfer_v4
<
FloatA
,
FloatA
,
ABlockDesc
,
decltype
(
a_thread_desc_
),
Sequence
<
KPerWave
,
MRepeat
,
1
>
,
Sequence
<
0
,
1
,
2
>
,
2
,
1
,
1
>
;
using
BThreadCopy
=
ThreadwiseDynamicTensorSliceTransfer_v4
<
FloatB
,
FloatB
,
BBlockDesc
,
decltype
(
b_thread_desc_
),
Sequence
<
KPerWave
,
NRepeat
,
1
>
,
Sequence
<
0
,
1
,
2
>
,
2
,
1
,
1
>
;
AThreadCopy
a_thread_copy_
;
BThreadCopy
b_thread_copy_
;
};
template
<
index_t
BlockSize
,
typename
FloatA
,
typename
FloatB
,
class
ABlockDesc
,
class
BBlockDesc
,
index_t
MPerWave
,
index_t
NPerWave
,
index_t
KPerWave
>
struct
BlockwiseGemmXdlops_km_kn_m0m1m2n_v1_2x2pipeline
{
using
CIndex
=
MultiIndex
<
2
>
;
static
constexpr
auto
I0
=
Number
<
0
>
{};
static
constexpr
auto
I1
=
Number
<
1
>
{};
static
constexpr
auto
I2
=
Number
<
2
>
{};
static
constexpr
auto
I3
=
Number
<
3
>
{};
static
constexpr
auto
xdlops_gemm
=
XdlopsGemm
<
float
,
MPerWave
,
NPerWave
,
KPerWave
>
{};
static
constexpr
index_t
WaveSize
=
64
;
static
constexpr
index_t
M0
=
ABlockDesc
{}.
GetLength
(
I1
);
static
constexpr
index_t
M1
=
ABlockDesc
{}.
GetLength
(
I2
);
static
constexpr
index_t
N0
=
BBlockDesc
{}.
GetLength
(
I1
);
static
constexpr
index_t
N1
=
BBlockDesc
{}.
GetLength
(
I2
);
static
constexpr
index_t
MWaves
=
M1
/
MPerWave
;
static
constexpr
index_t
NWaves
=
N1
/
NPerWave
;
static
constexpr
index_t
MRepeat
=
M0
;
static
constexpr
index_t
NRepeat
=
N0
;
__device__
constexpr
auto
GetOutputLayout
()
const
{
return
xdlops_gemm
.
GetOutputLayout
();
}
__device__
constexpr
auto
GetNumBlks
()
const
{
return
xdlops_gemm
.
GetOutputLayout
().
GetNumBlks
();
}
__device__
constexpr
auto
GetBlkSize
()
const
{
return
xdlops_gemm
.
GetOutputLayout
().
GetBlkSize
();
}
__device__
static
auto
CalculateAThreadOriginDataIndex
()
{
const
index_t
thread_id
=
get_thread_local_1d_id
();
const
index_t
waveId
=
thread_id
/
WaveSize
;
const
index_t
laneId
=
thread_id
%
WaveSize
;
const
index_t
waveId_m
=
waveId
/
NWaves
;
const
index_t
waveId_n
=
waveId
%
NWaves
;
if
constexpr
(
xdlops_gemm
.
IsKReduction
)
{
const
index_t
m_offset
=
waveId_m
*
MPerWave
+
xdlops_gemm
.
GetBlkTd
(
laneId
);
const
index_t
k_offset
=
xdlops_gemm
.
GetBlkId
(
laneId
)
*
xdlops_gemm
.
mfma_type
.
k_base
;
return
make_tuple
(
k_offset
,
0
,
m_offset
);
}
else
{
const
index_t
m_offset
=
waveId_m
*
MPerWave
+
laneId
;
const
index_t
k_offset
=
0
;
return
make_tuple
(
k_offset
,
0
,
m_offset
);
}
}
__device__
static
auto
CalculateBThreadOriginDataIndex
()
{
const
index_t
thread_id
=
get_thread_local_1d_id
();
const
index_t
waveId
=
thread_id
/
WaveSize
;
const
index_t
laneId
=
thread_id
%
WaveSize
;
const
index_t
waveId_m
=
waveId
/
NWaves
;
const
index_t
waveId_n
=
waveId
%
NWaves
;
if
constexpr
(
xdlops_gemm
.
IsKReduction
)
{
const
index_t
n_offset
=
waveId_n
*
NPerWave
+
xdlops_gemm
.
GetBlkTd
(
laneId
);
const
index_t
k_offset
=
xdlops_gemm
.
GetBlkId
(
laneId
)
*
xdlops_gemm
.
mfma_type
.
k_base
;
return
make_tuple
(
k_offset
,
0
,
n_offset
);
}
else
{
const
index_t
n_offset
=
waveId_n
*
NPerWave
+
laneId
;
const
index_t
k_offset
=
0
;
return
make_tuple
(
k_offset
,
0
,
n_offset
);
}
}
__device__
static
CIndex
CalculateCThreadOriginDataIndex
(
const
index_t
m0
,
const
index_t
n0
,
const
index_t
blk_i
)
{
const
index_t
waveId
=
get_thread_local_1d_id
()
/
WaveSize
;
const
auto
thread_mtx_on_blk
=
xdlops_gemm
.
GetBeginOfThreadBlk
(
blk_i
);
const
index_t
waveId_m
=
waveId
/
NWaves
;
const
index_t
waveId_n
=
waveId
%
NWaves
;
const
index_t
row
=
m0
*
M1
+
waveId_m
*
MPerWave
+
thread_mtx_on_blk
.
row
;
const
index_t
col
=
n0
*
N1
+
waveId_n
*
NPerWave
+
thread_mtx_on_blk
.
col
;
return
CIndex
{
row
,
col
};
}
__device__
BlockwiseGemmXdlops_km_kn_m0m1m2n_v1_2x2pipeline
()
:
a_thread_copy_
{
CalculateAThreadOriginDataIndex
()},
b_thread_copy_
{
CalculateBThreadOriginDataIndex
()}
{
static_assert
(
ABlockDesc
::
IsKnownAtCompileTime
()
&&
BBlockDesc
::
IsKnownAtCompileTime
(),
"wrong! Desc should be known at compile-time"
);
static_assert
(
ABlockDesc
{}.
GetLength
(
I0
)
==
BBlockDesc
{}.
GetLength
(
I0
),
"wrong! K dimension not consistent"
);
static_assert
(
BlockSize
==
MWaves
*
NWaves
*
WaveSize
,
"BlockSize != MWaves * NWaves * WaveSize
\n
"
);
...
...
composable_kernel/include/tensor_operation/gridwise_dynamic_gemm_xdlops.hpp
View file @
1d48b521
...
...
@@ -111,6 +111,8 @@ template <index_t BlockSize,
index_t
MPerWave
,
index_t
NPerWave
,
index_t
KPerWave
,
index_t
MRepeat
,
index_t
NRepeat
,
typename
ABlockTransferThreadSliceLengths_K_M
,
typename
ABlockTransferThreadClusterLengths_K_M
,
typename
ABlockTransferThreadClusterArrangeOrder
,
...
...
@@ -278,8 +280,6 @@ struct GridwiseDynamicGemm_km_kn_m0m1n0n1_xdlops_v1
// c_mtx[MPerBlock, NPerBlock] is distributed among threads, and saved in
// register
// sanity check
constexpr
index_t
MRepeat
=
2
;
constexpr
index_t
NRepeat
=
2
;
static_assert
(
MPerBlock
%
(
MPerWave
*
MRepeat
)
==
0
&&
NPerBlock
%
(
NPerWave
*
NRepeat
)
==
0
,
...
...
composable_kernel/include/tensor_operation/xdlops_gemm.hpp
View file @
1d48b521
...
...
@@ -620,10 +620,9 @@ struct XdlopsGemm
constexpr
index_t
b_offset
=
BDesc
{}.
CalculateOffset
(
make_multi_index
(
k
,
n0
,
0
));
constexpr
index_t
c_offset
=
CDesc
{}.
CalculateOffset
(
make_multi_index
(
m0
,
n0
));
mfma_type
.
template
run
<
MPerXdlops
,
NPerXdlops
>(
p_a_wave
[
Number
<
a_offset
>
{}],
p_b_wave
[
Number
<
b_offset
>
{}],
p_c_thread
.
template
AsType
<
float16_t
>()(
Number
<
c_offset
>
{}));
mfma_type
.
template
run
<
MPerXdlops
,
NPerXdlops
>(
p_a_wave
[
Number
<
a_offset
>
{}],
p_b_wave
[
Number
<
b_offset
>
{}],
p_c_thread
.
template
AsType
<
float32_t
>());
});
}
...
...
composable_kernel/include/utility/amd_xdlops.hpp
View file @
1d48b521
...
...
@@ -240,13 +240,13 @@ struct intrin_mfma_f32_32x32x1f32;
template
<
>
struct
intrin_mfma_f32_32x32x1f32
<
64
,
64
>
{
__device__
static
void
Run
(
const
f
loat
&
reg_a
,
const
f
loat
&
reg_b
,
vector_type
<
float
,
64
>
&
reg_c
)
template
<
class
FloatA
,
class
FloatB
,
class
FloatC
>
__device__
static
void
Run
(
const
F
loat
A
&
reg_a
,
const
F
loat
B
&
reg_b
,
FloatC
&
reg_c
)
{
reg_c
.
template
AsType
<
float32_t
>()(
Number
<
0
>
{})
=
llvm_intrin_amdgcn_mfma_f32_32x32x1f32
(
reg_a
,
reg_b
,
reg_c
.
template
AsType
<
float32_t
>()
[
Number
<
0
>
{}],
1
,
0
,
0
);
reg_c
.
template
AsType
<
float32_t
>()(
Number
<
1
>
{})
=
llvm_intrin_amdgcn_mfma_f32_32x32x1f32
(
reg_a
,
reg_b
,
reg_c
.
template
AsType
<
float32_t
>()
[
Number
<
1
>
{}],
1
,
1
,
0
);
reg_c
(
Number
<
0
>
{})
=
llvm_intrin_amdgcn_mfma_f32_32x32x1f32
(
reg_a
,
reg_b
,
reg_c
[
Number
<
0
>
{}],
1
,
0
,
0
);
reg_c
(
Number
<
1
>
{})
=
llvm_intrin_amdgcn_mfma_f32_32x32x1f32
(
reg_a
,
reg_b
,
reg_c
[
Number
<
1
>
{}],
1
,
1
,
0
);
}
};
...
...
driver/include/device_dynamic_convolution_forward_implicit_gemm_v4r4_xdlops_nchw_kcyx_nkhw.hpp
View file @
1d48b521
...
...
@@ -108,9 +108,12 @@ void device_dynamic_convolution_forward_implicit_gemm_v4r4_xdlops_nchw_kcyx_nkhw
constexpr
index_t
GemmNPerBlock
=
64
;
constexpr
index_t
GemmKPerBlock
=
8
;
constexpr
index_t
GemmMPerWave
=
32
;
constexpr
index_t
GemmNPerWave
=
32
;
constexpr
index_t
GemmKPerWave
=
2
;
constexpr
index_t
GemmMPerWave
=
64
;
constexpr
index_t
GemmNPerWave
=
64
;
constexpr
index_t
GemmKPerWave
=
1
;
constexpr
index_t
MRepeat
=
1
;
constexpr
index_t
NRepeat
=
1
;
using
GemmABlockTransferThreadSliceLengths_GemmK_GemmM
=
Sequence
<
4
,
2
>
;
using
GemmABlockTransferThreadClusterLengths_GemmK_GemmM
=
Sequence
<
2
,
32
>
;
...
...
@@ -159,6 +162,8 @@ void device_dynamic_convolution_forward_implicit_gemm_v4r4_xdlops_nchw_kcyx_nkhw
GemmMPerWave
,
GemmNPerWave
,
GemmKPerWave
,
MRepeat
,
NRepeat
,
GemmABlockTransferThreadSliceLengths_GemmK_GemmM
,
GemmABlockTransferThreadClusterLengths_GemmK_GemmM
,
Sequence
<
1
,
0
>
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment