Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel_ROCM
Commits
97b32147
"local_mode/README.md" did not exist on "18493eefc09cd47a6d47da3af0d73cbee063de9f"
Commit
97b32147
authored
Jan 29, 2025
by
Andriy Roshchenko
Browse files
Fix strides
parent
8e89c4d9
Changes
1
Show whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
31 additions
and
10 deletions
+31
-10
test/mx_mfma_op/mx_mfma_op.hpp
test/mx_mfma_op/mx_mfma_op.hpp
+31
-10
No files found.
test/mx_mfma_op/mx_mfma_op.hpp
View file @
97b32147
...
@@ -358,15 +358,15 @@ struct GemmParams
...
@@ -358,15 +358,15 @@ struct GemmParams
*
*
* A[16x128] * B[128x16] = C[16x16], all row major.
* A[16x128] * B[128x16] = C[16x16], all row major.
*/
*/
GemmParams
()
:
M
(
16
),
N
(
16
),
K
(
128
)
,
StrideA
(
128
),
StrideB
(
16
),
StrideC
(
16
)
{}
GemmParams
()
:
M
(
16
),
N
(
16
),
K
(
128
)
{}
ck
::
index_t
M
;
ck
::
index_t
M
;
ck
::
index_t
N
;
ck
::
index_t
N
;
ck
::
index_t
K
;
ck
::
index_t
K
;
ck
::
index_t
StrideA
;
ck
::
index_t
StrideA
=
-
1
;
ck
::
index_t
StrideB
;
ck
::
index_t
StrideB
=
-
1
;
ck
::
index_t
StrideC
;
ck
::
index_t
StrideC
=
-
1
;
};
};
template
<
typename
GemmInstance
,
template
<
typename
GemmInstance
,
...
@@ -465,9 +465,30 @@ struct TestMFMA
...
@@ -465,9 +465,30 @@ struct TestMFMA
params
.
M
=
BLOCK_M
;
params
.
M
=
BLOCK_M
;
params
.
N
=
BLOCK_N
;
params
.
N
=
BLOCK_N
;
params
.
K
=
BLOCK_K
;
params
.
K
=
BLOCK_K
;
params
.
StrideA
=
BLOCK_K
;
// M K
params
.
StrideB
=
BLOCK_N
;
// K N
auto
f_get_default_stride
=
[](
std
::
size_t
row
,
params
.
StrideC
=
BLOCK_N
;
// M N
std
::
size_t
col
,
ck
::
index_t
stride
,
auto
layout
)
{
if
(
stride
==
-
1
)
{
// give a chance if stride is -1, return a default packed stride
if
constexpr
(
std
::
is_same_v
<
decltype
(
layout
),
ck
::
tensor_layout
::
gemm
::
RowMajor
>
)
{
return
static_cast
<
std
::
size_t
>
(
col
);
}
else
{
return
static_cast
<
std
::
size_t
>
(
row
);
}
}
else
return
static_cast
<
std
::
size_t
>
(
stride
);
};
params
.
StrideA
=
f_get_default_stride
(
BLOCK_M
,
BLOCK_K
,
params
.
StrideA
,
ALayout
{});
params
.
StrideB
=
f_get_default_stride
(
BLOCK_K
,
BLOCK_N
,
params
.
StrideB
,
BLayout
{});
params
.
StrideC
=
f_get_default_stride
(
BLOCK_M
,
BLOCK_N
,
params
.
StrideC
,
CLayout
{});
auto
host_tensors
=
PrepareGemmTensors
(
params
);
auto
host_tensors
=
PrepareGemmTensors
(
params
);
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment