Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel_ROCM
Commits
9f8e26f6
"...git@developer.sourcefind.cn:OpenDAS/mmdetection3d.git" did not exist on "bdbc7c664b4adf19c013d296c58dfafeb6e8fdf7"
Commit
9f8e26f6
authored
Feb 04, 2025
by
Andriy Roshchenko
Browse files
Add row-major C store
parent
6c39e6af
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
124 additions
and
1 deletion
+124
-1
test/mx_mfma_op/mx_mfma_op.hpp
test/mx_mfma_op/mx_mfma_op.hpp
+124
-1
No files found.
test/mx_mfma_op/mx_mfma_op.hpp
View file @
9f8e26f6
...
@@ -487,6 +487,129 @@ struct store_C_col_major<CType, CFragT, 32, 32>
...
@@ -487,6 +487,129 @@ struct store_C_col_major<CType, CFragT, 32, 32>
}
}
};
};
// Define a store function for C
// Size: (BLOCK_M x BLOCK_N)
// ASSUMPTION:
// - We want contiguous BLOCK_N sized row neighbors in register.
// - Data is in row major format
template
<
typename
CType
,
typename
CFragT
,
int32_t
BLOCK_M
,
int32_t
BLOCK_N
>
struct
store_C_row_major
;
// Here we want to store a 16x16 block of data.
//
// Size | BLOCK_N | BLOCK_N | BLOCK_N | BLOCK_N |
// N | 0 ... 15 | 0 ... 15 | 0 ... 15 | 0 ... 15 |
// Thread Id | 0 ... 15 | 16 ... 31 | 32 ... 47 | 48 ... 63 | Vector
// Register Element ------------ ------------- ------------ -------------- Element
// Reg0 | M0 | M4 | M8 | M12 | v[0]
// Reg1 | M1 | M5 | M9 | M13 | v[1]
// Reg2 | M2 | M6 | M10 | M14 | v[2]
// Reg3 | M3 | M7 | M11 | M15 | v[3]
template
<
typename
CType
,
typename
CFragT
>
struct
store_C_row_major
<
CType
,
CFragT
,
16
,
16
>
{
__device__
void
operator
()(
CType
*
output
,
CFragT
cFrag
)
{
static
constexpr
uint32_t
VW
=
vectorSize
(
cFrag
);
// 4
static
constexpr
uint32_t
Dim
=
16
;
// Each thread will load 4 elements.
// We need to know where they start, and where the next elements are.
auto
startCoord2D
=
std
::
make_pair
((
threadIdx
.
x
/
Dim
)
*
VW
,
// Row
threadIdx
.
x
%
Dim
);
// Col
auto
stepCoord2D
=
std
::
make_pair
(
1u
,
0u
);
// Flatten to 1D row_major offsets.
auto
row_major
=
[](
auto
const
&
coord
,
auto
ld
)
{
return
coord
.
first
*
ld
+
coord
.
second
;
};
auto
startOffset
=
row_major
(
startCoord2D
,
16
);
auto
kOffset
=
row_major
(
stepCoord2D
,
16
);
auto
*
fragPtr
=
reinterpret_cast
<
CFragT
*>
(
output
+
startOffset
);
*
fragPtr
=
cFrag
;
// If you notice carefully, kOffset != 1.
// This means the following is vector is updated with 4 non-contiguous offsets,
// which the compiler will separate into 4 different global_store_dword instructions.
output
[
startOffset
]
=
cFrag
[
0
];
// v[0] = Reg 0
output
[
startOffset
+
kOffset
]
=
cFrag
[
1
];
// v[1] = Reg 1
output
[
startOffset
+
2
*
kOffset
]
=
cFrag
[
2
];
// v[2] = Reg 2
output
[
startOffset
+
3
*
kOffset
]
=
cFrag
[
3
];
// v[3] = Reg 3
}
};
// Here we want to store a 32x32 block of data.
// Register Mapping:
// Size | BLOCK_N | BLOCK_N |
// N | 0 ... 31 | 0 ... 31 |
// Thread Id | 0 ... 31 | 32 ... 63 | Vector
// Register Element ------------ ------------- Element
// Reg0 | M0 | M4 | v[0]
// Reg1 | M1 | M5 | v[1]
// Reg2 | M2 | M6 | v[2]
// Reg3 | M3 | M7 | v[3]
// ____________ _____________
// Reg4 | M8 | M12 | v[4]
// Reg5 | M9 | M13 | v[5]
// Reg6 | M10 | M14 | v[6]
// Reg7 | M11 | M15 | v[7]
// ____________ _____________
// Reg8 | M16 | M20 | v[8]
// Reg9 | M17 | M21 | v[9]
// Reg10 | M18 | M22 | v[10]
// Reg11 | M19 | M23 | v[11]
// ____________ _____________
// Reg12 | M24 | M28 | v[12]
// Reg13 | M25 | M29 | v[13]
// Reg14 | M26 | M30 | v[14]
// Reg15 | M27 | M31 | v[15]
template
<
typename
CType
,
typename
CFragT
>
struct
store_C_row_major
<
CType
,
CFragT
,
32
,
32
>
{
__device__
void
operator
()(
CType
*
output
,
CFragT
cFrag
)
{
static
constexpr
uint32_t
WAVE_SIZE
=
64
;
static
constexpr
uint32_t
VW
=
4
;
// This VW is per 'chunk'
static
constexpr
uint32_t
Dim
=
32
;
// BLOCK_N
static
constexpr
uint32_t
M_PER_VW_CHUNK
=
VW
*
WAVE_SIZE
/
32
;
// 8
auto
startCoord2D
=
std
::
make_pair
((
threadIdx
.
x
/
Dim
)
*
VW
,
// Row
threadIdx
.
x
%
Dim
);
// Col
// Minor step for each 'chunk'
auto
minorStepCoord2D
=
std
::
make_pair
(
1u
,
0u
);
// Major step between 'chunks'
auto
majorStepCoord2D
=
std
::
make_pair
(
M_PER_VW_CHUNK
,
0
);
// Flatten to 1D row_major offsets.
auto
row_major
=
[](
auto
const
&
coord
,
auto
ld
)
{
return
coord
.
first
*
ld
+
coord
.
second
;
};
auto
startOffset
=
row_major
(
startCoord2D
,
32
);
auto
kMinorOffset
=
row_major
(
minorStepCoord2D
,
32
);
auto
kMajorOffset
=
row_major
(
majorStepCoord2D
,
32
);
output
[
startOffset
]
=
cFrag
[
0
];
// v[0] = Reg 0
output
[
startOffset
+
kMinorOffset
]
=
cFrag
[
1
];
// v[1] = Reg 1
output
[
startOffset
+
2
*
kMinorOffset
]
=
cFrag
[
2
];
// v[2] = Reg 2
output
[
startOffset
+
3
*
kMinorOffset
]
=
cFrag
[
3
];
// v[3] = Reg 3
output
[
startOffset
+
kMajorOffset
]
=
cFrag
[
4
];
// v[4] = Reg 4
output
[
startOffset
+
kMajorOffset
+
kMinorOffset
]
=
cFrag
[
5
];
// v[5] = Reg 5
output
[
startOffset
+
kMajorOffset
+
2
*
kMinorOffset
]
=
cFrag
[
6
];
// v[6] = Reg 6
output
[
startOffset
+
kMajorOffset
+
3
*
kMinorOffset
]
=
cFrag
[
7
];
// v[7] = Reg 7
output
[
startOffset
+
2
*
kMajorOffset
]
=
cFrag
[
8
];
// v[8] = Reg 8
output
[
startOffset
+
2
*
kMajorOffset
+
kMinorOffset
]
=
cFrag
[
9
];
// v[9] = Reg 9
output
[
startOffset
+
2
*
kMajorOffset
+
2
*
kMinorOffset
]
=
cFrag
[
10
];
// v[10] = Reg 10
output
[
startOffset
+
2
*
kMajorOffset
+
3
*
kMinorOffset
]
=
cFrag
[
11
];
// v[11] = Reg 11
output
[
startOffset
+
3
*
kMajorOffset
]
=
cFrag
[
12
];
// v[12] = Reg 12
output
[
startOffset
+
3
*
kMajorOffset
+
kMinorOffset
]
=
cFrag
[
13
];
// v[13] = Reg 13
output
[
startOffset
+
3
*
kMajorOffset
+
2
*
kMinorOffset
]
=
cFrag
[
14
];
// v[14] = Reg 14
output
[
startOffset
+
3
*
kMajorOffset
+
3
*
kMinorOffset
]
=
cFrag
[
15
];
// v[15] = Reg 15
}
};
template
<
typename
AType
,
template
<
typename
AType
,
typename
BType
,
typename
BType
,
typename
CType
,
typename
CType
,
...
@@ -581,7 +704,7 @@ matmul(const AType* a, const ScaleType* xa, const BType* b, const ScaleType* xb,
...
@@ -581,7 +704,7 @@ matmul(const AType* a, const ScaleType* xa, const BType* b, const ScaleType* xb,
fragC
[
i
]
=
type_convert
<
CType
>
(
fragAcc
.
template
AsType
<
RawAccumFragT
>()[
Number
<
0
>
{}][
i
]);
fragC
[
i
]
=
type_convert
<
CType
>
(
fragAcc
.
template
AsType
<
RawAccumFragT
>()[
Number
<
0
>
{}][
i
]);
}
}
auto
storeC
=
store_C_
col
_major
<
CType
,
CFragT
,
BLOCK_M
,
BLOCK_N
>
{};
auto
storeC
=
store_C_
row
_major
<
CType
,
CFragT
,
BLOCK_M
,
BLOCK_N
>
{};
storeC
(
c
,
fragC
);
storeC
(
c
,
fragC
);
}
}
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment