Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
873d0958
"docs/vscode:/vscode.git/clone" did not exist on "637628a70fc708057cfd6dfe8717ca9035553bc8"
Commit
873d0958
authored
Apr 21, 2022
by
ltqin
Browse files
fix M N PerXdlops
parent
10a2ae2f
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
24 additions
and
12 deletions
+24
-12
example/01_gemm/gemm_xdl_fp64.cpp
example/01_gemm/gemm_xdl_fp64.cpp
+8
-1
include/ck/tensor_operation/gpu/warp/xdlops_gemm.hpp
include/ck/tensor_operation/gpu/warp/xdlops_gemm.hpp
+10
-10
include/ck/utility/amd_xdlops.hpp
include/ck/utility/amd_xdlops.hpp
+6
-1
No files found.
example/01_gemm/gemm_xdl_fp64.cpp
View file @
873d0958
...
...
@@ -50,7 +50,7 @@ using DeviceGemmInstance = ck::tensor_operation::device::DeviceGemmXdl
//##########| Type| Type| Type| Type| | | | Elementwise| Elementwise| Elementwise|Spacialization| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| SrcDstVectorDim| DstScalar|
//##########| | | | | | | | Operation| Operation| Operation| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | | PerVector|
//##########| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
<
F64
,
F64
,
F64
,
F64
,
Row
,
Col
,
Row
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmDefault
,
256
,
256
,
128
,
4
,
2
,
32
,
32
,
4
,
2
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
2
,
2
,
true
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
2
,
2
,
true
,
7
,
1
>
;
<
F64
,
F64
,
F64
,
F64
,
Row
,
Col
,
Row
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmDefault
,
256
,
128
,
128
,
4
,
2
,
16
,
16
,
4
,
4
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
2
,
2
,
true
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
2
,
2
,
true
,
7
,
1
>
;
// clang-format on
using
ReferenceGemmInstance
=
ck
::
tensor_operation
::
host
::
...
...
@@ -198,6 +198,13 @@ int main(int argc, char* argv[])
ref_invoker
.
Run
(
ref_argument
);
if
(
0
)
{
LogRangeAsType
<
double
>
(
std
::
cout
<<
"a : "
,
a_m_k
.
mData
,
","
)
<<
std
::
endl
;
LogRangeAsType
<
double
>
(
std
::
cout
<<
"b: "
,
b_k_n
.
mData
,
","
)
<<
std
::
endl
;
LogRangeAsType
<
double
>
(
std
::
cout
<<
"c_device: "
,
c_m_n_device_result
.
mData
,
","
)
<<
std
::
endl
;
}
ck
::
utils
::
check_err
(
c_m_n_device_result
.
mData
,
c_m_n_host_result
.
mData
);
}
...
...
include/ck/tensor_operation/gpu/warp/xdlops_gemm.hpp
View file @
873d0958
...
...
@@ -387,17 +387,17 @@ struct mfma_type<MfmaInstr::mfma_i32_16x16x16i8>
template
<
>
struct
mfma_type
<
MfmaInstr
::
mfma_f64_16x16x4f64
>
{
static
constexpr
index_t
group_size
=
4
;
static
constexpr
index_t
num_groups_per_blk
=
1
;
static
constexpr
index_t
num_regs_per_blk
=
4
;
// group_size * num_groups_per_blk;
static
constexpr
index_t
group_size
=
4
;
static
constexpr
index_t
num_groups_per_blk
=
1
;
static
constexpr
index_t
num_regs_per_blk
=
4
;
// group_size * num_groups_per_blk;
static
constexpr
index_t
num_threads_per_blk
=
16
;
static
constexpr
index_t
wave_size
=
64
;
static
constexpr
index_t
num_input_blks
=
4
;
// wave_size / num_threads_per_blk;
static
constexpr
index_t
num_output_blks
=
1
;
static
constexpr
index_t
m_per_blk
=
16
;
static
constexpr
index_t
n_per_blk
=
16
;
static
constexpr
index_t
k_per_blk
=
1
;
static
constexpr
bool
is_k_reduction
=
true
;
static
constexpr
index_t
num_input_blks
=
4
;
// wave_size / num_threads_per_blk;
static
constexpr
index_t
num_output_blks
=
1
;
static
constexpr
index_t
m_per_blk
=
16
;
static
constexpr
index_t
n_per_blk
=
16
;
static
constexpr
index_t
k_per_blk
=
1
;
static
constexpr
bool
is_k_reduction
=
true
;
template
<
index_t
MPerXdlops
,
index_t
NPerXdlops
,
class
FloatA
,
class
FloatB
,
class
FloatC
>
__device__
void
run
(
const
FloatA
&
a
,
const
FloatB
&
b
,
FloatC
&
reg_c
)
const
...
...
@@ -413,7 +413,7 @@ struct MfmaSelector
static
constexpr
auto
GetMfma
();
template
<
>
static
constexpr
auto
GetMfma
<
double
,
32
,
32
>
()
static
constexpr
auto
GetMfma
<
double
,
16
,
16
>
()
{
return
MfmaInstr
::
mfma_f64_16x16x4f64
;
}
...
...
include/ck/utility/amd_xdlops.hpp
View file @
873d0958
...
...
@@ -298,13 +298,18 @@ template <index_t MPerWave, index_t NPerWave>
struct
intrin_mfma_f64_16x16x4f64
;
template
<
>
struct
intrin_mfma_f64_16x16x4f64
<
32
,
32
>
struct
intrin_mfma_f64_16x16x4f64
<
16
,
16
>
{
template
<
class
FloatC
>
__device__
static
void
Run
(
const
double
&
reg_a
,
const
double
&
reg_b
,
FloatC
&
reg_c
)
{
#ifdef __gxf90a__
reg_c
.
template
AsType
<
double4_t
>()(
Number
<
0
>
{})
=
__builtin_amdgcn_mfma_f64_16x16x4f64
(
reg_a
,
reg_b
,
reg_c
.
template
AsType
<
double4_t
>()[
Number
<
0
>
{}],
0
,
0
,
0
);
#else
reg_c
.
template
AsType
<
double4_t
>()(
Number
<
0
>
{})
=
{
reg_a
,
reg_a
,
reg_b
,
reg_b
};
#endif
}
};
}
// namespace ck
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment