Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel_ROCM
Commits
a24c1b01
Commit
a24c1b01
authored
Feb 03, 2025
by
Andriy Roshchenko
Browse files
WIP: Completed implementation for MX FP8 MFMA
parent
465ba138
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
247 additions
and
16 deletions
+247
-16
include/ck/tensor_operation/gpu/warp/xdlops_gemm.hpp
include/ck/tensor_operation/gpu/warp/xdlops_gemm.hpp
+15
-5
test/mx_mfma_op/mx_mfma_op.cpp
test/mx_mfma_op/mx_mfma_op.cpp
+3
-3
test/mx_mfma_op/mx_mfma_op.hpp
test/mx_mfma_op/mx_mfma_op.hpp
+229
-8
No files found.
include/ck/tensor_operation/gpu/warp/xdlops_gemm.hpp
View file @
a24c1b01
...
@@ -847,9 +847,14 @@ struct mfma_type<MfmaInstr::mfma_scale_f32_32x32x64f8f6f4>
...
@@ -847,9 +847,14 @@ struct mfma_type<MfmaInstr::mfma_scale_f32_32x32x64f8f6f4>
// clang-format on
// clang-format on
template
<
index_t
MPerXdlops
,
index_t
NPerXdlops
,
class
FloatA
,
class
FloatB
,
class
FloatC
>
template
<
index_t
MPerXdlops
,
index_t
NPerXdlops
,
class
FloatA
,
class
FloatB
,
class
FloatC
>
__device__
void
run
(
const
FloatA
&
a
,
const
FloatB
&
b
,
FloatC
&
reg_c
)
const
__device__
void
run
(
const
FloatA
&
a
,
const
int32_t
&
scale_a
,
const
FloatB
&
b
,
const
int32_t
&
scale_b
,
FloatC
&
reg_c
)
const
{
{
intrin_mfma_scale_f32_32x32x64f8f6f4
<
MPerXdlops
,
NPerXdlops
>::
Run
(
a
,
b
,
reg_c
);
intrin_mfma_scale_f32_32x32x64f8f6f4
<
MPerXdlops
,
NPerXdlops
>::
Run
(
a
,
scale_a
,
b
,
scale_b
,
reg_c
);
}
}
};
};
...
@@ -871,9 +876,14 @@ struct mfma_type<MfmaInstr::mfma_scale_f32_16x16x128f8f6f4>
...
@@ -871,9 +876,14 @@ struct mfma_type<MfmaInstr::mfma_scale_f32_16x16x128f8f6f4>
// clang-format on
// clang-format on
template
<
index_t
MPerXdlops
,
index_t
NPerXdlops
,
class
FloatA
,
class
FloatB
,
class
FloatC
>
template
<
index_t
MPerXdlops
,
index_t
NPerXdlops
,
class
FloatA
,
class
FloatB
,
class
FloatC
>
__device__
void
run
(
const
FloatA
&
a
,
const
FloatB
&
b
,
FloatC
&
reg_c
)
const
__device__
void
run
(
const
FloatA
&
a
,
{
const
int32_t
&
scale_a
,
intrin_mfma_scale_f32_16x16x128f8f6f4
<
MPerXdlops
,
NPerXdlops
>::
Run
(
a
,
b
,
reg_c
);
const
FloatB
&
b
,
const
int32_t
&
scale_b
,
FloatC
&
reg_c
)
const
{
intrin_mfma_scale_f32_16x16x128f8f6f4
<
MPerXdlops
,
NPerXdlops
>::
Run
(
a
,
scale_a
,
b
,
scale_b
,
reg_c
);
}
}
};
};
...
...
test/mx_mfma_op/mx_mfma_op.cpp
View file @
a24c1b01
...
@@ -79,9 +79,9 @@ bool run_mxmfma_test(ck::index_t init)
...
@@ -79,9 +79,9 @@ bool run_mxmfma_test(ck::index_t init)
using
BLayout
=
ck
::
tensor_layout
::
gemm
::
ColumnMajor
;
using
BLayout
=
ck
::
tensor_layout
::
gemm
::
ColumnMajor
;
using
CLayout
=
ck
::
tensor_layout
::
gemm
::
RowMajor
;
using
CLayout
=
ck
::
tensor_layout
::
gemm
::
RowMajor
;
using
AccType
=
float
;
// only MFMA_F32 instructions supported
using
AccType
=
float
;
// only MFMA_F32 instructions supported
using
CPUAccType
=
AccType
;
//
using CPUAccType = AccType;
using
ScaleType
=
ck
::
e8m0_bexp_t
;
// biased exponent type
using
ScaleType
=
ck
::
e8m0_bexp_t
;
// biased exponent type
ck
::
mfma_type
<
static_cast
<
ck
::
MfmaInstr
>
(
mfma
)
>
mfma_instr
;
ck
::
mfma_type
<
static_cast
<
ck
::
MfmaInstr
>
(
mfma
)
>
mfma_instr
;
constexpr
auto
BLOCK_M
=
mfma_instr
.
m_per_blk
;
constexpr
auto
BLOCK_M
=
mfma_instr
.
m_per_blk
;
...
...
test/mx_mfma_op/mx_mfma_op.hpp
View file @
a24c1b01
...
@@ -38,6 +38,17 @@ struct mfma_type_selector<AFragT, BFragT, AccumFragT, 16, 16>
...
@@ -38,6 +38,17 @@ struct mfma_type_selector<AFragT, BFragT, AccumFragT, 16, 16>
auto
op
=
mfma_type
<
MfmaInstr
::
mfma_f32_16x16x128f8f6f4
>
{};
auto
op
=
mfma_type
<
MfmaInstr
::
mfma_f32_16x16x128f8f6f4
>
{};
op
.
template
run
<
16
,
16
,
AFragT
,
BFragT
,
AccumFragT
>(
fragA
,
fragB
,
fragAcc
);
op
.
template
run
<
16
,
16
,
AFragT
,
BFragT
,
AccumFragT
>(
fragA
,
fragB
,
fragAcc
);
}
}
__device__
void
operator
()(
AFragT
const
&
fragA
,
const
int32_t
&
scale_a
,
BFragT
const
&
fragB
,
const
int32_t
&
scale_b
,
AccumFragT
&
fragAcc
)
{
auto
op
=
mfma_type
<
MfmaInstr
::
mfma_scale_f32_16x16x128f8f6f4
>
{};
op
.
template
run
<
16
,
16
,
AFragT
,
BFragT
,
AccumFragT
>(
fragA
,
scale_a
,
fragB
,
scale_b
,
fragAcc
);
}
};
};
template
<
typename
AFragT
,
typename
BFragT
,
typename
AccumFragT
>
template
<
typename
AFragT
,
typename
BFragT
,
typename
AccumFragT
>
...
@@ -48,6 +59,17 @@ struct mfma_type_selector<AFragT, BFragT, AccumFragT, 32, 32>
...
@@ -48,6 +59,17 @@ struct mfma_type_selector<AFragT, BFragT, AccumFragT, 32, 32>
auto
op
=
mfma_type
<
MfmaInstr
::
mfma_f32_32x32x64f8f6f4
>
{};
auto
op
=
mfma_type
<
MfmaInstr
::
mfma_f32_32x32x64f8f6f4
>
{};
op
.
template
run
<
32
,
32
,
AFragT
,
BFragT
,
AccumFragT
>(
fragA
,
fragB
,
fragAcc
);
op
.
template
run
<
32
,
32
,
AFragT
,
BFragT
,
AccumFragT
>(
fragA
,
fragB
,
fragAcc
);
}
}
__device__
void
operator
()(
AFragT
const
&
fragA
,
const
int32_t
&
scale_a
,
BFragT
const
&
fragB
,
const
int32_t
&
scale_b
,
AccumFragT
&
fragAcc
)
{
auto
op
=
mfma_type
<
MfmaInstr
::
mfma_scale_f32_32x32x64f8f6f4
>
{};
op
.
template
run
<
32
,
32
,
AFragT
,
BFragT
,
AccumFragT
>(
fragA
,
scale_a
,
fragB
,
scale_b
,
fragAcc
);
}
};
};
template
<
typename
VecT
>
template
<
typename
VecT
>
...
@@ -137,11 +159,121 @@ __device__ AFragT load_A_col_major(AType const* input_ptr)
...
@@ -137,11 +159,121 @@ __device__ AFragT load_A_col_major(AType const* input_ptr)
return
fragA
;
return
fragA
;
}
}
// Define a load function for input A blocks:
// Size: (BLOCK_M x BLOCK_K)
// ASSUMPTION:
// - We want contiguous BLOCK_M sized column neighbors in register.
// - Data is in row major format
// This means:
// - From A we will load BLOCK_M rows of size K to satisfy our input data
template
<
typename
AType
,
typename
AFragT
,
int32_t
BLOCK_M
,
int32_t
BLOCK_K
>
__device__
AFragT
load_A_row_major
(
AType
const
*
input_ptr
)
{
// clang-format off
// Register Mapping for 16x128: || Register Mapping for 32x64:
// Size | BLOCK_M | BLOCK_M | BLOCK_M | BLOCK_M | || Size | BLOCK_M | BLOCK_M |
// M | 0 ... 15 | 0 ... 15 | 0 ... 15 | 0 ... 15 | || M | 0 ... 31 | 0 ... 31 |
// Thread Id | 0 ... 15 | 16 ... 31 | 32 ... 47 | 48 ... 63 | Vector || Thread Id | 0 ... 31 | 32 ... 63 | Vector
// Register Element ------------ ------------- ------------ ------------- Element || Register Element ------------ ------------- Element
// Reg 0 [0:7] | K0 | K32 | K64 | K96 | v[0] || Reg 0 [0:7] | K0 | K32 | v[0]
// Reg 0 [8:15] | K1 | K33 | K65 | K97 | v[1] || Reg 0 [8:15] | K1 | K33 | v[1]
// Reg 0 [16:23] | K2 | K34 | K66 | K98 | v[2] || Reg 0 [16:23] | K2 | K34 | v[2]
// Reg 0 [24:31] | K3 | K35 | K67 | K99 | v[3] || Reg 0 [24:31] | K3 | K35 | v[3]
// Reg 1 [0:7] | K4 | K36 | K68 | K100 | v[4] || Reg 1 [0:7] | K4 | K36 | v[4]
// Reg 1 [8:15] | K5 | K37 | K69 | K101 | v[5] || Reg 1 [8:15] | K5 | K37 | v[5]
// Reg 1 [16:23] | K6 | K38 | K70 | K102 | v[6] || Reg 1 [16:23] | K6 | K38 | v[6]
// Reg 1 [24:31] | K7 | K39 | K71 | K103 | v[7] || Reg 1 [24:31] | K7 | K39 | v[7]
// Reg 2 [0:7] | K8 | K40 | K72 | K104 | v[8] || Reg 2 [0:7] | K8 | K40 | v[8]
// Reg 2 [8:15] | K9 | K41 | K73 | K105 | v[9] || Reg 2 [8:15] | K9 | K41 | v[9]
// Reg 2 [16:23] | K10 | K42 | K74 | K106 | v[10] || Reg 2 [16:23] | K10 | K42 | v[10]
// Reg 2 [24:31] | K11 | K43 | K75 | K107 | v[11] || Reg 2 [24:31] | K11 | K43 | v[11]
// Reg 3 [0:7] | K12 | K44 | K76 | K108 | v[12] || Reg 3 [0:7] | K12 | K44 | v[12]
// Reg 3 [8:15] | K13 | K45 | K77 | K109 | v[13] || Reg 3 [8:15] | K13 | K45 | v[13]
// Reg 3 [16:23] | K14 | K46 | K78 | K110 | v[14] || Reg 3 [16:23] | K14 | K46 | v[14]
// Reg 3 [24:31] | K15 | K47 | K79 | K111 | v[15] || Reg 3 [24:31] | K15 | K47 | v[15]
// Reg 4 [0:7] | K16 | K48 | K80 | K112 | v[16] || Reg 4 [0:7] | K16 | K48 | v[16]
// Reg 4 [8:15] | K17 | K49 | K81 | K113 | v[17] || Reg 4 [8:15] | K17 | K49 | v[17]
// Reg 4 [16:23] | K18 | K50 | K82 | K114 | v[18] || Reg 4 [16:23] | K18 | K50 | v[18]
// Reg 4 [24:31] | K19 | K51 | K83 | K115 | v[19] || Reg 4 [24:31] | K19 | K51 | v[19]
// Reg 5 [0:7] | K20 | K52 | K84 | K116 | v[20] || Reg 5 [0:7] | K20 | K52 | v[20]
// Reg 5 [8:15] | K21 | K53 | K85 | K117 | v[21] || Reg 5 [8:15] | K21 | K53 | v[21]
// Reg 5 [16:23] | K22 | K54 | K86 | K118 | v[22] || Reg 5 [16:23] | K22 | K54 | v[22]
// Reg 5 [24:31] | K23 | K55 | K87 | K119 | v[23] || Reg 5 [24:31] | K23 | K55 | v[23]
// Reg 6 [0:7] | K24 | K56 | K88 | K120 | v[24] || Reg 6 [0:7] | K24 | K56 | v[24]
// Reg 6 [8:15] | K25 | K57 | K89 | K121 | v[25] || Reg 6 [8:15] | K25 | K57 | v[25]
// Reg 6 [16:23] | K26 | K58 | K90 | K122 | v[26] || Reg 6 [16:23] | K26 | K58 | v[26]
// Reg 6 [24:31] | K27 | K59 | K91 | K123 | v[27] || Reg 6 [24:31] | K27 | K59 | v[27]
// Reg 7 [0:7] | K28 | K60 | K92 | K124 | v[28] || Reg 7 [0:7] | K28 | K60 | v[28]
// Reg 7 [8:15] | K29 | K61 | K93 | K125 | v[29] || Reg 7 [8:15] | K29 | K61 | v[29]
// Reg 7 [16:23] | K30 | K62 | K94 | K126 | v[30] || Reg 7 [16:23] | K30 | K62 | v[30]
// Reg 7 [24:31] | K31 | K63 | K95 | K127 | v[31] || Reg 7 [24:31] | K31 | K63 | v[31]
// clang-format on
// Here we want to load a BLOCK_M x BLOCK_K block of data.
static
constexpr
uint32_t
VW
=
vectorSize
(
AFragT
{});
// To start the loading process, let's visualize in 2D coords.
// Each thread will load 32 elements.
// We need to know where they start, and where the next elements are.
auto
startCoord2D
=
std
::
make_pair
(
threadIdx
.
x
%
BLOCK_M
,
// Row
(
threadIdx
.
x
/
BLOCK_M
)
*
VW
);
// Col
// Flatten to 1D row_major offsets.
auto
row_major
=
[](
auto
const
&
coord
,
auto
ld
)
{
return
coord
.
first
*
ld
+
coord
.
second
;
};
// BLOCK_K is a stride in A matrix
auto
startOffset
=
row_major
(
startCoord2D
,
BLOCK_K
);
auto
const
*
fragPtr
=
reinterpret_cast
<
AFragT
const
*>
(
input_ptr
+
startOffset
);
return
*
fragPtr
;
}
// Define a load function for scaled A blocks:
// Size: (BLOCK_M x BLOCK_K)
// ASSUMPTION:
// - We want contiguous BLOCK_M sized column neighbors in register.
// - Data is in row major format
// - The scale inputs distributed across 64 lanes.
// This means:
// - From A we will load BLOCK_M rows of size K to satisfy our input data
template
<
typename
AType
,
typename
AFragT
,
typename
ScaleType
,
typename
ScaleFragT
,
int32_t
BLOCK_M
,
int32_t
BLOCK_K
,
int32_t
BLOCK_X
>
__device__
AFragT
load_mx_A_row_major
(
AType
const
*
input_ptr
,
ScaleType
const
*
scale_ptr
,
ScaleFragT
&
fragX
)
{
static
constexpr
uint32_t
VW
=
vectorSize
(
AFragT
{});
static_assert
(
VW
==
BLOCK_X
,
"Fragment size must be equal to BLOCK_X"
);
// To start the loading process, let's visualize in 2D coords.
// Each thread will load 1 element
// We need to know where they start
auto
startCoord2D
=
std
::
make_pair
(
threadIdx
.
x
%
BLOCK_M
,
// Row
(
threadIdx
.
x
/
BLOCK_M
)
*
VW
/
BLOCK_X
);
// Col
// Flatten to 1D row_major offsets.
auto
row_major
=
[](
auto
const
&
coord
,
auto
ld
)
{
return
coord
.
first
*
ld
+
coord
.
second
;
};
// BLOCK_K / BLOCK_X is a stride in xA matrix
auto
startOffset
=
row_major
(
startCoord2D
,
BLOCK_K
/
BLOCK_X
);
// preserve upper bits obtain 8-bit exponent
fragX
=
(
fragX
&
0xFFFFFF00
)
|
(
utils
::
get_exponent_value
(
scale_ptr
[
startOffset
])
&
0xFF
);
return
load_A_row_major
<
AType
,
AFragT
,
BLOCK_M
,
BLOCK_K
>
(
input_ptr
);
}
// Define a load function for input B blocks:
// Define a load function for input B blocks:
// Size: (BLOCK_K x BLOCK_N)
// Size: (BLOCK_K x BLOCK_N)
// ASSUMPTION:
// ASSUMPTION:
// - We want contiguous BLOCK_N sized row neighbors in register.
// - We want contiguous BLOCK_N sized row neighbors in register.
// - Data is in
row_
major format
// - Data is in
column
major format
// This means:
// This means:
// - From B we will load K rows of size BLOCK_N to satisfy our input data
// - From B we will load K rows of size BLOCK_N to satisfy our input data
template
<
typename
BType
,
typename
BFragT
,
int32_t
BLOCK_K
,
int32_t
BLOCK_N
>
template
<
typename
BType
,
typename
BFragT
,
int32_t
BLOCK_K
,
int32_t
BLOCK_N
>
...
@@ -205,6 +337,46 @@ __device__ BFragT load_B_col_major(BType const* input_ptr)
...
@@ -205,6 +337,46 @@ __device__ BFragT load_B_col_major(BType const* input_ptr)
return
*
fragPtr
;
return
*
fragPtr
;
}
}
// Define a load function for scaled B blocks:
// Size: (BLOCK_K x BLOCK_N)
// ASSUMPTION:
// - We want contiguous BLOCK_N sized row neighbors in register.
// - Data is in column major format
// - The scale inputs distributed across 64 lanes.
// This means:
// - From B we will load K rows of size BLOCK_N to satisfy our input data
template
<
typename
BType
,
typename
BFragT
,
typename
ScaleType
,
typename
ScaleFragT
,
int32_t
BLOCK_K
,
int32_t
BLOCK_N
,
int32_t
BLOCK_X
>
__device__
BFragT
load_mx_B_col_major
(
BType
const
*
input_ptr
,
ScaleType
const
*
scale_ptr
,
ScaleFragT
&
fragX
)
{
static
constexpr
uint32_t
VW
=
vectorSize
(
BFragT
{});
static_assert
(
VW
==
BLOCK_X
,
"Fragment size must be equal to BLOCK_X"
);
// To start the loading process, let's visualize in 2D coords.
// Each thread will load 1 element
// We need to know where to start
auto
startCoord2D
=
std
::
make_pair
((
threadIdx
.
x
/
BLOCK_N
)
*
VW
/
BLOCK_X
,
// Row
threadIdx
.
x
%
BLOCK_N
);
// Col
// Flatten to 1D col_major offsets.
auto
col_major
=
[](
auto
const
&
coord
,
auto
ld
)
{
return
coord
.
first
+
coord
.
second
*
ld
;
};
auto
startOffset
=
col_major
(
startCoord2D
,
BLOCK_K
/
BLOCK_X
);
// preserve upper bits obtain 8-bit exponent
fragX
=
(
fragX
&
0xFFFFFF00
)
|
(
utils
::
get_exponent_value
(
scale_ptr
[
startOffset
])
&
0xFF
);
return
load_B_col_major
<
BType
,
BFragT
,
BLOCK_K
,
BLOCK_N
>
(
input_ptr
);
}
// Define a store function for C
// Define a store function for C
// Size: (BLOCK_M x BLOCK_N)
// Size: (BLOCK_M x BLOCK_N)
// ASSUMPTION:
// ASSUMPTION:
...
@@ -368,12 +540,49 @@ template <typename AType,
...
@@ -368,12 +540,49 @@ template <typename AType,
int32_t
BLOCK_N
,
int32_t
BLOCK_N
,
int32_t
BLOCK_K
,
int32_t
BLOCK_K
,
int32_t
BLOCK_X
>
int32_t
BLOCK_X
>
__global__
void
matmul
(
const
AType
*
a
,
const
BType
*
b
,
const
ScaleType
*
x
,
CType
*
c
)
__global__
void
matmul
(
const
AType
*
a
,
const
ScaleType
*
xa
,
const
BType
*
b
,
const
ScaleType
*
xb
,
CType
*
c
)
{
{
ignore
=
a
;
constexpr
int
WAVE_SIZE
=
64
;
ignore
=
b
;
assert
(
threadIdx
.
x
<
WAVE_SIZE
);
ignore
=
x
;
assert
(
blockDim
.
x
==
1
&&
blockDim
.
y
==
1
&&
blockDim
.
z
==
1
);
ignore
=
c
;
using
AFragT
=
vector_type
<
AType
,
BLOCK_M
*
BLOCK_K
/
WAVE_SIZE
>::
type
;
using
BFragT
=
vector_type
<
BType
,
BLOCK_K
*
BLOCK_N
/
WAVE_SIZE
>::
type
;
using
CFragT
=
vector_type
<
CType
,
BLOCK_M
*
BLOCK_N
/
WAVE_SIZE
>::
type
;
using
AccumFragT
=
vector_type
<
AccType
,
BLOCK_M
*
BLOCK_N
/
WAVE_SIZE
>
;
using
RawAccumFragT
=
vector_type
<
AccType
,
BLOCK_M
*
BLOCK_N
/
WAVE_SIZE
>::
type
;
using
ScaleFragT
=
int32_t
;
// Create frags
auto
fragA
=
AFragT
{};
auto
fragB
=
BFragT
{};
auto
fragC
=
CFragT
{};
auto
fragAcc
=
AccumFragT
{
0
};
auto
fragXa
=
ScaleFragT
{
0
};
auto
fragXb
=
ScaleFragT
{
0
};
// Load the inputs.
// A = col major, BLOCK_M x BLOCK_K
fragA
=
load_mx_A_row_major
<
AType
,
AFragT
,
ScaleType
,
ScaleFragT
,
BLOCK_M
,
BLOCK_K
,
BLOCK_X
>
(
a
,
xa
,
fragXa
);
// B = col major, BLOCK_K x BLOCK_N
fragB
=
load_mx_B_col_major
<
BType
,
BFragT
,
ScaleType
,
ScaleFragT
,
BLOCK_K
,
BLOCK_N
,
BLOCK_X
>
(
b
,
xb
,
fragXb
);
// Scaled Matrix multiply-accumulate using MFMA units
// Accumulation intermediate = BLOCK_M x BLOCK_N
mfma_type_selector
<
AFragT
,
BFragT
,
AccumFragT
,
BLOCK_M
,
BLOCK_N
>
{}(
fragA
,
fragXa
,
fragB
,
fragXb
,
fragAcc
);
for
(
int
i
=
0
;
i
<
vectorSize
(
fragC
);
++
i
)
{
fragC
[
i
]
=
type_convert
<
CType
>
(
fragAcc
.
template
AsType
<
RawAccumFragT
>()[
Number
<
0
>
{}][
i
]);
}
auto
storeC
=
store_C_col_major
<
CType
,
CFragT
,
BLOCK_M
,
BLOCK_N
>
{};
storeC
(
c
,
fragC
);
}
}
/**
/**
...
@@ -443,20 +652,32 @@ void RunHostGEMM(const Tensor<ADataType>& A,
...
@@ -443,20 +652,32 @@ void RunHostGEMM(const Tensor<ADataType>& A,
ref_invoker
.
Run
(
ref_argument
);
ref_invoker
.
Run
(
ref_argument
);
}
}
template
<
typename
KernelType
,
typename
ADataType
,
typename
BDataType
,
typename
CDataType
>
template
<
typename
KernelType
,
typename
ADataType
,
typename
BDataType
,
typename
ScaleType
,
typename
CDataType
>
bool
RunDeviceGEMM
(
KernelType
kernel
,
bool
RunDeviceGEMM
(
KernelType
kernel
,
const
Tensor
<
ADataType
>&
A
,
const
Tensor
<
ADataType
>&
A
,
const
Tensor
<
ScaleType
>&
a_scales
,
const
Tensor
<
BDataType
>&
B
,
const
Tensor
<
BDataType
>&
B
,
const
Tensor
<
ScaleType
>&
b_scales
,
Tensor
<
CDataType
>&
C
)
Tensor
<
CDataType
>&
C
)
{
{
DeviceMem
a_m_k_device_buf
(
sizeof
(
ADataType
)
*
A
.
mDesc
.
GetElementSpaceSize
());
DeviceMem
a_m_k_device_buf
(
sizeof
(
ADataType
)
*
A
.
mDesc
.
GetElementSpaceSize
());
DeviceMem
a_scales_device_buf
(
sizeof
(
ScaleType
)
*
a_scales
.
mDesc
.
GetElementSpaceSize
());
DeviceMem
b_n_k_device_buf
(
sizeof
(
BDataType
)
*
B
.
mDesc
.
GetElementSpaceSize
());
DeviceMem
b_n_k_device_buf
(
sizeof
(
BDataType
)
*
B
.
mDesc
.
GetElementSpaceSize
());
DeviceMem
b_scales_device_buf
(
sizeof
(
ScaleType
)
*
b_scales
.
mDesc
.
GetElementSpaceSize
());
DeviceMem
c_m_n_device_buf
(
sizeof
(
CDataType
)
*
C
.
mDesc
.
GetElementSpaceSize
());
DeviceMem
c_m_n_device_buf
(
sizeof
(
CDataType
)
*
C
.
mDesc
.
GetElementSpaceSize
());
a_m_k_device_buf
.
ToDevice
(
A
.
mData
.
data
());
a_m_k_device_buf
.
ToDevice
(
A
.
mData
.
data
());
a_scales_device_buf
.
ToDevice
(
a_scales
.
mData
.
data
());
b_n_k_device_buf
.
ToDevice
(
B
.
mData
.
data
());
b_n_k_device_buf
.
ToDevice
(
B
.
mData
.
data
());
b_scales_device_buf
.
ToDevice
(
b_scales
.
mData
.
data
());
kernel
<<<
1
,
64
>>>
(
static_cast
<
const
ADataType
*>
(
a_m_k_device_buf
.
GetDeviceBuffer
()),
kernel
<<<
1
,
64
>>>
(
static_cast
<
const
ADataType
*>
(
a_m_k_device_buf
.
GetDeviceBuffer
()),
static_cast
<
const
ScaleType
*>
(
a_scales_device_buf
.
GetDeviceBuffer
()),
static_cast
<
const
BDataType
*>
(
b_n_k_device_buf
.
GetDeviceBuffer
()),
static_cast
<
const
BDataType
*>
(
b_n_k_device_buf
.
GetDeviceBuffer
()),
static_cast
<
const
ScaleType
*>
(
b_scales_device_buf
.
GetDeviceBuffer
()),
static_cast
<
CDataType
*>
(
c_m_n_device_buf
.
GetDeviceBuffer
()));
static_cast
<
CDataType
*>
(
c_m_n_device_buf
.
GetDeviceBuffer
()));
c_m_n_device_buf
.
FromDevice
(
C
.
mData
.
data
());
c_m_n_device_buf
.
FromDevice
(
C
.
mData
.
data
());
...
@@ -600,7 +821,7 @@ struct TestMXMFMA
...
@@ -600,7 +821,7 @@ struct TestMXMFMA
RunHostGEMM
(
a
,
a_scales
,
b
,
b_scales
,
c_host
);
RunHostGEMM
(
a
,
a_scales
,
b
,
b_scales
,
c_host
);
RunDeviceGEMM
(
mfma_kernel
,
a
,
b
,
c_device
);
RunDeviceGEMM
(
mfma_kernel
,
a
,
a_scales
,
b
,
b_scales
,
c_device
);
bool
res
=
false
;
bool
res
=
false
;
if
constexpr
(
std
::
is_same
<
CDataType
,
float
>::
value
||
if
constexpr
(
std
::
is_same
<
CDataType
,
float
>::
value
||
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment