Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel_ROCM
Commits
938256dd
Commit
938256dd
authored
Feb 07, 2025
by
Andriy Roshchenko
Browse files
Refactor `load_A_row_major` to follow scale mapping
parent
b138d4fd
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
223 additions
and
55 deletions
+223
-55
test/mx_mfma_op/mx_mfma_op.hpp
test/mx_mfma_op/mx_mfma_op.hpp
+223
-55
No files found.
test/mx_mfma_op/mx_mfma_op.hpp
View file @
938256dd
...
...
@@ -166,67 +166,197 @@ __device__ AFragT load_A_col_major(AType const* input_ptr)
// - Data is in row major format
// This means:
// - From A we will load BLOCK_M rows of size K to satisfy our input data
// template <typename AType, typename AFragT, int32_t BLOCK_M, int32_t BLOCK_K>
// struct load_A_row_major
// {
// __device__ AFragT operator()(AType const* input_ptr)
// {
// // clang-format off
// // Register Mapping for 16x128: ||
// Register Mapping for 32x64:
// // Size | BLOCK_M | BLOCK_M | BLOCK_M | BLOCK_M | ||
// Size | BLOCK_M | BLOCK_M |
// // M | 0 ... 15 | 0 ... 15 | 0 ... 15 | 0 ... 15 | || M
// | 0 ... 31 | 0 ... 31 |
// // Thread Id | 0 ... 15 | 16 ... 31 | 32 ... 47 | 48 ... 63 | Vector ||
// Thread Id | 0 ... 31 | 32 ... 63 | Vector
// // Register Element ------------ ------------- ------------ ------------- Element ||
// Register Element ------------ ------------- Element
// // Reg 0 [0:7] | K0 | K32 | K64 | K96 | v[0] || Reg
// 0 [0:7] | K0 | K32 | v[0]
// // Reg 0 [8:15] | K1 | K33 | K65 | K97 | v[1] || Reg
// 0 [8:15] | K1 | K33 | v[1]
// // Reg 0 [16:23] | K2 | K34 | K66 | K98 | v[2] || Reg
// 0 [16:23] | K2 | K34 | v[2]
// // Reg 0 [24:31] | K3 | K35 | K67 | K99 | v[3] || Reg
// 0 [24:31] | K3 | K35 | v[3]
// // Reg 1 [0:7] | K4 | K36 | K68 | K100 | v[4] || Reg
// 1 [0:7] | K4 | K36 | v[4]
// // Reg 1 [8:15] | K5 | K37 | K69 | K101 | v[5] || Reg
// 1 [8:15] | K5 | K37 | v[5]
// // Reg 1 [16:23] | K6 | K38 | K70 | K102 | v[6] || Reg
// 1 [16:23] | K6 | K38 | v[6]
// // Reg 1 [24:31] | K7 | K39 | K71 | K103 | v[7] || Reg
// 1 [24:31] | K7 | K39 | v[7]
// // Reg 2 [0:7] | K8 | K40 | K72 | K104 | v[8] || Reg
// 2 [0:7] | K8 | K40 | v[8]
// // Reg 2 [8:15] | K9 | K41 | K73 | K105 | v[9] || Reg
// 2 [8:15] | K9 | K41 | v[9]
// // Reg 2 [16:23] | K10 | K42 | K74 | K106 | v[10] || Reg
// 2 [16:23] | K10 | K42 | v[10]
// // Reg 2 [24:31] | K11 | K43 | K75 | K107 | v[11] || Reg
// 2 [24:31] | K11 | K43 | v[11]
// // Reg 3 [0:7] | K12 | K44 | K76 | K108 | v[12] || Reg
// 3 [0:7] | K12 | K44 | v[12]
// // Reg 3 [8:15] | K13 | K45 | K77 | K109 | v[13] || Reg
// 3 [8:15] | K13 | K45 | v[13]
// // Reg 3 [16:23] | K14 | K46 | K78 | K110 | v[14] || Reg
// 3 [16:23] | K14 | K46 | v[14]
// // Reg 3 [24:31] | K15 | K47 | K79 | K111 | v[15] || Reg
// 3 [24:31] | K15 | K47 | v[15]
// // Reg 4 [0:7] | K16 | K48 | K80 | K112 | v[16] || Reg
// 4 [0:7] | K16 | K48 | v[16]
// // Reg 4 [8:15] | K17 | K49 | K81 | K113 | v[17] || Reg
// 4 [8:15] | K17 | K49 | v[17]
// // Reg 4 [16:23] | K18 | K50 | K82 | K114 | v[18] || Reg
// 4 [16:23] | K18 | K50 | v[18]
// // Reg 4 [24:31] | K19 | K51 | K83 | K115 | v[19] || Reg
// 4 [24:31] | K19 | K51 | v[19]
// // Reg 5 [0:7] | K20 | K52 | K84 | K116 | v[20] || Reg
// 5 [0:7] | K20 | K52 | v[20]
// // Reg 5 [8:15] | K21 | K53 | K85 | K117 | v[21] || Reg
// 5 [8:15] | K21 | K53 | v[21]
// // Reg 5 [16:23] | K22 | K54 | K86 | K118 | v[22] || Reg
// 5 [16:23] | K22 | K54 | v[22]
// // Reg 5 [24:31] | K23 | K55 | K87 | K119 | v[23] || Reg
// 5 [24:31] | K23 | K55 | v[23]
// // Reg 6 [0:7] | K24 | K56 | K88 | K120 | v[24] || Reg
// 6 [0:7] | K24 | K56 | v[24]
// // Reg 6 [8:15] | K25 | K57 | K89 | K121 | v[25] || Reg
// 6 [8:15] | K25 | K57 | v[25]
// // Reg 6 [16:23] | K26 | K58 | K90 | K122 | v[26] || Reg
// 6 [16:23] | K26 | K58 | v[26]
// // Reg 6 [24:31] | K27 | K59 | K91 | K123 | v[27] || Reg
// 6 [24:31] | K27 | K59 | v[27]
// // Reg 7 [0:7] | K28 | K60 | K92 | K124 | v[28] || Reg
// 7 [0:7] | K28 | K60 | v[28]
// // Reg 7 [8:15] | K29 | K61 | K93 | K125 | v[29] || Reg
// 7 [8:15] | K29 | K61 | v[29]
// // Reg 7 [16:23] | K30 | K62 | K94 | K126 | v[30] || Reg
// 7 [16:23] | K30 | K62 | v[30]
// // Reg 7 [24:31] | K31 | K63 | K95 | K127 | v[31] || Reg
// 7 [24:31] | K31 | K63 | v[31]
// // clang-format on
// // Here we want to load a BLOCK_M x BLOCK_K block of data.
// static constexpr uint32_t VW = vectorSize(AFragT{});
// // To start the loading process, let's visualize in 2D coords.
// // Each thread will load 32 elements.
// // We need to know where they start, and where the next elements are.
// auto startCoord2D = std::make_pair(threadIdx.x % BLOCK_M, // Row
// (threadIdx.x / BLOCK_M) * VW); // Col
// // Flatten to 1D row_major offsets.
// auto row_major = [](auto const& coord, auto ld) { return coord.first * ld + coord.second;
// };
// // BLOCK_K is a stride in A matrix
// auto startOffset = row_major(startCoord2D, BLOCK_K);
// auto const* fragPtr = reinterpret_cast<AFragT const*>(input_ptr + startOffset);
// return *fragPtr;
// }
// }
template
<
typename
AType
,
typename
AFragT
,
int32_t
BLOCK_M
,
int32_t
BLOCK_K
>
__device__
AFragT
load_A_row_major
(
AType
const
*
input_ptr
)
struct
load_A_row_major
{
// clang-format off
__device__
AFragT
operator
()(
AType
const
*
input_ptr
)
{
// clang-format off
// Register Mapping for 16x128: || Register Mapping for 32x64:
// Size | BLOCK_M | BLOCK_M | BLOCK_M | BLOCK_M | || Size | BLOCK_M | BLOCK_M |
// M | 0 ... 15 | 0 ... 15 | 0 ... 15 | 0 ... 15 | || M | 0 ... 31 | 0 ... 31 |
// Thread Id | 0 ... 15 | 16 ... 31 | 32 ... 47 | 48 ... 63 | Vector || Thread Id | 0 ... 31 | 32 ... 63 | Vector
// Register Element ------------ ------------- ------------ ------------- Element || Register Element ------------ ------------- Element
// Reg 0 [0:7] | K0 | K32 | K64 | K96 | v[0] || Reg 0 [0:7] | K0 | K32 | v[0]
// Reg 0 [8:15] | K1 | K33 | K65 | K97 | v[1] || Reg 0 [8:15] | K1 | K33 | v[1]
// Reg 0 [16:23] | K2 | K34 | K66 | K98 | v[2] || Reg 0 [16:23] | K2 | K34 | v[2]
// Reg 0 [24:31] | K3 | K35 | K67 | K99 | v[3] || Reg 0 [24:31] | K3 | K35 | v[3]
// Reg 1 [0:7] | K4 | K36 | K68 | K100 | v[4] || Reg 1 [0:7] | K4 | K36 | v[4]
// Reg 1 [8:15] | K5 | K37 | K69 | K101 | v[5] || Reg 1 [8:15] | K5 | K37 | v[5]
// Reg 1 [16:23] | K6 | K38 | K70 | K102 | v[6] || Reg 1 [16:23] | K6 | K38 | v[6]
// Reg 1 [24:31] | K7 | K39 | K71 | K103 | v[7] || Reg 1 [24:31] | K7 | K39 | v[7]
// Reg 2 [0:7] | K8 | K40 | K72 | K104 | v[8] || Reg 2 [0:7] | K8 | K40 | v[8]
// Reg 2 [8:15] | K9 | K41 | K73 | K105 | v[9] || Reg 2 [8:15] | K9 | K41 | v[9]
// Reg 2 [16:23] | K10 | K42 | K74 | K106 | v[10] || Reg 2 [16:23] | K10 | K42 | v[10]
// Reg 2 [24:31] | K11 | K43 | K75 | K107 | v[11] || Reg 2 [24:31] | K11 | K43 | v[11]
// Reg 3 [0:7] | K12 | K44 | K76 | K108 | v[12] || Reg 3 [0:7] | K12 | K44 | v[12]
// Reg 3 [8:15] | K13 | K45 | K77 | K109 | v[13] || Reg 3 [8:15] | K13 | K45 | v[13]
// Reg 3 [16:23] | K14 | K46 | K78 | K110 | v[14] || Reg 3 [16:23] | K14 | K46 | v[14]
// Reg 3 [24:31] | K15 | K47 | K79 | K111 | v[15] || Reg 3 [24:31] | K15 | K47 | v[15]
// Reg 4 [0:7] | K16 | K48 | K80 | K112 | v[16] || Reg 4 [0:7] | K16 | K48 | v[16]
// Reg 4 [8:15] | K17 | K49 | K81 | K113 | v[17] || Reg 4 [8:15] | K17 | K49 | v[17]
// Reg 4 [16:23] | K18 | K50 | K82 | K114 | v[18] || Reg 4 [16:23] | K18 | K50 | v[18]
// Reg 4 [24:31] | K19 | K51 | K83 | K115 | v[19] || Reg 4 [24:31] | K19 | K51 | v[19]
// Reg 5 [0:7] | K20 | K52 | K84 | K116 | v[20] || Reg 5 [0:7] | K20 | K52 | v[20]
// Reg 5 [8:15] | K21 | K53 | K85 | K117 | v[21] || Reg 5 [8:15] | K21 | K53 | v[21]
// Reg 5 [16:23] | K22 | K54 | K86 | K118 | v[22] || Reg 5 [16:23] | K22 | K54 | v[22]
// Reg 5 [24:31] | K23 | K55 | K87 | K119 | v[23] || Reg 5 [24:31] | K23 | K55 | v[23]
// Reg 6 [0:7] | K24 | K56 | K88 | K120 | v[24] || Reg 6 [0:7] | K24 | K56 | v[24]
// Reg 6 [8:15] | K25 | K57 | K89 | K121 | v[25] || Reg 6 [8:15] | K25 | K57 | v[25]
// Reg 6 [16:23] | K26 | K58 | K90 | K122 | v[26] || Reg 6 [16:23] | K26 | K58 | v[26]
// Reg 6 [24:31] | K27 | K59 | K91 | K123 | v[27] || Reg 6 [24:31] | K27 | K59 | v[27]
// Reg 7 [0:7] | K28 | K60 | K92 | K124 | v[28] || Reg 7 [0:7] | K28 | K60 | v[28]
// Reg 7 [8:15] | K29 | K61 | K93 | K125 | v[29] || Reg 7 [8:15] | K29 | K61 | v[29]
// Reg 7 [16:23] | K30 | K62 | K94 | K126 | v[30] || Reg 7 [16:23] | K30 | K62 | v[30]
// Reg 7 [24:31] | K31 | K63 | K95 | K127 | v[31] || Reg 7 [24:31] | K31 | K63 | v[31]
// clang-format on
// Size | BLOCK_M | BLOCK_M | BLOCK_M | BLOCK_M | || Size | BLOCK_M | BLOCK_M | |
// M | 0 ... 15 | 0 ... 15 | 0 ... 15 | 0 ... 15 | Vector || M | 0 ... 31 | 0 ... 31 | Vector |
// Thread Id | 0 ... 15 | 16 ... 31 | 32 ... 47 | 48 ... 63 | Element || Thread Id | 0 ... 31 | 32 ... 63 | Element|
// Register Element |------------|-------------|------------|-------------|-----------|| Register Element |------------|-------------|--------|
// Reg 0 [0:7] | K0 | K16 | K32 | K48 | v[0] || Reg 0 [0:7] | K0 | K16 | v[0] |
// Reg 0 [8:15] | K1 | K17 | K33 | K49 | v[1] || Reg 0 [8:15] | K1 | K17 | v[1] |
// Reg 0 [16:23] | K2 | K18 | K34 | K50 | v[2] || Reg 0 [16:23] | K2 | K18 | v[2] |
// Reg 0 [24:31] | K3 | K19 | K35 | K51 | v[3] || Reg 0 [24:31] | K3 | K19 | v[3] |
// Reg 1 [0:7] | K4 | K20 | K36 | K52 | v[4] || Reg 1 [0:7] | K4 | K20 | v[4] |
// Reg 1 [8:15] | K5 | K21 | K37 | K53 | v[5] || Reg 1 [8:15] | K5 | K21 | v[5] |
// Reg 1 [16:23] | K6 | K22 | K38 | K54 | v[6] || Reg 1 [16:23] | K6 | K22 | v[6] |
// Reg 1 [24:31] | K7 | K23 | K39 | K55 | v[7] || Reg 1 [24:31] | K7 | K23 | v[7] |
// Reg 2 [0:7] | K8 | K24 | K40 | K56 | v[8] || Reg 2 [0:7] | K8 | K24 | v[8] |
// Reg 2 [8:15] | K9 | K25 | K41 | K57 | v[9] || Reg 2 [8:15] | K9 | K25 | v[9] |
// Reg 2 [16:23] | K10 | K26 | K42 | K58 | v[10] || Reg 2 [16:23] | K10 | K26 | v[10] |
// Reg 2 [24:31] | K11 | K27 | K43 | K59 | v[11] || Reg 2 [24:31] | K11 | K27 | v[11] |
// Reg 3 [0:7] | K12 | K28 | K44 | K60 | v[12] || Reg 3 [0:7] | K12 | K28 | v[12] |
// Reg 3 [8:15] | K13 | K29 | K45 | K61 | v[13] || Reg 3 [8:15] | K13 | K29 | v[13] |
// Reg 3 [16:23] | K14 | K30 | K46 | K62 | v[14] || Reg 3 [16:23] | K14 | K30 | v[14] |
// Reg 3 [24:31] | K15 | K31 | K47 | K63 | v[15] || Reg 3 [24:31] | K15 | K31 | v[15] |
// Reg 4 [0:7] | K64 | K80 | K96 | K112 | v[16] || Reg 4 [0:7] | K32 | K48 | v[16] |
// Reg 4 [8:15] | K65 | K81 | K97 | K113 | v[17] || Reg 4 [8:15] | K33 | K49 | v[17] |
// Reg 4 [16:23] | K66 | K82 | K98 | K114 | v[18] || Reg 4 [16:23] | K34 | K50 | v[18] |
// Reg 4 [24:31] | K67 | K83 | K99 | K115 | v[19] || Reg 4 [24:31] | K35 | K51 | v[19] |
// Reg 5 [0:7] | K68 | K84 | K100 | K116 | v[20] || Reg 5 [0:7] | K36 | K52 | v[20] |
// Reg 5 [8:15] | K69 | K85 | K101 | K117 | v[21] || Reg 5 [8:15] | K37 | K53 | v[21] |
// Reg 5 [16:23] | K70 | K86 | K102 | K118 | v[22] || Reg 5 [16:23] | K38 | K54 | v[22] |
// Reg 5 [24:31] | K71 | K87 | K103 | K119 | v[23] || Reg 5 [24:31] | K39 | K55 | v[23] |
// Reg 6 [0:7] | K72 | K88 | K104 | K120 | v[24] || Reg 6 [0:7] | K40 | K56 | v[24] |
// Reg 6 [8:15] | K73 | K89 | K105 | K121 | v[25] || Reg 6 [8:15] | K41 | K57 | v[25] |
// Reg 6 [16:23] | K74 | K90 | K106 | K122 | v[26] || Reg 6 [16:23] | K42 | K58 | v[26] |
// Reg 6 [24:31] | K75 | K91 | K107 | K123 | v[27] || Reg 6 [24:31] | K43 | K59 | v[27] |
// Reg 7 [0:7] | K76 | K92 | K108 | K124 | v[28] || Reg 7 [0:7] | K44 | K60 | v[28] |
// Reg 7 [8:15] | K77 | K93 | K109 | K125 | v[29] || Reg 7 [8:15] | K45 | K61 | v[29] |
// Reg 7 [16:23] | K78 | K94 | K110 | K126 | v[30] || Reg 7 [16:23] | K46 | K62 | v[30] |
// Reg 7 [24:31] | K79 | K95 | K111 | K127 | v[31] || Reg 7 [24:31] | K47 | K63 | v[31] |
// clang-format on
static
constexpr
int32_t
WAVE_SIZE
=
64
;
// Here we want to load from rows of A in chunks of 16 elements each.
static
constexpr
uint32_t
chunk_size
=
16
;
// each chunk is separated by offset
static
constexpr
uint32_t
chunk_offset
=
chunk_size
*
WAVE_SIZE
/
BLOCK_M
;
// To start the loading process, let's visualize in 2D coords.
// Each thread will load 32 elements.
// We need to know where they start, and where the next elements are.
auto
startCoord2D
=
std
::
make_pair
(
threadIdx
.
x
%
BLOCK_M
,
// Row {0-31} | {0-15}
(
threadIdx
.
x
/
BLOCK_M
)
*
chunk_size
);
// Col {0, 16} | {0, 16, 32, 48}
// Here we want to load a BLOCK_M x BLOCK_K block of data.
static
constexpr
uint32_t
VW
=
vectorSize
(
AFragT
{});
// auto minorStepCoord2D = std::make_pair(0u, 1u); // read rows
auto
majorStepCoord2D
=
std
::
make_pair
(
0
,
chunk_offset
);
// read a chunk from a row
// To start the loading process, let's visualize in 2D coords.
// Each thread will load 32 elements.
// We need to know where they start, and where the next elements are.
auto
startCoord2D
=
std
::
make_pair
(
threadIdx
.
x
%
BLOCK_M
,
// Row
(
threadIdx
.
x
/
BLOCK_M
)
*
VW
);
// Col
// Flatten to 1D row_major offsets.
auto
row_major
=
[](
auto
const
&
coord
,
auto
ld
)
{
return
coord
.
first
*
ld
+
coord
.
second
;
};
// Flatten to 1D row_major offsets.
auto
row_major
=
[](
auto
const
&
coord
,
auto
ld
)
{
return
coord
.
first
*
ld
+
coord
.
second
;
};
// BLOCK_K is a stride in A matrix
auto
startOffset
=
row_major
(
startCoord2D
,
BLOCK_K
);
// auto kMinorOffset = row_major(minorStepCoord2D, BLOCK_K);
auto
kMajorOffset
=
row_major
(
majorStepCoord2D
,
BLOCK_K
);
// BLOCK_K is a stride
in A
matrix
auto
startOffset
=
row_major
(
startCoord2D
,
BLOCK_K
)
;
us
in
g
A
RawT
=
typename
scalar_type
<
AFragT
>::
type
;
using
AScalarFragT
=
vector_type
<
ARawT
,
chunk_size
>::
type
;
auto
const
*
fragPtr
=
reinterpret_cast
<
AFragT
const
*>
(
input_ptr
+
startOffset
);
return
*
fragPtr
;
}
union
{
AFragT
frag
;
AScalarFragT
chunks
[
2
];
}
fragA
{};
auto
*
fragPtr
=
reinterpret_cast
<
AScalarFragT
const
*>
(
input_ptr
+
startOffset
);
fragA
.
chunks
[
0
]
=
*
fragPtr
;
fragPtr
=
reinterpret_cast
<
AScalarFragT
const
*>
(
input_ptr
+
startOffset
+
kMajorOffset
);
fragA
.
chunks
[
1
]
=
*
fragPtr
;
return
fragA
.
frag
;
}
};
// Define a load function for scaled A blocks:
// Size: (BLOCK_M x BLOCK_K)
...
...
@@ -246,8 +376,46 @@ template <typename AType,
__device__
AFragT
load_mx_A_row_major
(
AType
const
*
input_ptr
,
ScaleType
const
*
scale_ptr
,
ScaleFragT
&
fragX
)
{
// clang-format off
// Register Mapping for 16x128: || Register Mapping for 32x64:
// Size | BLOCK_M | BLOCK_M | | BLOCK_M | BLOCK_M | | || Size | BLOCK_M | BLOCK_M | | |
// M | 0 ... 15 | 0 ... 15 | | 0 ... 15 | 0 ... 15 | | Vector || M | 0 ... 31 | 0 ... 31 | Vector | |
// Thread Id | 0 ... 15 | 16 ... 31 | Scale | 32 ... 47 | 48 ... 63 | Scale | Element || Thread Id | 0 ... 31 | 32 ... 63 | Element| Scale |
// Register Element ------------ ------------- ----------|------------ ------------- ----------|-----------|| Register Element |------------|-------------|--------|----------|
// Reg 0 [0:7] | K0 | K16 | x(M,0) | K32 | K48 | x(M,1) | v[0] || Reg 0 [0:7] | K0 | K16 | v[0] | x(M,0) |
// Reg 0 [8:15] | K1 | K17 | x(M,0) | K33 | K49 | x(M,1) | v[1] || Reg 0 [8:15] | K1 | K17 | v[1] | x(M,0) |
// Reg 0 [16:23] | K2 | K18 | x(M,0) | K34 | K50 | x(M,1) | v[2] || Reg 0 [16:23] | K2 | K18 | v[2] | x(M,0) |
// Reg 0 [24:31] | K3 | K19 | x(M,0) | K35 | K51 | x(M,1) | v[3] || Reg 0 [24:31] | K3 | K19 | v[3] | x(M,0) |
// Reg 1 [0:7] | K4 | K20 | x(M,0) | K36 | K52 | x(M,1) | v[4] || Reg 1 [0:7] | K4 | K20 | v[4] | x(M,0) |
// Reg 1 [8:15] | K5 | K21 | x(M,0) | K37 | K53 | x(M,1) | v[5] || Reg 1 [8:15] | K5 | K21 | v[5] | x(M,0) |
// Reg 1 [16:23] | K6 | K22 | x(M,0) | K38 | K54 | x(M,1) | v[6] || Reg 1 [16:23] | K6 | K22 | v[6] | x(M,0) |
// Reg 1 [24:31] | K7 | K23 | x(M,0) | K39 | K55 | x(M,1) | v[7] || Reg 1 [24:31] | K7 | K23 | v[7] | x(M,0) |
// Reg 2 [0:7] | K8 | K24 | x(M,0) | K40 | K56 | x(M,1) | v[8] || Reg 2 [0:7] | K8 | K24 | v[8] | x(M,0) |
// Reg 2 [8:15] | K9 | K25 | x(M,0) | K41 | K57 | x(M,1) | v[9] || Reg 2 [8:15] | K9 | K25 | v[9] | x(M,0) |
// Reg 2 [16:23] | K10 | K26 | x(M,0) | K42 | K58 | x(M,1) | v[10] || Reg 2 [16:23] | K10 | K26 | v[10] | x(M,0) |
// Reg 2 [24:31] | K11 | K27 | x(M,0) | K43 | K59 | x(M,1) | v[11] || Reg 2 [24:31] | K11 | K27 | v[11] | x(M,0) |
// Reg 3 [0:7] | K12 | K28 | x(M,0) | K44 | K60 | x(M,1) | v[12] || Reg 3 [0:7] | K12 | K28 | v[12] | x(M,0) |
// Reg 3 [8:15] | K13 | K29 | x(M,0) | K45 | K61 | x(M,1) | v[13] || Reg 3 [8:15] | K13 | K29 | v[13] | x(M,0) |
// Reg 3 [16:23] | K14 | K30 | x(M,0) | K46 | K62 | x(M,1) | v[14] || Reg 3 [16:23] | K14 | K30 | v[14] | x(M,0) |
// Reg 3 [24:31] | K15 | K31 | x(M,0) | K47 | K63 | x(M,1) | v[15] || Reg 3 [24:31] | K15 | K31 | v[15] | x(M,0) |
// Reg 4 [0:7] | K64 | K80 | x(M,2) | K96 | K112 | x(M,3) | v[16] || Reg 4 [0:7] | K32 | K48 | v[16] | x(M,1) |
// Reg 4 [8:15] | K65 | K81 | x(M,2) | K97 | K113 | x(M,3) | v[17] || Reg 4 [8:15] | K33 | K49 | v[17] | x(M,1) |
// Reg 4 [16:23] | K66 | K82 | x(M,2) | K98 | K114 | x(M,3) | v[18] || Reg 4 [16:23] | K34 | K50 | v[18] | x(M,1) |
// Reg 4 [24:31] | K67 | K83 | x(M,2) | K99 | K115 | x(M,3) | v[19] || Reg 4 [24:31] | K35 | K51 | v[19] | x(M,1) |
// Reg 5 [0:7] | K68 | K84 | x(M,2) | K100 | K116 | x(M,3) | v[20] || Reg 5 [0:7] | K36 | K52 | v[20] | x(M,1) |
// Reg 5 [8:15] | K69 | K85 | x(M,2) | K101 | K117 | x(M,3) | v[21] || Reg 5 [8:15] | K37 | K53 | v[21] | x(M,1) |
// Reg 5 [16:23] | K70 | K86 | x(M,2) | K102 | K118 | x(M,3) | v[22] || Reg 5 [16:23] | K38 | K54 | v[22] | x(M,1) |
// Reg 5 [24:31] | K71 | K87 | x(M,2) | K103 | K119 | x(M,3) | v[23] || Reg 5 [24:31] | K39 | K55 | v[23] | x(M,1) |
// Reg 6 [0:7] | K72 | K88 | x(M,2) | K104 | K120 | x(M,3) | v[24] || Reg 6 [0:7] | K40 | K56 | v[24] | x(M,1) |
// Reg 6 [8:15] | K73 | K89 | x(M,2) | K105 | K121 | x(M,3) | v[25] || Reg 6 [8:15] | K41 | K57 | v[25] | x(M,1) |
// Reg 6 [16:23] | K74 | K90 | x(M,2) | K106 | K122 | x(M,3) | v[26] || Reg 6 [16:23] | K42 | K58 | v[26] | x(M,1) |
// Reg 6 [24:31] | K75 | K91 | x(M,2) | K107 | K123 | x(M,3) | v[27] || Reg 6 [24:31] | K43 | K59 | v[27] | x(M,1) |
// Reg 7 [0:7] | K76 | K92 | x(M,2) | K108 | K124 | x(M,3) | v[28] || Reg 7 [0:7] | K44 | K60 | v[28] | x(M,1) |
// Reg 7 [8:15] | K77 | K93 | x(M,2) | K109 | K125 | x(M,3) | v[29] || Reg 7 [8:15] | K45 | K61 | v[29] | x(M,1) |
// Reg 7 [16:23] | K78 | K94 | x(M,2) | K110 | K126 | x(M,3) | v[30] || Reg 7 [16:23] | K46 | K62 | v[30] | x(M,1) |
// Reg 7 [24:31] | K79 | K95 | x(M,2) | K111 | K127 | x(M,3) | v[31] || Reg 7 [24:31] | K47 | K63 | v[31] | x(M,1) |
// clang-format on
static
constexpr
uint32_t
VW
=
vectorSize
(
AFragT
{});
static_assert
(
VW
==
BLOCK_X
,
"Fragment size must be equal to BLOCK_X"
);
...
...
@@ -266,7 +434,7 @@ __device__ AFragT load_mx_A_row_major(AType const* input_ptr,
// preserve upper bits obtain 8-bit exponent
fragX
=
(
fragX
&
0xFFFFFF00
)
|
(
utils
::
get_exponent_value
(
scale_ptr
[
startOffset
])
&
0xFF
);
return
load_A_row_major
<
AType
,
AFragT
,
BLOCK_M
,
BLOCK_K
>
(
input_ptr
);
return
load_A_row_major
<
AType
,
AFragT
,
BLOCK_M
,
BLOCK_K
>
{}
(
input_ptr
);
}
// Define a load function for input B blocks:
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment