Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
3399ddaf
Commit
3399ddaf
authored
May 20, 2021
by
Jing Zhang
Browse files
break vector type to blk_size
parent
59462dca
Changes
6
Hide whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
141 additions
and
127 deletions
+141
-127
composable_kernel/include/driver/driver_dynamic_convolution_forward_implicit_gemm_v4r4_xdlops_nchw_kcyx_nkhw.hpp
...tion_forward_implicit_gemm_v4r4_xdlops_nchw_kcyx_nkhw.hpp
+1
-1
composable_kernel/include/tensor_operation/blockwise_gemm_xdlops.hpp
...kernel/include/tensor_operation/blockwise_gemm_xdlops.hpp
+10
-21
composable_kernel/include/tensor_operation/gridwise_dynamic_gemm_xdlops.hpp
...include/tensor_operation/gridwise_dynamic_gemm_xdlops.hpp
+56
-49
composable_kernel/include/tensor_operation/xdlops_gemm.hpp
composable_kernel/include/tensor_operation/xdlops_gemm.hpp
+35
-35
composable_kernel/include/utility/amd_xdlops.hpp
composable_kernel/include/utility/amd_xdlops.hpp
+30
-12
driver/include/device_dynamic_convolution_forward_implicit_gemm_v4r4_xdlops_nchw_kcyx_nkhw.hpp
...tion_forward_implicit_gemm_v4r4_xdlops_nchw_kcyx_nkhw.hpp
+9
-9
No files found.
composable_kernel/include/driver/driver_dynamic_convolution_forward_implicit_gemm_v4r4_xdlops_nchw_kcyx_nkhw.hpp
View file @
3399ddaf
...
@@ -111,7 +111,7 @@ transform_forward_convolution_into_gemm_v4r4_xdlops_nchw_kcyx_nkhw_pad(
...
@@ -111,7 +111,7 @@ transform_forward_convolution_into_gemm_v4r4_xdlops_nchw_kcyx_nkhw_pad(
constexpr
auto
xdlops_gemm
=
XdlopsGemm
<
float
,
GemmMPerWave
,
GemmNPerWave
,
GemmKPerWave
>
{};
constexpr
auto
xdlops_gemm
=
XdlopsGemm
<
float
,
GemmMPerWave
,
GemmNPerWave
,
GemmKPerWave
>
{};
constexpr
auto
CLayout
=
xdlops_gemm
.
Get
Output
Layout
();
constexpr
auto
CLayout
=
xdlops_gemm
.
Get
C
Layout
();
constexpr
index_t
M0
=
CLayout
.
M1
();
constexpr
index_t
M0
=
CLayout
.
M1
();
constexpr
index_t
M1
=
CLayout
.
N1
();
constexpr
index_t
M1
=
CLayout
.
N1
();
...
...
composable_kernel/include/tensor_operation/blockwise_gemm_xdlops.hpp
View file @
3399ddaf
...
@@ -42,17 +42,11 @@ struct BlockwiseGemmXdlops_km_kn_m0m1m2n_v1
...
@@ -42,17 +42,11 @@ struct BlockwiseGemmXdlops_km_kn_m0m1m2n_v1
static
constexpr
index_t
MRepeat
=
M0
;
static
constexpr
index_t
MRepeat
=
M0
;
static
constexpr
index_t
NRepeat
=
N0
;
static
constexpr
index_t
NRepeat
=
N0
;
__device__
constexpr
auto
Get
Output
Layout
()
const
{
return
xdlops_gemm
.
Get
Output
Layout
();
}
__device__
constexpr
auto
Get
C
Layout
()
const
{
return
xdlops_gemm
.
Get
C
Layout
();
}
__device__
constexpr
auto
GetNumBlks
()
const
__device__
constexpr
auto
GetNumBlks
()
const
{
return
xdlops_gemm
.
GetCLayout
().
GetNumBlks
();
}
{
return
xdlops_gemm
.
GetOutputLayout
().
GetNumBlks
();
}
__device__
constexpr
auto
GetBlkSize
()
const
__device__
constexpr
auto
GetBlkSize
()
const
{
return
xdlops_gemm
.
GetCLayout
().
GetBlkSize
();
}
{
return
xdlops_gemm
.
GetOutputLayout
().
GetBlkSize
();
}
__device__
static
auto
CalculateAThreadOriginDataIndex
()
__device__
static
auto
CalculateAThreadOriginDataIndex
()
{
{
...
@@ -98,13 +92,14 @@ struct BlockwiseGemmXdlops_km_kn_m0m1m2n_v1
...
@@ -98,13 +92,14 @@ struct BlockwiseGemmXdlops_km_kn_m0m1m2n_v1
}
}
}
}
template
<
index_t
m0
,
index_t
n0
,
index_t
blk_i
>
template
<
index_t
m0
,
index_t
n0
,
index_t
xdlops_i
,
index_t
blk_i
>
__device__
static
CIndex
CalculateCThreadOriginDataIndex
(
Number
<
m0
>
,
Number
<
n0
>
,
Number
<
blk_i
>
)
__device__
static
CIndex
CalculateCThreadOriginDataIndex
(
Number
<
m0
>
,
Number
<
n0
>
,
Number
<
xdlops_i
>
,
Number
<
blk_i
>
)
{
{
const
index_t
waveId
=
get_thread_local_1d_id
()
/
WaveSize
;
const
index_t
waveId
=
get_thread_local_1d_id
()
/
WaveSize
;
const
auto
thread_mtx_on_blk
=
xdlops_gemm
.
GetBeginOfThreadBlk
(
blk_i
);
const
auto
thread_mtx_on_blk
=
xdlops_gemm
.
GetBeginOfThreadBlk
(
xdlops_i
,
blk_i
);
const
index_t
waveId_m
=
waveId
/
NWaves
;
const
index_t
waveId_m
=
waveId
/
NWaves
;
const
index_t
waveId_n
=
waveId
%
NWaves
;
const
index_t
waveId_n
=
waveId
%
NWaves
;
...
@@ -240,17 +235,11 @@ struct BlockwiseGemmXdlops_km_kn_m0m1m2n_v1_2x2pipeline
...
@@ -240,17 +235,11 @@ struct BlockwiseGemmXdlops_km_kn_m0m1m2n_v1_2x2pipeline
static
constexpr
index_t
MRepeat
=
M0
;
static
constexpr
index_t
MRepeat
=
M0
;
static
constexpr
index_t
NRepeat
=
N0
;
static
constexpr
index_t
NRepeat
=
N0
;
__device__
constexpr
auto
Get
Output
Layout
()
const
{
return
xdlops_gemm
.
Get
Output
Layout
();
}
__device__
constexpr
auto
Get
C
Layout
()
const
{
return
xdlops_gemm
.
Get
C
Layout
();
}
__device__
constexpr
auto
GetNumBlks
()
const
__device__
constexpr
auto
GetNumBlks
()
const
{
return
xdlops_gemm
.
GetCLayout
().
GetNumBlks
();
}
{
return
xdlops_gemm
.
GetOutputLayout
().
GetNumBlks
();
}
__device__
constexpr
auto
GetBlkSize
()
const
__device__
constexpr
auto
GetBlkSize
()
const
{
return
xdlops_gemm
.
GetCLayout
().
GetBlkSize
();
}
{
return
xdlops_gemm
.
GetOutputLayout
().
GetBlkSize
();
}
__device__
static
auto
CalculateAThreadOriginDataIndex
()
__device__
static
auto
CalculateAThreadOriginDataIndex
()
{
{
...
...
composable_kernel/include/tensor_operation/gridwise_dynamic_gemm_xdlops.hpp
View file @
3399ddaf
...
@@ -310,10 +310,11 @@ struct GridwiseDynamicGemm_km_kn_m0m1n0n1_xdlops_v1
...
@@ -310,10 +310,11 @@ struct GridwiseDynamicGemm_km_kn_m0m1n0n1_xdlops_v1
MPerWave
,
MPerWave
,
NPerWave
,
NPerWave
,
KPerWave
>
{};
KPerWave
>
{};
constexpr
auto
Output
Layout
=
blockwise_gemm
.
Get
Output
Layout
();
constexpr
auto
C
Layout
=
blockwise_gemm
.
Get
C
Layout
();
constexpr
index_t
BlkSize
=
OutputLayout
.
GetBlkSize
();
constexpr
index_t
BlkSize
=
CLayout
.
GetBlkSize
();
constexpr
index_t
NumBlks
=
OutputLayout
.
GetNumBlks
();
constexpr
index_t
NumBlks
=
CLayout
.
GetNumBlks
();
constexpr
index_t
NumXdlops
=
CLayout
.
GetNumXdlops
();
// constexpr auto c_mr_nr_nb_bk_thread_desc =
// constexpr auto c_mr_nr_nb_bk_thread_desc =
// make_dynamic_naive_tensor_descriptor_packed_v2( make_tuple(Number<MRepeat>{},
// make_dynamic_naive_tensor_descriptor_packed_v2( make_tuple(Number<MRepeat>{},
...
@@ -338,7 +339,9 @@ struct GridwiseDynamicGemm_km_kn_m0m1n0n1_xdlops_v1
...
@@ -338,7 +339,9 @@ struct GridwiseDynamicGemm_km_kn_m0m1n0n1_xdlops_v1
// Sequence<MRepeat, MPerThread, NRepeat, NPerThread>>{}
// Sequence<MRepeat, MPerThread, NRepeat, NPerThread>>{}
//.Run(c_m0_m1_n0_n1_thread_desc, make_tuple(I0, I0, I0, I0), c_thread_buf, FloatAcc{0});
//.Run(c_m0_m1_n0_n1_thread_desc, make_tuple(I0, I0, I0, I0), c_thread_buf, FloatAcc{0});
StaticBuffer
<
AddressSpace
::
Vgpr
,
vector_type
<
float
,
NumBlks
*
BlkSize
>
,
MRepeat
*
NRepeat
>
StaticBuffer
<
AddressSpace
::
Vgpr
,
vector_type
<
float
,
NumBlks
*
BlkSize
>
,
MRepeat
*
NRepeat
*
NumXdlops
>
c_thread_buf
;
c_thread_buf
;
constexpr
auto
a_block_slice_copy_step
=
make_multi_index
(
KPerBlock
,
0
);
constexpr
auto
a_block_slice_copy_step
=
make_multi_index
(
KPerBlock
,
0
);
...
@@ -471,9 +474,9 @@ struct GridwiseDynamicGemm_km_kn_m0m1n0n1_xdlops_v1
...
@@ -471,9 +474,9 @@ struct GridwiseDynamicGemm_km_kn_m0m1n0n1_xdlops_v1
// output: register to global memory
// output: register to global memory
{
{
constexpr
index_t
M0
=
Output
Layout
.
M1
();
constexpr
index_t
M0
=
C
Layout
.
M1
();
constexpr
index_t
M1
=
Output
Layout
.
N1
();
constexpr
index_t
M1
=
C
Layout
.
N1
();
constexpr
index_t
M2
=
Output
Layout
.
M0
();
constexpr
index_t
M2
=
C
Layout
.
M0
();
constexpr
auto
c_m0_m1_m2_n_thread_desc
=
constexpr
auto
c_m0_m1_m2_n_thread_desc
=
make_dynamic_naive_tensor_descriptor_packed_v2
(
make_dynamic_naive_tensor_descriptor_packed_v2
(
...
@@ -483,49 +486,53 @@ struct GridwiseDynamicGemm_km_kn_m0m1n0n1_xdlops_v1
...
@@ -483,49 +486,53 @@ struct GridwiseDynamicGemm_km_kn_m0m1n0n1_xdlops_v1
static_for
<
0
,
MRepeat
,
1
>
{}([
&
](
auto
mr_i
)
{
static_for
<
0
,
MRepeat
,
1
>
{}([
&
](
auto
mr_i
)
{
static_for
<
0
,
NRepeat
,
1
>
{}([
&
](
auto
nr_i
)
{
static_for
<
0
,
NRepeat
,
1
>
{}([
&
](
auto
nr_i
)
{
static_for
<
0
,
NumBlks
,
1
>
{}([
&
](
auto
blk_i
)
{
static_for
<
0
,
NumXdlops
,
1
>
{}([
&
](
auto
xdlops_i
)
{
static_for
<
0
,
BlkSize
,
1
>
{}([
&
](
auto
j
)
{
static_for
<
0
,
NumBlks
,
1
>
{}([
&
](
auto
blk_i
)
{
c_blk_buf_
(
j
)
=
static_for
<
0
,
BlkSize
,
1
>
{}([
&
](
auto
j
)
{
c_thread_buf
[
Number
<
mr_i
*
NRepeat
+
nr_i
>
{}]
c_blk_buf_
(
j
)
=
.
template
AsType
<
float
>()[
Number
<
blk_i
*
BlkSize
+
j
>
{}];
c_thread_buf
[
Number
<
(
mr_i
*
NRepeat
+
nr_i
)
*
NumXdlops
+
xdlops_i
>
{}]
.
template
AsType
<
float
>()[
Number
<
blk_i
*
BlkSize
+
j
>
{}];
});
// calculate origin of thread output tensor on global memory
// blockwise GEMM c matrix starting index
const
auto
c_thread_mtx_on_block
=
blockwise_gemm
.
CalculateCThreadOriginDataIndex
(
mr_i
,
nr_i
,
xdlops_i
,
blk_i
);
const
index_t
k_thread_data_on_global
=
m_block_data_idx_on_global
+
c_thread_mtx_on_block
[
I0
];
const
index_t
b_thread_data_on_global
=
n_block_data_idx_on_global
+
c_thread_mtx_on_block
[
I1
];
constexpr
auto
c_m0_m1_m2_n_global_tensor_iterator_hacks
=
CGlobalIteratorHacks
{};
ThreadwiseDynamicTensorSliceTransfer_v1r3
<
FloatAcc
,
FloatC
,
decltype
(
c_m0_m1_m2_n_thread_desc
),
decltype
(
c_m0_m1_m2_n_global_desc
),
Sequence
<
M0
,
1
,
M2
,
1
>
,
Sequence
<
0
,
1
,
2
,
3
>
,
// CThreadTransferSrcDstAccessOrder,
3
,
// CThreadTransferSrcDstVectorDim,
1
,
// CThreadTransferDstScalarPerVector,
CGlobalMemoryDataOperation
,
1
,
true
>
{
c_m0_m1_m2_n_global_desc
,
make_multi_index
(
k_thread_data_on_global
/
(
M2
*
M1
),
k_thread_data_on_global
%
(
M2
*
M1
)
/
M2
,
k_thread_data_on_global
%
M2
,
b_thread_data_on_global
)}
.
Run
(
c_m0_m1_m2_n_thread_desc
,
make_tuple
(
I0
,
I0
,
I0
,
I0
),
c_blk_buf_
,
c_m0_m1_m2_n_global_desc
,
c_global_buf
,
c_m0_m1_m2_n_global_tensor_iterator_hacks
);
});
});
// calculate origin of thread output tensor on global memory
// blockwise GEMM c matrix starting index
const
auto
c_thread_mtx_on_block
=
blockwise_gemm
.
CalculateCThreadOriginDataIndex
(
mr_i
,
nr_i
,
blk_i
);
const
index_t
k_thread_data_on_global
=
m_block_data_idx_on_global
+
c_thread_mtx_on_block
[
I0
];
const
index_t
b_thread_data_on_global
=
n_block_data_idx_on_global
+
c_thread_mtx_on_block
[
I1
];
constexpr
auto
c_m0_m1_m2_n_global_tensor_iterator_hacks
=
CGlobalIteratorHacks
{};
ThreadwiseDynamicTensorSliceTransfer_v1r3
<
FloatAcc
,
FloatC
,
decltype
(
c_m0_m1_m2_n_thread_desc
),
decltype
(
c_m0_m1_m2_n_global_desc
),
Sequence
<
M0
,
1
,
M2
,
1
>
,
Sequence
<
0
,
1
,
2
,
3
>
,
// CThreadTransferSrcDstAccessOrder,
3
,
// CThreadTransferSrcDstVectorDim,
1
,
// CThreadTransferDstScalarPerVector,
CGlobalMemoryDataOperation
,
1
,
true
>
{
c_m0_m1_m2_n_global_desc
,
make_multi_index
(
k_thread_data_on_global
/
(
M2
*
M1
),
k_thread_data_on_global
%
(
M2
*
M1
)
/
M2
,
k_thread_data_on_global
%
M2
,
b_thread_data_on_global
)}
.
Run
(
c_m0_m1_m2_n_thread_desc
,
make_tuple
(
I0
,
I0
,
I0
,
I0
),
c_blk_buf_
,
c_m0_m1_m2_n_global_desc
,
c_global_buf
,
c_m0_m1_m2_n_global_tensor_iterator_hacks
);
});
});
});
});
});
});
...
...
composable_kernel/include/tensor_operation/xdlops_gemm.hpp
View file @
3399ddaf
...
@@ -50,10 +50,15 @@ struct mfma_info<mfma_instr::mfma_f32_32x32x1xf32>
...
@@ -50,10 +50,15 @@ struct mfma_info<mfma_instr::mfma_f32_32x32x1xf32>
static
constexpr
index_t
cycles
=
64
;
static
constexpr
index_t
cycles
=
64
;
static
constexpr
index_t
k_base
=
1
;
static
constexpr
index_t
k_base
=
1
;
template
<
index_t
MPerXdlops
,
index_t
NPerXdlops
,
class
FloatA
,
class
FloatB
,
class
FloatC
>
template
<
index_t
MPerXdlops
,
index_t
NPerXdlops
,
index_t
COffset
,
class
FloatA
,
class
FloatB
,
class
FloatC
>
__device__
void
run
(
const
FloatA
&
a
,
const
FloatB
&
b
,
FloatC
&
reg_c
)
const
__device__
void
run
(
const
FloatA
&
a
,
const
FloatB
&
b
,
FloatC
&
reg_c
)
const
{
{
return
intrin_mfma_f32_32x32x1f32
<
MPerXdlops
,
NPerXdlops
>::
Run
(
a
,
b
,
reg_c
);
return
intrin_mfma_f32_32x32x1f32
<
MPerXdlops
,
NPerXdlops
,
COffset
>::
Run
(
a
,
b
,
reg_c
);
}
}
};
};
...
@@ -74,10 +79,15 @@ struct mfma_info<mfma_instr::mfma_f32_32x32x2xf32>
...
@@ -74,10 +79,15 @@ struct mfma_info<mfma_instr::mfma_f32_32x32x2xf32>
static
constexpr
index_t
cycles
=
64
;
static
constexpr
index_t
cycles
=
64
;
static
constexpr
index_t
k_base
=
1
;
static
constexpr
index_t
k_base
=
1
;
template
<
index_t
MPerXdlops
,
index_t
NPerXdlops
,
class
FloatA
,
class
FloatB
,
class
FloatC
>
template
<
index_t
MPerXdlops
,
index_t
NPerXdlops
,
index_t
COffset
,
class
FloatA
,
class
FloatB
,
class
FloatC
>
__device__
void
run
(
const
FloatA
&
a
,
const
FloatB
&
b
,
FloatC
&
reg_c
)
const
__device__
void
run
(
const
FloatA
&
a
,
const
FloatB
&
b
,
FloatC
&
reg_c
)
const
{
{
return
intrin_mfma_f32_32x32x2f32
<
MPerXdlops
,
NPerXdlops
>::
Run
(
a
,
b
,
reg_c
);
return
intrin_mfma_f32_32x32x2f32
<
MPerXdlops
,
NPerXdlops
,
COffset
>::
Run
(
a
,
b
,
reg_c
);
}
}
};
};
...
@@ -528,7 +538,7 @@ struct xdlops_info
...
@@ -528,7 +538,7 @@ struct xdlops_info
static
constexpr
index_t
MRepeats
=
MRepeats_
;
static
constexpr
index_t
MRepeats
=
MRepeats_
;
static
constexpr
index_t
NRepeats
=
NRepeats_
;
static
constexpr
index_t
NRepeats
=
NRepeats_
;
static
constexpr
bool
IsABroadcast
()
{
return
NPerXdlops
>=
MPerXdlops
;
}
//
static constexpr bool IsABroadcast() { return NPerXdlops >= MPerXdlops; }
static
constexpr
bool
IsKReduction
()
static
constexpr
bool
IsKReduction
()
{
{
...
@@ -743,9 +753,11 @@ struct XdlopsGemm
...
@@ -743,9 +753,11 @@ struct XdlopsGemm
using
CIndex
=
MultiIndex
<
2
>
;
using
CIndex
=
MultiIndex
<
2
>
;
__device__
static
constexpr
index_t
GetNumBlksPerXdlops
()
__device__
static
constexpr
index_t
GetNumBlks
()
{
return
mfma_type
.
num_output_blks
;
}
__device__
static
constexpr
index_t
GetNumXdlops
()
{
{
return
(
MPerXdlops
*
NPerXdlops
)
/
(
mfma_type
.
m
*
mfma_type
.
n
);
return
MPerXdlops
*
NPerXdlops
/
(
mfma_type
.
m
*
mfma_type
.
n
*
mfma_type
.
num_output_blks
);
}
}
__host__
__device__
constexpr
XdlopsGemm
()
__host__
__device__
constexpr
XdlopsGemm
()
...
@@ -791,42 +803,27 @@ struct XdlopsGemm
...
@@ -791,42 +803,27 @@ struct XdlopsGemm
static_assert
(
KPerWave
%
KPerXdlops
==
0
,
"KPerWave cannot be divided by KPerXdlops"
);
static_assert
(
KPerWave
%
KPerXdlops
==
0
,
"KPerWave cannot be divided by KPerXdlops"
);
constexpr
index_t
c_offset
=
CDesc
{}.
CalculateOffset
(
make_tuple
(
m0
,
n0
));
constexpr
index_t
c_offset
=
CDesc
{}.
CalculateOffset
(
make_tuple
(
m0
,
n0
))
*
GetNumXdlops
()
;
static_for
<
0
,
KPerWave
,
KPerXdlops
>
{}([
&
](
auto
k
)
{
static_for
<
0
,
KPerWave
,
KPerXdlops
>
{}([
&
](
auto
k
)
{
constexpr
index_t
a_offset
=
ADesc
{}.
CalculateOffset
(
make_tuple
(
k
,
m0
,
0
));
constexpr
index_t
a_offset
=
ADesc
{}.
CalculateOffset
(
make_tuple
(
k
,
m0
,
0
));
constexpr
index_t
b_offset
=
BDesc
{}.
CalculateOffset
(
make_tuple
(
k
,
n0
,
0
));
constexpr
index_t
b_offset
=
BDesc
{}.
CalculateOffset
(
make_tuple
(
k
,
n0
,
0
));
mfma_type
.
template
run
<
MPerXdlops
,
NPerXdlops
>(
p_a_wave
[
Number
<
a_offset
>
{}],
mfma_type
.
template
run
<
MPerXdlops
,
NPerXdlops
,
c_offset
>(
p_b_wave
[
Number
<
b_offset
>
{}],
p_a_wave
[
Number
<
a_offset
>
{}],
p_b_wave
[
Number
<
b_offset
>
{}],
p_c_thread
);
p_c_thread
(
Number
<
c_offset
>
{}));
});
});
}
}
__device__
static
CIndex
GetBeginOfThreadBlk
(
index_t
i
)
__device__
static
CIndex
GetBeginOfThreadBlk
(
index_t
xdlops_i
,
index_t
blk_
i
)
{
{
const
index_t
xdlops_i
=
i
/
GetNumBlksPerXdlops
();
const
index_t
j
=
i
%
GetNumBlksPerXdlops
();
const
index_t
m_i
=
xdlops_i
/
NRepeats
;
const
index_t
n_i
=
xdlops_i
%
NRepeats
;
const
index_t
laneId
=
get_thread_local_1d_id
()
%
mfma_type
.
wave_size
;
const
index_t
laneId
=
get_thread_local_1d_id
()
%
mfma_type
.
wave_size
;
const
index_t
blk_id
=
laneId
/
mfma_type
.
num_threads_blk
;
const
index_t
blk_id
=
laneId
/
mfma_type
.
num_threads_blk
;
const
index_t
blk_td
=
laneId
%
mfma_type
.
num_threads_blk
;
const
index_t
blk_td
=
laneId
%
mfma_type
.
num_threads_blk
;
index_t
col_blk
=
j
%
mfma_type
.
num_output_blks
;
index_t
n_offset
=
blk_i
*
mfma_type
.
n
+
blk_td
;
index_t
row_blk
=
j
/
mfma_type
.
num_output_blks
;
index_t
m_offset
=
xdlops_i
*
mfma_type
.
m
+
blk_id
*
mfma_type
.
group_size
;
static_if
<!
IsABroadcast
>
{}([
&
](
auto
)
{
return
CIndex
{
m_offset
,
n_offset
};
col_blk
=
j
/
mfma_type
.
num_output_blks
;
row_blk
=
j
%
mfma_type
.
num_output_blks
;
});
index_t
col
=
col_blk
*
mfma_type
.
n
+
blk_td
+
n_i
*
NPerXdlops
;
index_t
row
=
row_blk
*
mfma_type
.
m
+
blk_id
*
mfma_type
.
group_size
+
m_i
*
MPerXdlops
;
return
CIndex
{
row
,
col
};
}
}
static
constexpr
index_t
MRepeats
=
GetXdlopsInfo
().
MRepeats
;
static
constexpr
index_t
MRepeats
=
GetXdlopsInfo
().
MRepeats
;
...
@@ -834,8 +831,8 @@ struct XdlopsGemm
...
@@ -834,8 +831,8 @@ struct XdlopsGemm
static
constexpr
index_t
MPerXdlops
=
GetXdlopsInfo
().
MPerXdlops
;
static
constexpr
index_t
MPerXdlops
=
GetXdlopsInfo
().
MPerXdlops
;
static
constexpr
index_t
NPerXdlops
=
GetXdlopsInfo
().
NPerXdlops
;
static
constexpr
index_t
NPerXdlops
=
GetXdlopsInfo
().
NPerXdlops
;
static
constexpr
bool
IsKReduction
=
GetXdlopsInfo
().
IsKReduction
();
static
constexpr
bool
IsKReduction
=
GetXdlopsInfo
().
IsKReduction
();
static
constexpr
bool
IsABroadcast
=
GetXdlopsInfo
().
IsABroadcast
();
//
static constexpr bool IsABroadcast = GetXdlopsInfo().IsABroadcast();
static
constexpr
index_t
KPerXdlops
=
GetXdlopsInfo
().
GetKPerXdlops
();
static
constexpr
index_t
KPerXdlops
=
GetXdlopsInfo
().
GetKPerXdlops
();
static
constexpr
auto
GetBlkId
(
const
index_t
lane_id
)
static
constexpr
auto
GetBlkId
(
const
index_t
lane_id
)
...
@@ -850,7 +847,7 @@ struct XdlopsGemm
...
@@ -850,7 +847,7 @@ struct XdlopsGemm
static
constexpr
auto
mfma_type
=
GetXdlopsInfo
().
mfma_type
;
static
constexpr
auto
mfma_type
=
GetXdlopsInfo
().
mfma_type
;
struct
Output
Layout
struct
C
Layout
{
{
__host__
__device__
static
constexpr
index_t
M1
()
{
return
mfma_type
.
num_groups_blk
;
}
__host__
__device__
static
constexpr
index_t
M1
()
{
return
mfma_type
.
num_groups_blk
;
}
__host__
__device__
static
constexpr
index_t
M0
()
{
return
mfma_type
.
group_size
;
}
__host__
__device__
static
constexpr
index_t
M0
()
{
return
mfma_type
.
group_size
;
}
...
@@ -859,13 +856,16 @@ struct XdlopsGemm
...
@@ -859,13 +856,16 @@ struct XdlopsGemm
__device__
static
constexpr
index_t
GetBlkSize
()
{
return
mfma_type
.
num_regs_blk
;
}
__device__
static
constexpr
index_t
GetBlkSize
()
{
return
mfma_type
.
num_regs_blk
;
}
__device__
static
constexpr
index_t
GetNumBlks
()
__device__
static
constexpr
index_t
GetNumBlks
()
{
return
mfma_type
.
num_output_blks
;
}
__device__
static
constexpr
index_t
GetNumXdlops
()
{
{
return
GetNumBlksPerXdlops
()
*
MRepeats
*
NRepeats
;
return
MPerXdlops
*
NPerXdlops
/
(
mfma_type
.
m
*
mfma_type
.
n
*
mfma_type
.
num_output_blks
);
}
}
};
};
__host__
__device__
static
constexpr
auto
Get
Output
Layout
()
{
return
Output
Layout
{};
}
__host__
__device__
static
constexpr
auto
Get
C
Layout
()
{
return
C
Layout
{};
}
};
};
}
// namespace ck
}
// namespace ck
...
...
composable_kernel/include/utility/amd_xdlops.hpp
View file @
3399ddaf
...
@@ -198,7 +198,7 @@ extern "C" __device__ float16_t llvm_intrin_amdgcn_mfma_f32_16x16x2bf16(
...
@@ -198,7 +198,7 @@ extern "C" __device__ float16_t llvm_intrin_amdgcn_mfma_f32_16x16x2bf16(
extern
"C"
__device__
float4_t
llvm_intrin_amdgcn_mfma_f32_4x4x2bf16
(
extern
"C"
__device__
float4_t
llvm_intrin_amdgcn_mfma_f32_4x4x2bf16
(
ushort2_t
,
ushort2_t
,
float4_t
,
int
,
int
,
int
)
__asm
(
"llvm.amdgcn.mfma.f32.4x4x2bf16"
);
ushort2_t
,
ushort2_t
,
float4_t
,
int
,
int
,
int
)
__asm
(
"llvm.amdgcn.mfma.f32.4x4x2bf16"
);
template
<
index_t
MPerWave
,
index_t
NPerWave
>
template
<
index_t
MPerWave
,
index_t
NPerWave
,
index_t
COffset
>
struct
intrin_mfma_f32_32x32x1f32
;
struct
intrin_mfma_f32_32x32x1f32
;
// template <index_t AStride, index_t BStride>
// template <index_t AStride, index_t BStride>
...
@@ -237,16 +237,28 @@ struct intrin_mfma_f32_32x32x1f32;
...
@@ -237,16 +237,28 @@ struct intrin_mfma_f32_32x32x1f32;
//}
//}
//};
//};
template
<
>
template
<
index_t
COffset
>
struct
intrin_mfma_f32_32x32x1f32
<
64
,
64
>
struct
intrin_mfma_f32_32x32x1f32
<
64
,
64
,
COffset
>
{
{
template
<
class
FloatA
,
class
FloatB
,
class
FloatC
>
template
<
class
FloatA
,
class
FloatB
,
class
FloatC
>
__device__
static
void
Run
(
const
FloatA
&
reg_a
,
const
FloatB
&
reg_b
,
FloatC
&
reg_c
)
__device__
static
void
Run
(
const
FloatA
&
reg_a
,
const
FloatB
&
reg_b
,
FloatC
&
reg_c
)
{
{
reg_c
.
template
AsType
<
float32_t
>()(
Number
<
0
>
{})
=
llvm_intrin_amdgcn_mfma_f32_32x32x1f32
(
reg_c
(
Number
<
COffset
>
{}).
template
AsType
<
float32_t
>()(
Number
<
0
>
{})
=
reg_a
,
reg_b
,
reg_c
.
template
AsType
<
float32_t
>()[
Number
<
0
>
{}],
1
,
0
,
0
);
llvm_intrin_amdgcn_mfma_f32_32x32x1f32
(
reg_c
.
template
AsType
<
float32_t
>()(
Number
<
1
>
{})
=
llvm_intrin_amdgcn_mfma_f32_32x32x1f32
(
reg_a
,
reg_a
,
reg_b
,
reg_c
.
template
AsType
<
float32_t
>()[
Number
<
1
>
{}],
1
,
1
,
0
);
reg_b
,
reg_c
[
Number
<
COffset
>
{}].
template
AsType
<
float32_t
>()[
Number
<
0
>
{}],
1
,
0
,
0
);
reg_c
(
Number
<
COffset
+
1
>
{}).
template
AsType
<
float32_t
>()(
Number
<
0
>
{})
=
llvm_intrin_amdgcn_mfma_f32_32x32x1f32
(
reg_a
,
reg_b
,
reg_c
[
Number
<
COffset
+
1
>
{}].
template
AsType
<
float32_t
>()[
Number
<
0
>
{}],
1
,
1
,
0
);
}
}
};
};
...
@@ -272,17 +284,23 @@ struct intrin_mfma_f32_32x32x1f32<64, 64>
...
@@ -272,17 +284,23 @@ struct intrin_mfma_f32_32x32x1f32<64, 64>
//}
//}
//};
//};
template
<
index_t
MPerWave
,
index_t
NPerWave
>
template
<
index_t
MPerWave
,
index_t
NPerWave
,
index_t
COffset
>
struct
intrin_mfma_f32_32x32x2f32
;
struct
intrin_mfma_f32_32x32x2f32
;
template
<
>
template
<
index_t
COffset
>
struct
intrin_mfma_f32_32x32x2f32
<
32
,
32
>
struct
intrin_mfma_f32_32x32x2f32
<
32
,
32
,
COffset
>
{
{
template
<
class
FloatA
,
class
FloatB
,
class
FloatC
>
template
<
class
FloatA
,
class
FloatB
,
class
FloatC
>
__device__
static
void
Run
(
const
FloatA
&
reg_a
,
const
FloatB
&
reg_b
,
FloatC
&
reg_c
)
__device__
static
void
Run
(
const
FloatA
&
reg_a
,
const
FloatB
&
reg_b
,
FloatC
&
reg_c
)
{
{
reg_c
.
template
AsType
<
float16_t
>()(
Number
<
0
>
{})
=
llvm_intrin_amdgcn_mfma_f32_32x32x2f32
(
reg_c
(
Number
<
COffset
>
{}).
template
AsType
<
float16_t
>()(
Number
<
0
>
{})
=
reg_a
,
reg_b
,
reg_c
.
template
AsType
<
float16_t
>()[
Number
<
0
>
{}],
0
,
0
,
0
);
llvm_intrin_amdgcn_mfma_f32_32x32x2f32
(
reg_a
,
reg_b
,
reg_c
[
Number
<
COffset
>
{}].
template
AsType
<
float16_t
>()[
Number
<
0
>
{}],
0
,
0
,
0
);
}
}
};
};
...
...
driver/include/device_dynamic_convolution_forward_implicit_gemm_v4r4_xdlops_nchw_kcyx_nkhw.hpp
View file @
3399ddaf
...
@@ -104,25 +104,25 @@ void device_dynamic_convolution_forward_implicit_gemm_v4r4_xdlops_nchw_kcyx_nkhw
...
@@ -104,25 +104,25 @@ void device_dynamic_convolution_forward_implicit_gemm_v4r4_xdlops_nchw_kcyx_nkhw
#else
#else
constexpr
index_t
BlockSize
=
256
;
constexpr
index_t
BlockSize
=
256
;
constexpr
index_t
GemmMPerBlock
=
128
;
constexpr
index_t
GemmMPerBlock
=
256
;
constexpr
index_t
GemmNPerBlock
=
128
;
constexpr
index_t
GemmNPerBlock
=
256
;
constexpr
index_t
GemmKPerBlock
=
16
;
constexpr
index_t
GemmKPerBlock
=
16
;
constexpr
index_t
GemmMPerWave
=
64
;
constexpr
index_t
GemmMPerWave
=
32
;
constexpr
index_t
GemmNPerWave
=
64
;
constexpr
index_t
GemmNPerWave
=
32
;
constexpr
index_t
GemmKPerWave
=
4
;
constexpr
index_t
GemmKPerWave
=
4
;
constexpr
index_t
MRepeat
=
1
;
constexpr
index_t
MRepeat
=
4
;
constexpr
index_t
NRepeat
=
1
;
constexpr
index_t
NRepeat
=
4
;
using
GemmABlockTransferThreadSliceLengths_GemmK_GemmM
=
Sequence
<
4
,
2
>
;
using
GemmABlockTransferThreadSliceLengths_GemmK_GemmM
=
Sequence
<
4
,
4
>
;
using
GemmABlockTransferThreadClusterLengths_GemmK_GemmM
=
Sequence
<
4
,
64
>
;
using
GemmABlockTransferThreadClusterLengths_GemmK_GemmM
=
Sequence
<
4
,
64
>
;
constexpr
index_t
GemmABlockTransferSrcScalarPerVector_GemmK
=
4
;
constexpr
index_t
GemmABlockTransferSrcScalarPerVector_GemmK
=
4
;
constexpr
index_t
GemmABlockTransferDstScalarPerVector_GemmM
=
1
;
constexpr
index_t
GemmABlockTransferDstScalarPerVector_GemmM
=
1
;
using
GemmBBlockTransferThreadSliceLengths_GemmK_GemmN
=
Sequence
<
2
,
4
>
;
using
GemmBBlockTransferThreadSliceLengths_GemmK_GemmN
=
Sequence
<
4
,
4
>
;
using
GemmBBlockTransferThreadClusterLengths_GemmK_GemmN
=
Sequence
<
8
,
32
>
;
using
GemmBBlockTransferThreadClusterLengths_GemmK_GemmN
=
Sequence
<
4
,
64
>
;
constexpr
index_t
GemmBBlockTransferSrcScalarPerVector_GemmN
=
4
;
constexpr
index_t
GemmBBlockTransferSrcScalarPerVector_GemmN
=
4
;
constexpr
index_t
GemmBBlockTransferDstScalarPerVector_GemmN
=
4
;
constexpr
index_t
GemmBBlockTransferDstScalarPerVector_GemmN
=
4
;
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment