Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
01192e26
Commit
01192e26
authored
Jan 07, 2022
by
Jing Zhang
Browse files
test mfma builtins
parent
acbd7bd7
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
51 additions
and
56 deletions
+51
-56
composable_kernel/include/utility/amd_xdlops.hpp
composable_kernel/include/utility/amd_xdlops.hpp
+9
-26
host/driver_offline/include/device_gemm_xdlops_mk_kn_mn.hpp
host/driver_offline/include/device_gemm_xdlops_mk_kn_mn.hpp
+3
-3
host/driver_offline/include/driver_gemm_xdlops_v2r3.hpp
host/driver_offline/include/driver_gemm_xdlops_v2r3.hpp
+23
-11
script/profile_gemm.sh
script/profile_gemm.sh
+16
-16
No files found.
composable_kernel/include/utility/amd_xdlops.hpp
View file @
01192e26
...
...
@@ -6,22 +6,6 @@
namespace
ck
{
// A, B, C, cbsz, abid, blgp
// fp32
extern
"C"
__device__
float32_t
llvm_intrin_amdgcn_mfma_f32_32x32x1f32
(
float
,
float
,
float32_t
,
int
,
int
,
int
)
__asm
(
"llvm.amdgcn.mfma.f32.32x32x1f32"
);
extern
"C"
__device__
float16_t
llvm_intrin_amdgcn_mfma_f32_32x32x2f32
(
float
,
float
,
float16_t
,
int
,
int
,
int
)
__asm
(
"llvm.amdgcn.mfma.f32.32x32x2f32"
);
extern
"C"
__device__
float4_t
llvm_intrin_amdgcn_mfma_f32_16x16x4f32
(
float
,
float
,
float4_t
,
int
,
int
,
int
)
__asm
(
"llvm.amdgcn.mfma.f32.16x16x4f32"
);
extern
"C"
__device__
float16_t
llvm_intrin_amdgcn_mfma_f32_16x16x1f32
(
float
,
float
,
float16_t
,
int
,
int
,
int
)
__asm
(
"llvm.amdgcn.mfma.f32.16x16x1f32"
);
extern
"C"
__device__
float4_t
llvm_intrin_amdgcn_mfma_f32_4x4x1f32
(
float
,
float
,
float4_t
,
int
,
int
,
int
)
__asm
(
"llvm.amdgcn.mfma.f32.4x4x1f32"
);
// fp16
extern
"C"
__device__
float32_t
llvm_intrin_amdgcn_mfma_f32_32x32x4f16
(
half4_t
,
half4_t
,
float32_t
,
int
,
int
,
int
)
__asm
(
"llvm.amdgcn.mfma.f32.32x32x4f16"
);
...
...
@@ -86,9 +70,9 @@ struct intrin_mfma_f32_32x32x1f32<64, 64>
template
<
class
FloatC
>
__device__
static
void
Run
(
const
float
&
reg_a
,
const
float
&
reg_b
,
FloatC
&
reg_c
)
{
reg_c
.
template
AsType
<
float32_t
>()(
Number
<
0
>
{})
=
llvm_intr
in_amdgcn_mfma_f32_32x32x1f32
(
reg_c
.
template
AsType
<
float32_t
>()(
Number
<
0
>
{})
=
__built
in_amdgcn_mfma_f32_32x32x1f32
(
reg_a
,
reg_b
,
reg_c
.
template
AsType
<
float32_t
>()[
Number
<
0
>
{}],
1
,
0
,
0
);
reg_c
.
template
AsType
<
float32_t
>()(
Number
<
1
>
{})
=
llvm_intr
in_amdgcn_mfma_f32_32x32x1f32
(
reg_c
.
template
AsType
<
float32_t
>()(
Number
<
1
>
{})
=
__built
in_amdgcn_mfma_f32_32x32x1f32
(
reg_a
,
reg_b
,
reg_c
.
template
AsType
<
float32_t
>()[
Number
<
1
>
{}],
1
,
1
,
0
);
}
};
...
...
@@ -99,7 +83,7 @@ struct intrin_mfma_f32_32x32x1f32<32, 64>
template
<
class
FloatC
>
__device__
static
void
Run
(
const
float
&
reg_a
,
const
float
&
reg_b
,
FloatC
&
reg_c
)
{
reg_c
.
template
AsType
<
float32_t
>()(
Number
<
0
>
{})
=
llvm_intr
in_amdgcn_mfma_f32_32x32x1f32
(
reg_c
.
template
AsType
<
float32_t
>()(
Number
<
0
>
{})
=
__built
in_amdgcn_mfma_f32_32x32x1f32
(
reg_a
,
reg_b
,
reg_c
.
template
AsType
<
float32_t
>()[
Number
<
0
>
{}],
1
,
0
,
0
);
}
};
...
...
@@ -113,7 +97,7 @@ struct intrin_mfma_f32_32x32x2f32<32, 32>
template
<
class
FloatC
>
__device__
static
void
Run
(
const
float
&
reg_a
,
const
float
&
reg_b
,
FloatC
&
reg_c
)
{
reg_c
.
template
AsType
<
float16_t
>()(
Number
<
0
>
{})
=
llvm_intr
in_amdgcn_mfma_f32_32x32x2f32
(
reg_c
.
template
AsType
<
float16_t
>()(
Number
<
0
>
{})
=
__built
in_amdgcn_mfma_f32_32x32x2f32
(
reg_a
,
reg_b
,
reg_c
.
template
AsType
<
float16_t
>()[
Number
<
0
>
{}],
0
,
0
,
0
);
}
};
...
...
@@ -127,7 +111,7 @@ struct intrin_mfma_f32_16x16x4f32<16, 16>
template
<
class
FloatC
>
__device__
static
void
Run
(
const
float
&
reg_a
,
const
float
&
reg_b
,
FloatC
&
reg_c
)
{
reg_c
.
template
AsType
<
float4_t
>()(
Number
<
0
>
{})
=
llvm_intr
in_amdgcn_mfma_f32_16x16x4f32
(
reg_c
.
template
AsType
<
float4_t
>()(
Number
<
0
>
{})
=
__built
in_amdgcn_mfma_f32_16x16x4f32
(
reg_a
,
reg_b
,
reg_c
.
template
AsType
<
float4_t
>()[
Number
<
0
>
{}],
0
,
0
,
0
);
}
};
...
...
@@ -141,8 +125,7 @@ struct intrin_mfma_f32_16x16x1f32<16, 64>
template
<
class
FloatC
>
__device__
static
void
Run
(
const
float
&
reg_a
,
const
float
&
reg_b
,
FloatC
&
reg_c
)
{
reg_c
.
template
AsType
<
float16_t
>()(
Number
<
0
>
{})
=
llvm_intrin_amdgcn_mfma_f32_16x16x1f32
(
reg_c
.
template
AsType
<
float16_t
>()(
Number
<
0
>
{})
=
__builtin_amdgcn_mfma_f32_16x16x1f32
(
reg_a
,
reg_b
,
reg_c
.
template
AsType
<
float16_t
>()[
Number
<
0
>
{}],
2
,
0
,
0
);
}
};
...
...
@@ -156,7 +139,7 @@ struct intrin_mfma_f32_4x4x1f32<4, 64>
template
<
class
FloatC
>
__device__
static
void
Run
(
const
float
&
reg_a
,
const
float
&
reg_b
,
FloatC
&
reg_c
)
{
reg_c
.
template
AsType
<
float4_t
>()(
Number
<
0
>
{})
=
llvm_intr
in_amdgcn_mfma_f32_4x4x1f32
(
reg_c
.
template
AsType
<
float4_t
>()(
Number
<
0
>
{})
=
__built
in_amdgcn_mfma_f32_4x4x1f32
(
reg_a
,
reg_b
,
reg_c
.
template
AsType
<
float4_t
>()[
Number
<
0
>
{}],
4
,
0
,
0
);
}
};
...
...
@@ -167,9 +150,9 @@ struct intrin_mfma_f32_4x4x1f32<8, 64>
template
<
class
FloatC
>
__device__
static
void
Run
(
const
float
&
reg_a
,
const
float
&
reg_b
,
FloatC
&
reg_c
)
{
reg_c
.
template
AsType
<
float4_t
>()(
Number
<
0
>
{})
=
llvm_intr
in_amdgcn_mfma_f32_4x4x1f32
(
reg_c
.
template
AsType
<
float4_t
>()(
Number
<
0
>
{})
=
__built
in_amdgcn_mfma_f32_4x4x1f32
(
reg_a
,
reg_b
,
reg_c
.
template
AsType
<
float4_t
>()[
Number
<
0
>
{}],
4
,
0
,
0
);
reg_c
.
template
AsType
<
float4_t
>()(
Number
<
1
>
{})
=
llvm_intr
in_amdgcn_mfma_f32_4x4x1f32
(
reg_c
.
template
AsType
<
float4_t
>()(
Number
<
1
>
{})
=
__built
in_amdgcn_mfma_f32_4x4x1f32
(
reg_a
,
reg_b
,
reg_c
.
template
AsType
<
float4_t
>()[
Number
<
1
>
{}],
4
,
1
,
0
);
}
};
...
...
host/driver_offline/include/device_gemm_xdlops_mk_kn_mn.hpp
View file @
01192e26
...
...
@@ -332,9 +332,9 @@ void device_gemm_xdlops_mk_kn_mn(const Tensor<ABType>& a_m_k,
constexpr
index_t
CThreadTransferDstScalarPerVector
=
1
;
#endif
const
auto
K
=
a_m_k
.
mDesc
.
GetLengths
()[
1
];
const
auto
M
=
a_m_k
.
mDesc
.
GetLengths
()[
0
];
const
auto
N
=
b_k_n
.
mDesc
.
GetLengths
()[
1
];
const
index_t
K
=
a_m_k
.
mDesc
.
GetLengths
()[
1
];
const
index_t
M
=
a_m_k
.
mDesc
.
GetLengths
()[
0
];
const
index_t
N
=
b_k_n
.
mDesc
.
GetLengths
()[
1
];
constexpr
auto
K1Number
=
Number
<
K1
>
{};
const
auto
K0
=
K
/
K1Number
;
...
...
host/driver_offline/include/driver_gemm_xdlops_v2r3.hpp
View file @
01192e26
...
...
@@ -5,6 +5,7 @@
#include "tensor_descriptor.hpp"
#include "tensor_descriptor_helper.hpp"
#include "gridwise_gemm_xdlops_v2r3.hpp"
#include "element_wise_operation.hpp"
template
<
ck
::
index_t
BlockSize
,
typename
FloatAB
,
...
...
@@ -70,6 +71,8 @@ __host__ float driver_gemm_xdlops_v2r3(const FloatAB* p_a_grid,
constexpr
auto
I1
=
Number
<
1
>
{};
constexpr
auto
I2
=
Number
<
2
>
{};
using
ElementwiseOperation
=
ck
::
tensor_operation
::
element_wise
::
PassThrough
;
using
GridwiseGemm
=
GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3
<
BlockSize
,
FloatAB
,
...
...
@@ -79,6 +82,9 @@ __host__ float driver_gemm_xdlops_v2r3(const FloatAB* p_a_grid,
AGridDesc_K0_M_K1
,
BGridDesc_K0_N_K
,
CMNGridDesc
,
ElementwiseOperation
,
ElementwiseOperation
,
ElementwiseOperation
,
MPerBlock
,
NPerBlock
,
KPerBlock
,
...
...
@@ -87,7 +93,6 @@ __host__ float driver_gemm_xdlops_v2r3(const FloatAB* p_a_grid,
K1
,
MRepeat
,
NRepeat
,
ABlockTransferThreadSliceLengths_K0_M_K1
,
ABlockTransferThreadClusterLengths_K0_M_K1
,
ABlockTransferThreadClusterArrangeOrder
,
ABlockTransferSrcAccessOrder
,
...
...
@@ -95,7 +100,7 @@ __host__ float driver_gemm_xdlops_v2r3(const FloatAB* p_a_grid,
ABlockTransferSrcScalarPerVector
,
ABlockTransferDstScalarPerVector_K1
,
AThreadTransferSrcResetCoordinateAfterRun
,
B
Block
TransferThreadSliceLengths_K0_N_K1
,
A
Block
LdsAddExtraM
,
BBlockTransferThreadClusterLengths_K0_N_K1
,
BBlockTransferThreadClusterArrangeOrder
,
BBlockTransferSrcAccessOrder
,
...
...
@@ -103,17 +108,10 @@ __host__ float driver_gemm_xdlops_v2r3(const FloatAB* p_a_grid,
BBlockTransferSrcScalarPerVector
,
BBlockTransferDstScalarPerVector_K1
,
BThreadTransferSrcResetCoordinateAfterRun
,
BBlockLdsAddExtraN
,
CThreadTransferSrcDstAccessOrder
,
CThreadTransferSrcDstVectorDim
,
CThreadTransferDstScalarPerVector
,
AGridStepHacks
,
BGridStepHacks
,
CGridStepHacks
,
AGridMoveSliceWindowStepHacks
,
BGridMoveSliceWindowStepHacks
,
CAccessOrderMRepeatNRepeat
,
ABlockLdsAddExtraM
,
BBlockLdsAddExtraN
>
;
CThreadTransferDstScalarPerVector
>
;
{
std
::
cout
<<
"a_grid_desc_k0_m_k1{"
<<
a_grid_desc_k0_m_k1
.
GetLength
(
I0
)
<<
", "
...
...
@@ -152,6 +150,8 @@ __host__ float driver_gemm_xdlops_v2r3(const FloatAB* p_a_grid,
float
ave_time
=
0
;
auto
element_op_
=
ElementwiseOperation
{};
#if CK_EXPERIMENTAL_PASS_TENSOR_DESCRIPTOR_BY_VALUE
if
(
has_main_k0_block_loop
)
{
...
...
@@ -162,6 +162,9 @@ __host__ float driver_gemm_xdlops_v2r3(const FloatAB* p_a_grid,
remove_reference_t
<
AGridDesc_K0_M_K1
>
,
remove_reference_t
<
BGridDesc_K0_N_K
>
,
remove_reference_t
<
CGridDesc_M0_N0_M1_N1_M2_M3_M4_N2
>
,
ElementwiseOperation
,
ElementwiseOperation
,
ElementwiseOperation
,
remove_reference_t
<
Block2CTileMap
>
,
true
>
;
...
...
@@ -176,6 +179,9 @@ __host__ float driver_gemm_xdlops_v2r3(const FloatAB* p_a_grid,
a_grid_desc_k0_m_k1
,
b_grid_desc_k0_n_k1
,
c_m0_n0_m1_n1_m2_m3_m4_n2_grid_desc
,
element_op_
,
element_op_
,
element_op_
,
block_2_ctile_map
);
}
else
...
...
@@ -187,6 +193,9 @@ __host__ float driver_gemm_xdlops_v2r3(const FloatAB* p_a_grid,
remove_reference_t
<
AGridDesc_K0_M_K1
>
,
remove_reference_t
<
BGridDesc_K0_N_K
>
,
remove_reference_t
<
CGridDesc_M0_N0_M1_N1_M2_M3_M4_N2
>
,
ElementwiseOperation
,
ElementwiseOperation
,
ElementwiseOperation
,
remove_reference_t
<
Block2CTileMap
>
,
false
>
;
...
...
@@ -201,6 +210,9 @@ __host__ float driver_gemm_xdlops_v2r3(const FloatAB* p_a_grid,
a_grid_desc_k0_m_k1
,
b_grid_desc_k0_n_k1
,
c_m0_n0_m1_n1_m2_m3_m4_n2_grid_desc
,
element_op_
,
element_op_
,
element_op_
,
block_2_ctile_map
);
}
#elif CK_EXPERIMENTAL_PASS_TENSOR_DESCRIPTOR_BY_VOID_POINTER
...
...
script/profile_gemm.sh
View file @
01192e26
...
...
@@ -24,22 +24,22 @@ REPEAT=$7
#$DRIVER $OP $DATATYPE $LAYOUT $VERIFY $INIT $LOG $REPEAT 1024 1024 1024 1024 1024 1024
#$DRIVER $OP $DATATYPE $LAYOUT $VERIFY $INIT $LOG $REPEAT 2048 2048 2048 2048 2048 2048
$DRIVER
$OP
$DATATYPE
$LAYOUT
$VERIFY
$INIT
$LOG
$REPEAT
960 1024 1024
-1
-1
-1
$DRIVER
$OP
$DATATYPE
$LAYOUT
$VERIFY
$INIT
$LOG
$REPEAT
1920 2048 2048
-1
-1
-1
$DRIVER
$OP
$DATATYPE
$LAYOUT
$VERIFY
$INIT
$LOG
$REPEAT
3840 4096 4096
-1
-1
-1
$DRIVER
$OP
$DATATYPE
$LAYOUT
$VERIFY
$INIT
$LOG
$REPEAT
7680 8192 8192
-1
-1
-1
$DRIVER
$OP
$DATATYPE
$LAYOUT
$VERIFY
$INIT
$LOG
$REPEAT
1024 1024 1024 1024 1024 1024
$DRIVER
$OP
$DATATYPE
$LAYOUT
$VERIFY
$INIT
$LOG
$REPEAT
2048 2048 2048 2048 2048 2048
$DRIVER
$OP
$DATATYPE
$LAYOUT
$VERIFY
$INIT
$LOG
$REPEAT
4096 4096 4096 4096 4096 4096
#
$DRIVER $OP $DATATYPE $LAYOUT $VERIFY $INIT $LOG $REPEAT 960 1024 1024 -1 -1 -1
#
$DRIVER $OP $DATATYPE $LAYOUT $VERIFY $INIT $LOG $REPEAT 1920 2048 2048 -1 -1 -1
#
$DRIVER $OP $DATATYPE $LAYOUT $VERIFY $INIT $LOG $REPEAT 3840 4096 4096 -1 -1 -1
#
$DRIVER $OP $DATATYPE $LAYOUT $VERIFY $INIT $LOG $REPEAT 7680 8192 8192 -1 -1 -1
#
$DRIVER $OP $DATATYPE $LAYOUT $VERIFY $INIT $LOG $REPEAT 1024 1024 1024 1024 1024 1024
#
$DRIVER $OP $DATATYPE $LAYOUT $VERIFY $INIT $LOG $REPEAT 2048 2048 2048 2048 2048 2048
#
$DRIVER $OP $DATATYPE $LAYOUT $VERIFY $INIT $LOG $REPEAT 4096 4096 4096 4096 4096 4096
$DRIVER
$OP
$DATATYPE
$LAYOUT
$VERIFY
$INIT
$LOG
$REPEAT
8192 8192 8192 8192 8192 8192
$DRIVER
$OP
$DATATYPE
$LAYOUT
$VERIFY
$INIT
$LOG
$REPEAT
1024 1024 1024 1056 1056 1056
$DRIVER
$OP
$DATATYPE
$LAYOUT
$VERIFY
$INIT
$LOG
$REPEAT
2048 2048 2048 2080 2080 2080
$DRIVER
$OP
$DATATYPE
$LAYOUT
$VERIFY
$INIT
$LOG
$REPEAT
4096 4096 4096 4128 4128 4128
$DRIVER
$OP
$DATATYPE
$LAYOUT
$VERIFY
$INIT
$LOG
$REPEAT
8192 8192 8192 8224 8224 8224
#
$DRIVER $OP $DATATYPE $LAYOUT $VERIFY $INIT $LOG $REPEAT 1024 1024 1024 1056 1056 1056
#
$DRIVER $OP $DATATYPE $LAYOUT $VERIFY $INIT $LOG $REPEAT 2048 2048 2048 2080 2080 2080
#
$DRIVER $OP $DATATYPE $LAYOUT $VERIFY $INIT $LOG $REPEAT 4096 4096 4096 4128 4128 4128
#
$DRIVER $OP $DATATYPE $LAYOUT $VERIFY $INIT $LOG $REPEAT 8192 8192 8192 8224 8224 8224
$DRIVER
$OP
$DATATYPE
$LAYOUT
$VERIFY
$INIT
$LOG
$REPEAT
1024 1024 1024 1088 1088 1088
$DRIVER
$OP
$DATATYPE
$LAYOUT
$VERIFY
$INIT
$LOG
$REPEAT
2048 2048 2048 2112 2112 2112
$DRIVER
$OP
$DATATYPE
$LAYOUT
$VERIFY
$INIT
$LOG
$REPEAT
4096 4096 4096 4160 4160 4160
$DRIVER
$OP
$DATATYPE
$LAYOUT
$VERIFY
$INIT
$LOG
$REPEAT
8192 8192 8192 8256 8256 8256
#
$DRIVER $OP $DATATYPE $LAYOUT $VERIFY $INIT $LOG $REPEAT 1024 1024 1024 1088 1088 1088
#
$DRIVER $OP $DATATYPE $LAYOUT $VERIFY $INIT $LOG $REPEAT 2048 2048 2048 2112 2112 2112
#
$DRIVER $OP $DATATYPE $LAYOUT $VERIFY $INIT $LOG $REPEAT 4096 4096 4096 4160 4160 4160
#
$DRIVER $OP $DATATYPE $LAYOUT $VERIFY $INIT $LOG $REPEAT 8192 8192 8192 8256 8256 8256
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment