Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel_ROCM
Commits
2bd601e1
Commit
2bd601e1
authored
Jan 30, 2025
by
Andriy Roshchenko
Browse files
Cleanup
parent
718c7abb
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
8 additions
and
93 deletions
+8
-93
test/mx_mfma_op/mx_mfma_op.hpp
test/mx_mfma_op/mx_mfma_op.hpp
+8
-93
No files found.
test/mx_mfma_op/mx_mfma_op.hpp
View file @
2bd601e1
...
...
@@ -30,14 +30,8 @@ struct mfma_type_selector<AFragT, BFragT, AccumFragT, 16, 16>
{
__device__
void
operator
()(
AFragT
const
&
fragA
,
BFragT
const
&
fragB
,
AccumFragT
&
fragAcc
)
{
#if 1
auto
op
=
mfma_type
<
MfmaInstr
::
mfma_f32_16x16x128f8f6f4
>
{};
op
.
template
run
<
16
,
16
,
AFragT
,
BFragT
,
AccumFragT
>(
fragA
,
fragB
,
fragAcc
);
#else
ignore
=
fragA
;
ignore
=
fragB
;
ignore
=
fragAcc
;
#endif
}
};
...
...
@@ -46,14 +40,8 @@ struct mfma_type_selector<AFragT, BFragT, AccumFragT, 32, 32>
{
__device__
void
operator
()(
AFragT
const
&
fragA
,
BFragT
const
&
fragB
,
AccumFragT
&
fragAcc
)
{
#if 1
auto
op
=
mfma_type
<
MfmaInstr
::
mfma_f32_32x32x64f8f6f4
>
{};
op
.
template
run
<
32
,
32
,
AFragT
,
BFragT
,
AccumFragT
>(
fragA
,
fragB
,
fragAcc
);
#else
ignore
=
fragA
;
ignore
=
fragB
;
ignore
=
fragAcc
;
#endif
}
};
...
...
@@ -131,43 +119,9 @@ __device__ AFragT load_A_col_major(AType const* input_ptr)
// BLOCK_M is a stride in A matrix
auto
startOffset
=
col_major
(
startCoord2D
,
BLOCK_M
);
auto
kOffset
=
col_major
(
stepCoord2D
,
BLOCK_M
);
// kOffset == BLOCK_M
// This means every BLOCK_M element is loaded into output vector
#if 0
auto fragA = AScalarFragT{
bit_cast<ARawT>(input_ptr[startOffset]), // XXX v[0] = Reg 0 [0:7]
bit_cast<ARawT>(input_ptr[startOffset + 1 * kOffset]), // XXX v[1] = Reg 0 [8:15]
bit_cast<ARawT>(input_ptr[startOffset + 2 * kOffset]), // XXX v[2] = Reg 0 [16:23]
bit_cast<ARawT>(input_ptr[startOffset + 3 * kOffset]), // XXX v[3] = Reg 0 [24:31]
bit_cast<ARawT>(input_ptr[startOffset + 4 * kOffset]), // XXX v[4] = Reg 1 [0:7]
bit_cast<ARawT>(input_ptr[startOffset + 5 * kOffset]), // XXX v[5] = Reg 1 [8:15]
bit_cast<ARawT>(input_ptr[startOffset + 6 * kOffset]), // XXX v[6] = Reg 1 [16:23]
bit_cast<ARawT>(input_ptr[startOffset + 7 * kOffset]), // XXX v[7] = Reg 1 [24:31]
bit_cast<ARawT>(input_ptr[startOffset + 8 * kOffset]), // XXX v[8] = Reg 2 [0:7]
bit_cast<ARawT>(input_ptr[startOffset + 9 * kOffset]), // XXX v[9] = Reg 2 [8:15]
bit_cast<ARawT>(input_ptr[startOffset + 10 * kOffset]), // XXX v[10] = Reg 2 [16:23]
bit_cast<ARawT>(input_ptr[startOffset + 11 * kOffset]), // XXX v[11] = Reg 2 [24:31]
bit_cast<ARawT>(input_ptr[startOffset + 12 * kOffset]), // XXX v[12] = Reg 3 [0:7]
bit_cast<ARawT>(input_ptr[startOffset + 13 * kOffset]), // XXX v[13] = Reg 3 [8:15]
bit_cast<ARawT>(input_ptr[startOffset + 14 * kOffset]), // XXX v[14] = Reg 3 [16:23]
bit_cast<ARawT>(input_ptr[startOffset + 15 * kOffset]), // XXX v[15] = Reg 3 [24:31]
bit_cast<ARawT>(input_ptr[startOffset + 16 * kOffset]), // XXX v[16] = Reg 4 [0:7]
bit_cast<ARawT>(input_ptr[startOffset + 17 * kOffset]), // XXX v[17] = Reg 4 [8:15]
bit_cast<ARawT>(input_ptr[startOffset + 18 * kOffset]), // XXX v[18] = Reg 4 [16:23]
bit_cast<ARawT>(input_ptr[startOffset + 19 * kOffset]), // XXX v[19] = Reg 4 [24:31]
bit_cast<ARawT>(input_ptr[startOffset + 20 * kOffset]), // XXX v[20] = Reg 5 [0:7]
bit_cast<ARawT>(input_ptr[startOffset + 21 * kOffset]), // XXX v[21] = Reg 5 [8:15]
bit_cast<ARawT>(input_ptr[startOffset + 22 * kOffset]), // XXX v[22] = Reg 5 [16:23]
bit_cast<ARawT>(input_ptr[startOffset + 23 * kOffset]), // XXX v[23] = Reg 5 [24:31]
bit_cast<ARawT>(input_ptr[startOffset + 24 * kOffset]), // XXX v[24] = Reg 6 [0:7]
bit_cast<ARawT>(input_ptr[startOffset + 25 * kOffset]), // XXX v[25] = Reg 6 [8:15]
bit_cast<ARawT>(input_ptr[startOffset + 26 * kOffset]), // XXX v[26] = Reg 6 [16:23]
bit_cast<ARawT>(input_ptr[startOffset + 27 * kOffset]), // XXX v[27] = Reg 6 [24:31]
bit_cast<ARawT>(input_ptr[startOffset + 28 * kOffset]), // XXX v[28] = Reg 7 [0:7]
bit_cast<ARawT>(input_ptr[startOffset + 29 * kOffset]), // XXX v[29] = Reg 7 [8:15]
bit_cast<ARawT>(input_ptr[startOffset + 30 * kOffset]), // XXX v[30] = Reg 7 [16:23]
bit_cast<ARawT>(input_ptr[startOffset + 31 * kOffset])}; // XXX v[31] = Reg 7 [24:31]
#else
// kOffset == BLOCK_M
// This means every BLOCK_M element is loaded into output vector
auto
fragA
=
AScalarFragT
{};
#pragma unroll VW
for
(
uint32_t
i
=
0
;
i
<
VW
;
i
++
)
...
...
@@ -175,7 +129,6 @@ __device__ AFragT load_A_col_major(AType const* input_ptr)
fragA
[
i
]
=
bit_cast
<
ARawT
>
(
input_ptr
[
startOffset
+
i
*
kOffset
]);
}
#endif
return
fragA
;
}
...
...
@@ -237,15 +190,12 @@ __device__ BFragT load_B_col_major(BType const* input_ptr)
// We need to know where they start, and where the next elements are.
auto
startCoord2D
=
std
::
make_pair
((
threadIdx
.
x
/
BLOCK_N
)
*
VW
,
// Row
threadIdx
.
x
%
BLOCK_N
);
// Col
// auto stepCoord2D = std::make_pair(1u, 0u);
// Flatten to 1D col_major offsets.
auto
col_major
=
[](
auto
const
&
coord
,
auto
ld
)
{
return
coord
.
first
+
coord
.
second
*
ld
;
};
auto
startOffset
=
col_major
(
startCoord2D
,
BLOCK_K
);
// auto kOffset = col_major(stepCoord2D, BLOCK_K);
// kOffset == 1
auto
const
*
fragPtr
=
reinterpret_cast
<
BFragT
const
*>
(
input_ptr
+
startOffset
);
return
*
fragPtr
;
}
...
...
@@ -278,29 +228,16 @@ struct store_C_col_major<CType, CFragT, 16, 16>
static
constexpr
uint32_t
VW
=
vectorSize
(
cFrag
);
// 4
static
constexpr
uint32_t
Dim
=
16
;
#if 1
for
(
int
i
=
0
;
i
<
vectorSize
(
cFrag
);
++
i
)
{
printf
(
"threadIdx.x = %d; cFrag[%d] = %f
\n
"
,
static_cast
<
int
>
(
threadIdx
.
x
),
i
,
static_cast
<
float
>
(
cFrag
[
i
]));
}
#endif
// Each thread will load 4 elements.
// We need to know where they start, and where the next elements are.
auto
startCoord2D
=
std
::
make_pair
((
threadIdx
.
x
/
Dim
)
*
VW
,
// Row
threadIdx
.
x
%
Dim
);
// Col
// auto stepCoord2D = std::make_pair(1u, 0u);
// Flatten to 1D col_major offsets.
auto
col_major
=
[](
auto
const
&
coord
,
auto
ld
)
{
return
coord
.
first
+
coord
.
second
*
ld
;
};
auto
startOffset
=
col_major
(
startCoord2D
,
16
);
// auto kOffset = col_major(stepCoord2D, 16); // 1
// kOffset == 1
auto
*
fragPtr
=
reinterpret_cast
<
CFragT
*>
(
output
+
startOffset
);
*
fragPtr
=
cFrag
;
}
...
...
@@ -343,34 +280,19 @@ struct store_C_col_major<CType, CFragT, 32, 32>
static
constexpr
uint32_t
Dim
=
32
;
static
constexpr
uint32_t
M_PER_VW_CHUNK
=
VW
*
WAVE_SIZE
/
32
;
// 8
#if 1
for
(
int
i
=
0
;
i
<
vectorSize
(
cFrag
);
++
i
)
{
printf
(
"threadIdx.x = %d; cFrag[%d] = %f
\n
"
,
static_cast
<
int
>
(
threadIdx
.
x
),
i
,
static_cast
<
float
>
(
cFrag
[
i
]));
}
#endif
auto
startCoord2D
=
std
::
make_pair
((
threadIdx
.
x
/
Dim
)
*
VW
,
// Row
threadIdx
.
x
%
Dim
);
// Col
// Minor step for each 'chunk'
// auto minorStepCoord2D = std::make_pair(1u, 0u);
// Major step between 'chunks'
auto
majorStepCoord2D
=
std
::
make_pair
(
M_PER_VW_CHUNK
,
0
);
// Flatten to 1D col_major offsets.
auto
col_major
=
[](
auto
const
&
coord
,
auto
ld
)
{
return
coord
.
first
+
coord
.
second
*
ld
;
};
auto
startOffset
=
col_major
(
startCoord2D
,
32
);
// auto kMinorOffset = col_major(minorStepCoord2D, 32); // 1
auto
startOffset
=
col_major
(
startCoord2D
,
32
);
auto
kMajorOffset
=
col_major
(
majorStepCoord2D
,
32
);
// 8
// kMinorOffset == 1.
// This means we can vector store 4 contiguous elements at a time.
// we can vector store 4 contiguous elements at a time.
using
CRawT
=
typename
scalar_type
<
CFragT
>::
type
;
using
CScalarFragT
=
vector_type
<
CRawT
,
VW
>::
type
;
union
...
...
@@ -444,16 +366,9 @@ __global__ void matmul(const AType* a, const BType* b, CType* c)
*/
struct
GemmParams
{
/**
* @brief This constructor initializes the parameters for GEMM storage with default values.
*
* A[16x128] * B[128x16] = C[16x16], all row major.
*/
GemmParams
()
:
M
(
16
),
N
(
16
),
K
(
128
)
{}
ck
::
index_t
M
;
ck
::
index_t
N
;
ck
::
index_t
K
;
ck
::
index_t
M
=
16
;
ck
::
index_t
N
=
16
;
ck
::
index_t
K
=
128
;
ck
::
index_t
StrideA
=
-
1
;
ck
::
index_t
StrideB
=
-
1
;
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment