Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
bf445c31
"...composable_kernel_rocm.git" did not exist on "5d37d7bff4e631c3b94112c31a52f209ca39dfe2"
Commit
bf445c31
authored
Aug 31, 2023
by
Bartlomiej Wroblewski
Browse files
Review: Change names from FloatX to XDataType
parent
0ff1d1f8
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
74 additions
and
72 deletions
+74
-72
include/ck/tensor_operation/gpu/block/blockwise_gemm_dpp.hpp
include/ck/tensor_operation/gpu/block/blockwise_gemm_dpp.hpp
+16
-15
include/ck/tensor_operation/gpu/grid/gridwise_gemm_dpp.hpp
include/ck/tensor_operation/gpu/grid/gridwise_gemm_dpp.hpp
+25
-25
include/ck/tensor_operation/gpu/warp/dpp_gemm.hpp
include/ck/tensor_operation/gpu/warp/dpp_gemm.hpp
+33
-32
No files found.
include/ck/tensor_operation/gpu/block/blockwise_gemm_dpp.hpp
View file @
bf445c31
...
@@ -20,8 +20,8 @@ namespace ck {
...
@@ -20,8 +20,8 @@ namespace ck {
* `MPerBlock / (MRepeat * MPerDpp) * NPerBlock / (NRepeat * NPerDpp)` waves.
* `MPerBlock / (MRepeat * MPerDpp) * NPerBlock / (NRepeat * NPerDpp)` waves.
*/
*/
template
<
index_t
BlockSize
,
template
<
index_t
BlockSize
,
typename
FloatAB
,
typename
ABDataType
,
typename
FloatAcc
,
typename
AccDataType
,
typename
AK0MK1BlockDesc
,
typename
AK0MK1BlockDesc
,
typename
BK0NK1BlockDesc
,
typename
BK0NK1BlockDesc
,
index_t
MPerDpp
,
index_t
MPerDpp
,
...
@@ -50,7 +50,7 @@ struct BlockwiseGemmDpp_k0mk1_k0nk1_m0n0m1n1m2n2
...
@@ -50,7 +50,7 @@ struct BlockwiseGemmDpp_k0mk1_k0nk1_m0n0m1n1m2n2
static
constexpr
index_t
A_K1
=
AK0MK1BlockDesc
{}.
GetLength
(
I2
);
static
constexpr
index_t
A_K1
=
AK0MK1BlockDesc
{}.
GetLength
(
I2
);
static
constexpr
index_t
B_K1
=
BK0NK1BlockDesc
{}.
GetLength
(
I2
);
static
constexpr
index_t
B_K1
=
BK0NK1BlockDesc
{}.
GetLength
(
I2
);
static
constexpr
auto
dpp_gemm
=
DppGemm
<
FloatAB
,
MPerDpp
,
NPerDpp
,
KPack
>
{};
static
constexpr
auto
dpp_gemm
=
DppGemm
<
ABDataType
,
MPerDpp
,
NPerDpp
,
KPack
>
{};
static
constexpr
index_t
KPerThread
=
KPerBlock
/
dpp_gemm
.
K0PerDpp
;
static
constexpr
index_t
KPerThread
=
KPerBlock
/
dpp_gemm
.
K0PerDpp
;
...
@@ -58,7 +58,7 @@ struct BlockwiseGemmDpp_k0mk1_k0nk1_m0n0m1n1m2n2
...
@@ -58,7 +58,7 @@ struct BlockwiseGemmDpp_k0mk1_k0nk1_m0n0m1n1m2n2
static
constexpr
index_t
NWaves
=
NPerBlock
/
(
NRepeat
*
NPerDpp
);
static
constexpr
index_t
NWaves
=
NPerBlock
/
(
NRepeat
*
NPerDpp
);
StaticBufferTupleOfVector
<
AddressSpaceEnum
::
Vgpr
,
StaticBufferTupleOfVector
<
AddressSpaceEnum
::
Vgpr
,
FloatAcc
,
AccDataType
,
MRepeat
*
NRepeat
,
MRepeat
*
NRepeat
,
dpp_gemm
.
GetRegSizePerDpp
(),
dpp_gemm
.
GetRegSizePerDpp
(),
true
>
true
>
...
@@ -260,9 +260,9 @@ struct BlockwiseGemmDpp_k0mk1_k0nk1_m0n0m1n1m2n2
...
@@ -260,9 +260,9 @@ struct BlockwiseGemmDpp_k0mk1_k0nk1_m0n0m1n1m2n2
const
BBlockBuffer
&
b_block_buf
,
const
BBlockBuffer
&
b_block_buf
,
CThreadBuffer
&
c_thread_buf
)
const
CThreadBuffer
&
c_thread_buf
)
const
{
{
auto
a_thread_buf
=
make_static_buffer
<
AddressSpaceEnum
::
Vgpr
,
FloatAB
>
(
auto
a_thread_buf
=
make_static_buffer
<
AddressSpaceEnum
::
Vgpr
,
ABDataType
>
(
a_thread_desc_
.
GetElementSpaceSize
());
a_thread_desc_
.
GetElementSpaceSize
());
auto
b_thread_buf
=
make_static_buffer
<
AddressSpaceEnum
::
Vgpr
,
FloatAB
>
(
auto
b_thread_buf
=
make_static_buffer
<
AddressSpaceEnum
::
Vgpr
,
ABDataType
>
(
b_thread_desc_
.
GetElementSpaceSize
());
b_thread_desc_
.
GetElementSpaceSize
());
static_for
<
0
,
MRepeat
,
1
>
{}([
&
](
auto
m0
)
{
static_for
<
0
,
MRepeat
,
1
>
{}([
&
](
auto
m0
)
{
...
@@ -284,17 +284,18 @@ struct BlockwiseGemmDpp_k0mk1_k0nk1_m0n0m1n1m2n2
...
@@ -284,17 +284,18 @@ struct BlockwiseGemmDpp_k0mk1_k0nk1_m0n0m1n1m2n2
b_thread_buf
);
b_thread_buf
);
static_for
<
0
,
KPerThread
,
KPack
>
{}([
&
](
auto
k
)
{
static_for
<
0
,
KPerThread
,
KPack
>
{}([
&
](
auto
k
)
{
vector_type
<
FloatAB
,
KPack
>
a_thread_vec
;
vector_type
<
ABDataType
,
KPack
>
a_thread_vec
;
vector_type
<
FloatAB
,
KPack
>
b_thread_vec
;
vector_type
<
ABDataType
,
KPack
>
b_thread_vec
;
static_for
<
0
,
KPack
,
1
>
{}([
&
](
auto
i
)
{
static_for
<
0
,
KPack
,
1
>
{}([
&
](
auto
i
)
{
a_thread_vec
.
template
AsType
<
FloatAB
>()(
i
)
=
a_thread_buf
a_thread_vec
.
template
AsType
<
ABDataType
>()(
i
)
=
a_thread_buf
[
Number
<
a_thread_desc_
.
CalculateOffset
(
make_tuple
(
0
,
0
,
0
,
k
+
i
))
>
{}];
[
Number
<
a_thread_desc_
.
CalculateOffset
(
make_tuple
(
0
,
0
,
0
,
k
+
i
))
>
{}];
b_thread_vec
.
template
AsType
<
FloatAB
>()(
i
)
=
b_thread_buf
b_thread_vec
.
template
AsType
<
ABDataType
>()(
i
)
=
b_thread_buf
[
Number
<
b_thread_desc_
.
CalculateOffset
(
make_tuple
(
0
,
0
,
0
,
k
+
i
))
>
{}];
[
Number
<
b_thread_desc_
.
CalculateOffset
(
make_tuple
(
0
,
0
,
0
,
k
+
i
))
>
{}];
});
});
using
dpp_input_type
=
typename
vector_type
<
FloatAB
,
dpp_gemm
.
K1PerDpp
>::
type
;
using
dpp_input_type
=
typename
vector_type
<
ABDataType
,
dpp_gemm
.
K1PerDpp
>::
type
;
constexpr
index_t
c_offset
=
constexpr
index_t
c_offset
=
c_thread_desc_
.
CalculateOffset
(
make_tuple
(
m0
,
n0
,
0
));
c_thread_desc_
.
CalculateOffset
(
make_tuple
(
m0
,
n0
,
0
));
...
@@ -320,8 +321,8 @@ struct BlockwiseGemmDpp_k0mk1_k0nk1_m0n0m1n1m2n2
...
@@ -320,8 +321,8 @@ struct BlockwiseGemmDpp_k0mk1_k0nk1_m0n0m1n1m2n2
static
constexpr
auto
c_thread_desc_
=
make_naive_tensor_descriptor_packed
(
static
constexpr
auto
c_thread_desc_
=
make_naive_tensor_descriptor_packed
(
make_tuple
(
Number
<
MRepeat
>
{},
Number
<
NRepeat
>
{},
dpp_gemm
.
GetRegSizePerDpp
()));
make_tuple
(
Number
<
MRepeat
>
{},
Number
<
NRepeat
>
{},
dpp_gemm
.
GetRegSizePerDpp
()));
using
AThreadCopy
=
ThreadwiseTensorSliceTransfer_v4
<
FloatAB
,
using
AThreadCopy
=
ThreadwiseTensorSliceTransfer_v4
<
ABDataType
,
FloatAB
,
ABDataType
,
decltype
(
a_block_desc_m0_m1_m2_k
),
decltype
(
a_block_desc_m0_m1_m2_k
),
decltype
(
a_thread_desc_
),
decltype
(
a_thread_desc_
),
Sequence
<
1
,
1
,
1
,
KPerThread
>
,
Sequence
<
1
,
1
,
1
,
KPerThread
>
,
...
@@ -330,8 +331,8 @@ struct BlockwiseGemmDpp_k0mk1_k0nk1_m0n0m1n1m2n2
...
@@ -330,8 +331,8 @@ struct BlockwiseGemmDpp_k0mk1_k0nk1_m0n0m1n1m2n2
A_K1
,
A_K1
,
A_K1
>
;
A_K1
>
;
using
BThreadCopy
=
ThreadwiseTensorSliceTransfer_v4
<
FloatAB
,
using
BThreadCopy
=
ThreadwiseTensorSliceTransfer_v4
<
ABDataType
,
FloatAB
,
ABDataType
,
decltype
(
b_block_desc_n0_n1_n2_k
),
decltype
(
b_block_desc_n0_n1_n2_k
),
decltype
(
b_thread_desc_
),
decltype
(
b_thread_desc_
),
Sequence
<
1
,
1
,
1
,
KPerThread
>
,
Sequence
<
1
,
1
,
1
,
KPerThread
>
,
...
...
include/ck/tensor_operation/gpu/grid/gridwise_gemm_dpp.hpp
View file @
bf445c31
...
@@ -51,9 +51,9 @@ __global__ void
...
@@ -51,9 +51,9 @@ __global__ void
}
}
template
<
index_t
BlockSize
,
template
<
index_t
BlockSize
,
typename
FloatAB
,
typename
ABDataType
,
typename
FloatAcc
,
typename
AccDataType
,
typename
FloatC
,
typename
CDataType
,
InMemoryDataOperationEnum
CGlobalMemoryDataOperation
,
InMemoryDataOperationEnum
CGlobalMemoryDataOperation
,
typename
ALayout
,
typename
ALayout
,
typename
BLayout
,
typename
BLayout
,
...
@@ -172,9 +172,9 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_dpp
...
@@ -172,9 +172,9 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_dpp
// Argument
// Argument
struct
Argument
:
public
Problem
,
public
tensor_operation
::
device
::
BaseArgument
struct
Argument
:
public
Problem
,
public
tensor_operation
::
device
::
BaseArgument
{
{
__host__
Argument
(
const
FloatAB
*
p_a_grid_
,
__host__
Argument
(
const
ABDataType
*
p_a_grid_
,
const
FloatAB
*
p_b_grid_
,
const
ABDataType
*
p_b_grid_
,
FloatC
*
p_c_grid_
,
CDataType
*
p_c_grid_
,
index_t
M_
,
index_t
M_
,
index_t
N_
,
index_t
N_
,
index_t
K_
,
index_t
K_
,
...
@@ -188,9 +188,9 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_dpp
...
@@ -188,9 +188,9 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_dpp
{
{
}
}
const
FloatAB
*
p_a_grid
;
const
ABDataType
*
p_a_grid
;
const
FloatAB
*
p_b_grid
;
const
ABDataType
*
p_b_grid
;
FloatC
*
p_c_grid
;
CDataType
*
p_c_grid
;
};
};
using
GridwiseGemmPipe
=
remove_cvref_t
<
using
GridwiseGemmPipe
=
remove_cvref_t
<
...
@@ -252,7 +252,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_dpp
...
@@ -252,7 +252,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_dpp
constexpr
auto
b_block_space_size_aligned
=
constexpr
auto
b_block_space_size_aligned
=
math
::
integer_least_multiple
(
b_block_desc_k0_n_k1
.
GetElementSpaceSize
(),
max_lds_align
);
math
::
integer_least_multiple
(
b_block_desc_k0_n_k1
.
GetElementSpaceSize
(),
max_lds_align
);
return
(
a_block_space_size_aligned
+
b_block_space_size_aligned
)
*
sizeof
(
FloatAB
);
return
(
a_block_space_size_aligned
+
b_block_space_size_aligned
)
*
sizeof
(
ABDataType
);
}
}
__host__
static
constexpr
bool
CheckValidity
(
const
Problem
&
problem
)
__host__
static
constexpr
bool
CheckValidity
(
const
Problem
&
problem
)
...
@@ -347,8 +347,8 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_dpp
...
@@ -347,8 +347,8 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_dpp
using
BlockwiseGemm
=
using
BlockwiseGemm
=
BlockwiseGemmDpp_k0mk1_k0nk1_m0n0m1n1m2n2
<
BlockSize
,
BlockwiseGemmDpp_k0mk1_k0nk1_m0n0m1n1m2n2
<
BlockSize
,
FloatAB
,
ABDataType
,
FloatAcc
,
AccDataType
,
decltype
(
a_block_desc_k0_m_k1
),
decltype
(
a_block_desc_k0_m_k1
),
decltype
(
b_block_desc_k0_n_k1
),
decltype
(
b_block_desc_k0_n_k1
),
MPerDpp
,
MPerDpp
,
...
@@ -430,9 +430,9 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_dpp
...
@@ -430,9 +430,9 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_dpp
typename
AGridDesc_K0_M_K1
,
typename
AGridDesc_K0_M_K1
,
typename
BGridDesc_K0_N_K1
,
typename
BGridDesc_K0_N_K1
,
typename
CGridDesc_M_N
>
typename
CGridDesc_M_N
>
__device__
static
void
Run
(
const
FloatAB
*
__restrict__
p_a_grid
,
__device__
static
void
Run
(
const
ABDataType
*
__restrict__
p_a_grid
,
const
FloatAB
*
__restrict__
p_b_grid
,
const
ABDataType
*
__restrict__
p_b_grid
,
FloatC
*
__restrict__
p_c_grid
,
CDataType
*
__restrict__
p_c_grid
,
void
*
__restrict__
p_shared
,
void
*
__restrict__
p_shared
,
const
AGridDesc_K0_M_K1
&
a_grid_desc_k0_m_k1
,
const
AGridDesc_K0_M_K1
&
a_grid_desc_k0_m_k1
,
const
BGridDesc_K0_N_K1
&
b_grid_desc_k0_n_k1
,
const
BGridDesc_K0_N_K1
&
b_grid_desc_k0_n_k1
,
...
@@ -488,8 +488,8 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_dpp
...
@@ -488,8 +488,8 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_dpp
Sequence
<
K0PerBlock
,
MPerBlock
,
K1
>
,
Sequence
<
K0PerBlock
,
MPerBlock
,
K1
>
,
ABlockTransferThreadClusterLengths_K0_M_K1
,
ABlockTransferThreadClusterLengths_K0_M_K1
,
ABlockTransferThreadClusterArrangeOrder
,
ABlockTransferThreadClusterArrangeOrder
,
FloatAB
,
ABDataType
,
FloatAB
,
ABDataType
,
decltype
(
a_grid_desc_k0_m_k1
),
decltype
(
a_grid_desc_k0_m_k1
),
decltype
(
a_block_desc_k0_m_k1
),
decltype
(
a_block_desc_k0_m_k1
),
ABlockTransferSrcAccessOrder
,
ABlockTransferSrcAccessOrder
,
...
@@ -518,8 +518,8 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_dpp
...
@@ -518,8 +518,8 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_dpp
Sequence
<
K0PerBlock
,
NPerBlock
,
K1
>
,
Sequence
<
K0PerBlock
,
NPerBlock
,
K1
>
,
BBlockTransferThreadClusterLengths_K0_N_K1
,
BBlockTransferThreadClusterLengths_K0_N_K1
,
BBlockTransferThreadClusterArrangeOrder
,
BBlockTransferThreadClusterArrangeOrder
,
FloatAB
,
ABDataType
,
FloatAB
,
ABDataType
,
decltype
(
b_grid_desc_k0_n_k1
),
decltype
(
b_grid_desc_k0_n_k1
),
decltype
(
b_block_desc_k0_n_k1
),
decltype
(
b_block_desc_k0_n_k1
),
BBlockTransferSrcAccessOrder
,
BBlockTransferSrcAccessOrder
,
...
@@ -548,8 +548,8 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_dpp
...
@@ -548,8 +548,8 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_dpp
// register
// register
auto
blockwise_gemm
=
auto
blockwise_gemm
=
BlockwiseGemmDpp_k0mk1_k0nk1_m0n0m1n1m2n2
<
BlockSize
,
BlockwiseGemmDpp_k0mk1_k0nk1_m0n0m1n1m2n2
<
BlockSize
,
FloatAB
,
ABDataType
,
FloatAcc
,
AccDataType
,
decltype
(
a_block_desc_k0_m_k1
),
decltype
(
a_block_desc_k0_m_k1
),
decltype
(
b_block_desc_k0_n_k1
),
decltype
(
b_block_desc_k0_n_k1
),
MPerDpp
,
MPerDpp
,
...
@@ -565,10 +565,10 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_dpp
...
@@ -565,10 +565,10 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_dpp
math
::
integer_least_multiple
(
a_block_desc_k0_m_k1
.
GetElementSpaceSize
(),
max_lds_align
);
math
::
integer_least_multiple
(
a_block_desc_k0_m_k1
.
GetElementSpaceSize
(),
max_lds_align
);
auto
a_block_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Lds
>
(
auto
a_block_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Lds
>
(
static_cast
<
FloatAB
*>
(
p_shared
),
a_block_desc_k0_m_k1
.
GetElementSpaceSize
());
static_cast
<
ABDataType
*>
(
p_shared
),
a_block_desc_k0_m_k1
.
GetElementSpaceSize
());
auto
b_block_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Lds
>
(
auto
b_block_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Lds
>
(
static_cast
<
FloatAB
*>
(
p_shared
)
+
a_block_space_size_aligned
,
static_cast
<
ABDataType
*>
(
p_shared
)
+
a_block_space_size_aligned
,
b_block_desc_k0_n_k1
.
GetElementSpaceSize
());
b_block_desc_k0_n_k1
.
GetElementSpaceSize
());
constexpr
auto
a_block_slice_copy_step
=
make_multi_index
(
K0PerBlock
,
0
,
0
);
constexpr
auto
a_block_slice_copy_step
=
make_multi_index
(
K0PerBlock
,
0
,
0
);
...
@@ -642,8 +642,8 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_dpp
...
@@ -642,8 +642,8 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_dpp
make_multi_index
(
n_thread_data_on_grid
));
make_multi_index
(
n_thread_data_on_grid
));
auto
c_thread_copy
=
auto
c_thread_copy
=
ThreadwiseTensorSliceTransfer_v1r3
<
FloatAcc
,
ThreadwiseTensorSliceTransfer_v1r3
<
AccDataType
,
FloatC
,
CDataType
,
decltype
(
c_thread_desc_m0_n0_m1_n1_m2_n2
),
decltype
(
c_thread_desc_m0_n0_m1_n1_m2_n2
),
decltype
(
c_grid_desc_m0_n0_m1_n1_m2_n2
),
decltype
(
c_grid_desc_m0_n0_m1_n1_m2_n2
),
CElementwiseOperation
,
CElementwiseOperation
,
...
...
include/ck/tensor_operation/gpu/warp/dpp_gemm.hpp
View file @
bf445c31
...
@@ -54,18 +54,18 @@ struct dpp_type<DppInstr::dpp8_f16_32x8x2>
...
@@ -54,18 +54,18 @@ struct dpp_type<DppInstr::dpp8_f16_32x8x2>
static
constexpr
index_t
n_per_thread
=
1
;
static
constexpr
index_t
n_per_thread
=
1
;
static
constexpr
index_t
k_per_dpp
=
2
;
static
constexpr
index_t
k_per_dpp
=
2
;
static
constexpr
bool
share_a
=
true
;
static
constexpr
bool
share_a
=
true
;
using
b
ase
_t
ype
=
half_t
;
using
B
ase
T
ype
=
half_t
;
template
<
index_t
MPerDpp
,
index_t
NPerDpp
,
class
FloatA
,
class
FloatB
,
class
FloatC
>
template
<
index_t
MPerDpp
,
index_t
NPerDpp
,
class
ADataType
,
class
BDataType
,
class
CDataType
>
__device__
void
run
(
const
FloatA
&
a
,
const
FloatB
&
b
,
FloatC
&
reg_c
)
const
__device__
void
run
(
const
ADataType
&
a
,
const
BDataType
&
b
,
CDataType
&
reg_c
)
const
{
{
dpp8
::
DppInstrRunner
<
m_per_thread
,
dpp8
::
DppInstrRunner
<
m_per_thread
,
n_per_thread
,
n_per_thread
,
k_per_dpp
,
k_per_dpp
,
b
ase
_t
ype
,
B
ase
T
ype
,
FloatA
,
ADataType
,
FloatB
,
BDataType
,
FloatC
,
CDataType
,
share_a
>
{}
share_a
>
{}
.
Run
(
a
,
b
,
reg_c
);
.
Run
(
a
,
b
,
reg_c
);
}
}
...
@@ -84,18 +84,18 @@ struct dpp_type<DppInstr::dpp8_f16_8x32x2>
...
@@ -84,18 +84,18 @@ struct dpp_type<DppInstr::dpp8_f16_8x32x2>
static
constexpr
index_t
n_per_thread
=
1
;
static
constexpr
index_t
n_per_thread
=
1
;
static
constexpr
index_t
k_per_dpp
=
2
;
static
constexpr
index_t
k_per_dpp
=
2
;
static
constexpr
bool
share_a
=
true
;
static
constexpr
bool
share_a
=
true
;
using
b
ase
_t
ype
=
half_t
;
using
B
ase
T
ype
=
half_t
;
template
<
index_t
MPerDpp
,
index_t
NPerDpp
,
class
FloatA
,
class
FloatB
,
class
FloatC
>
template
<
index_t
MPerDpp
,
index_t
NPerDpp
,
class
ADataType
,
class
BDataType
,
class
CDataType
>
__device__
void
run
(
const
FloatA
&
a
,
const
FloatB
&
b
,
FloatC
&
reg_c
)
const
__device__
void
run
(
const
ADataType
&
a
,
const
BDataType
&
b
,
CDataType
&
reg_c
)
const
{
{
dpp8
::
DppInstrRunner
<
m_per_thread
,
dpp8
::
DppInstrRunner
<
m_per_thread
,
n_per_thread
,
n_per_thread
,
k_per_dpp
,
k_per_dpp
,
b
ase
_t
ype
,
B
ase
T
ype
,
FloatA
,
ADataType
,
FloatB
,
BDataType
,
FloatC
,
CDataType
,
share_a
>
{}
share_a
>
{}
.
Run
(
a
,
b
,
reg_c
);
.
Run
(
a
,
b
,
reg_c
);
}
}
...
@@ -114,27 +114,27 @@ struct dpp_type<DppInstr::dpp8_f16_16x16x2>
...
@@ -114,27 +114,27 @@ struct dpp_type<DppInstr::dpp8_f16_16x16x2>
static
constexpr
index_t
n_per_thread
=
1
;
static
constexpr
index_t
n_per_thread
=
1
;
static
constexpr
index_t
k_per_dpp
=
2
;
static
constexpr
index_t
k_per_dpp
=
2
;
static
constexpr
bool
share_a
=
true
;
static
constexpr
bool
share_a
=
true
;
using
b
ase
_t
ype
=
half_t
;
using
B
ase
T
ype
=
half_t
;
template
<
index_t
MPerDpp
,
index_t
NPerDpp
,
class
FloatA
,
class
FloatB
,
class
FloatC
>
template
<
index_t
MPerDpp
,
index_t
NPerDpp
,
class
ADataType
,
class
BDataType
,
class
CDataType
>
__device__
void
run
(
const
FloatA
&
a
,
const
FloatB
&
b
,
FloatC
&
reg_c
)
const
__device__
void
run
(
const
ADataType
&
a
,
const
BDataType
&
b
,
CDataType
&
reg_c
)
const
{
{
dpp8
::
DppInstrRunner
<
m_per_thread
,
dpp8
::
DppInstrRunner
<
m_per_thread
,
n_per_thread
,
n_per_thread
,
k_per_dpp
,
k_per_dpp
,
b
ase
_t
ype
,
B
ase
T
ype
,
FloatA
,
ADataType
,
FloatB
,
BDataType
,
FloatC
,
CDataType
,
share_a
>
{}
share_a
>
{}
.
Run
(
a
,
b
,
reg_c
);
.
Run
(
a
,
b
,
reg_c
);
}
}
};
};
template
<
typename
b
ase
_t
ype
,
index_t
MPerDpp
,
index_t
NPerDpp
>
template
<
typename
B
ase
T
ype
,
index_t
MPerDpp
,
index_t
NPerDpp
>
struct
DppSelector
struct
DppSelector
{
{
template
<
typename
b
ase
_t
ype_
,
index_t
MPerDpp_
,
index_t
NPerDpp_
>
template
<
typename
B
ase
T
ype_
,
index_t
MPerDpp_
,
index_t
NPerDpp_
>
static
constexpr
auto
GetDpp
();
static
constexpr
auto
GetDpp
();
template
<
>
template
<
>
...
@@ -155,7 +155,7 @@ struct DppSelector
...
@@ -155,7 +155,7 @@ struct DppSelector
return
DppInstr
::
dpp8_f16_32x8x2
;
return
DppInstr
::
dpp8_f16_32x8x2
;
}
}
static
constexpr
auto
selected_dpp
=
dpp_type
<
GetDpp
<
b
ase
_t
ype
,
MPerDpp
,
NPerDpp
>
()
>
{};
static
constexpr
auto
selected_dpp
=
dpp_type
<
GetDpp
<
B
ase
T
ype
,
MPerDpp
,
NPerDpp
>
()
>
{};
__host__
__device__
constexpr
DppSelector
()
__host__
__device__
constexpr
DppSelector
()
{
{
...
@@ -200,7 +200,7 @@ struct DppSelector
...
@@ -200,7 +200,7 @@ struct DppSelector
static
constexpr
index_t
GetK1PerDpp
()
{
return
selected_dpp
.
k_per_dpp
;
}
static
constexpr
index_t
GetK1PerDpp
()
{
return
selected_dpp
.
k_per_dpp
;
}
};
};
template
<
typename
b
ase
_t
ype
,
index_t
MPerDpp
,
index_t
NPerDpp
,
index_t
KPack
>
template
<
typename
B
ase
T
ype
,
index_t
MPerDpp
,
index_t
NPerDpp
,
index_t
KPack
>
struct
DppGemm
struct
DppGemm
{
{
static
constexpr
auto
I0
=
Number
<
0
>
{};
static
constexpr
auto
I0
=
Number
<
0
>
{};
...
@@ -228,13 +228,14 @@ struct DppGemm
...
@@ -228,13 +228,14 @@ struct DppGemm
return
MPerDpp
*
NPerDpp
/
dpp_instr
.
wave_size
;
return
MPerDpp
*
NPerDpp
/
dpp_instr
.
wave_size
;
}
}
template
<
class
FloatA
,
class
FloatB
,
class
FloatC
>
template
<
class
ADataType
,
class
BDataType
,
class
CDataType
>
__device__
void
Run
(
const
FloatA
&
p_a_wave
,
const
FloatB
&
p_b_wave
,
FloatC
&
p_c_thread
)
const
__device__
void
Run
(
const
ADataType
&
p_a_wave
,
const
BDataType
&
p_b_wave
,
CDataType
&
p_c_thread
)
const
{
{
static_assert
(
is_same
<
b
ase
_t
ype
,
double
>::
value
||
is_same
<
b
ase
_t
ype
,
float
>::
value
||
static_assert
(
is_same
<
B
ase
T
ype
,
double
>::
value
||
is_same
<
B
ase
T
ype
,
float
>::
value
||
is_same
<
b
ase
_t
ype
,
half_t
>::
value
||
is_same
<
b
ase
_t
ype
,
bhalf_t
>::
value
||
is_same
<
B
ase
T
ype
,
half_t
>::
value
||
is_same
<
B
ase
T
ype
,
bhalf_t
>::
value
||
is_same
<
b
ase
_t
ype
,
int8_t
>::
value
||
is_same
<
b
ase
_t
ype
,
f8_t
>::
value
,
is_same
<
B
ase
T
ype
,
int8_t
>::
value
||
is_same
<
B
ase
T
ype
,
f8_t
>::
value
,
"base
b
ase
_t
ype must be double, float, half, bfloat16, and int8_t!"
);
"base
B
ase
T
ype must be double, float, half, bfloat16, and int8_t!"
);
static_for
<
0
,
KPack
/
dpp_instr
.
k_per_dpp
,
1
>
{}([
&
](
auto
k
)
{
static_for
<
0
,
KPack
/
dpp_instr
.
k_per_dpp
,
1
>
{}([
&
](
auto
k
)
{
dpp_instr
.
template
run
<
MPerDpp
,
NPerDpp
>(
p_a_wave
[
k
],
p_b_wave
[
k
],
p_c_thread
);
dpp_instr
.
template
run
<
MPerDpp
,
NPerDpp
>(
p_a_wave
[
k
],
p_b_wave
[
k
],
p_c_thread
);
...
@@ -305,7 +306,7 @@ struct DppGemm
...
@@ -305,7 +306,7 @@ struct DppGemm
return
CIndex
{
m_offset
,
n_offset
};
return
CIndex
{
m_offset
,
n_offset
};
}
}
static
constexpr
auto
dpp
=
DppSelector
<
b
ase
_t
ype
,
MPerDpp
,
NPerDpp
>
{};
static
constexpr
auto
dpp
=
DppSelector
<
B
ase
T
ype
,
MPerDpp
,
NPerDpp
>
{};
static
constexpr
auto
dpp_instr
=
dpp
.
selected_dpp
;
static
constexpr
auto
dpp_instr
=
dpp
.
selected_dpp
;
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment