Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel_ROCM
Commits
dbb7002d
Commit
dbb7002d
authored
Feb 06, 2025
by
Adam Osewski
Browse files
Merge remote-tracking branch 'origin/develop' into aosewski/hotloop
parents
96c8d948
2bef5501
Changes
228
Hide whitespace changes
Inline
Side-by-side
Showing
8 changed files
with
1991 additions
and
3 deletions
+1991
-3
test/data_type/test_fp8_ocp.cpp
test/data_type/test_fp8_ocp.cpp
+2
-2
test/data_type/test_mx_bf8.cpp
test/data_type/test_mx_bf8.cpp
+654
-0
test/data_type/test_mx_fp8.cpp
test/data_type/test_mx_fp8.cpp
+616
-0
test/data_type/test_pk_i4.cpp
test/data_type/test_pk_i4.cpp
+77
-0
test/mx_mfma_op/CMakeLists.txt
test/mx_mfma_op/CMakeLists.txt
+9
-0
test/mx_mfma_op/mx_mfma_op.cpp
test/mx_mfma_op/mx_mfma_op.cpp
+65
-0
test/mx_mfma_op/mx_mfma_op.hpp
test/mx_mfma_op/mx_mfma_op.hpp
+567
-0
test/smfmac_op/smfmac_op_xdl.cpp
test/smfmac_op/smfmac_op_xdl.cpp
+1
-1
No files found.
test/data_type/test_fp8_ocp.cpp
View file @
dbb7002d
...
...
@@ -60,8 +60,8 @@ TEST(FP8OCP, ConvertFP32Nearest)
float
neg_float
=
-
0.015625
f
;
//-2^-6
ASSERT_NEAR
(
neg_float
,
type_convert
<
float
>
(
f8_convert_rne
<
f8_ocp_t
>
(
neg_float
)),
0.0
f
);
// positive subnorm f
loat
value to fp8 and back, check if holds
pos_float
=
0.00390625
f
;
// positive subnorm f
p8
value to fp8 and back, check if holds
pos_float
=
0.00390625
f
;
// 2^-8
ASSERT_NEAR
(
pos_float
,
type_convert
<
float
>
(
f8_convert_rne
<
f8_ocp_t
>
(
pos_float
)),
abs_tol
);
// min subnorm fp8 value to fp8 and back, check if holds
...
...
test/data_type/test_mx_bf8.cpp
0 → 100644
View file @
dbb7002d
// SPDX-License-Identifier: MIT
// Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved.
#include "gtest/gtest.h"
#include "ck/library/utility/device_memory.hpp"
#include "ck/utility/scaled_type_convert.hpp"
using
ck
::
bf8_ocp_t
;
using
ck
::
bf8x16_ocp_t
;
using
ck
::
bf8x2_ocp_t
;
using
ck
::
bf8x32_ocp_t
;
using
ck
::
e8m0_bexp_t
;
using
ck
::
float16_t
;
using
ck
::
float2_t
;
using
ck
::
float32_t
;
using
ck
::
mxf8_convert_rne
;
using
ck
::
mxf8_convert_sr
;
using
ck
::
scaled_type_convert
;
using
ck
::
type_convert
;
constexpr
uint64_t
test_size
=
256
*
256
+
2
+
4
+
6
;
/**
* @brief Tests conversion of BF8 values to float using E8M0 exponent scaling.
*
* This function performs a series of conversions from BF8 values to float values using
* E8M0 exponent scaling. It handles all possible combinations of E8M0 and BF8 values,
* as well as specific vector and rounding conversions.
*
* @param N The maximum number of conversions to perform.
* @param p_test Pointer to the output array where the converted float values will be stored.
* @param p_completed Pointer to a variable that tracks the number of completed conversions.
*
* @note If either p_test or p_completed is nullptr, the function will return immediately.
* @note The function will stop converting if the number of conversions reaches N.
* @note First 256*256 conversions are for all possible combinations of E8M0 and BF8 values that are
* stored in memory sequentially with BF8 values varying faster.
*
* The function performs the following conversions:
* - All possible combinations of E8M0 and BF8 values. [256x256]
* - Vector conversions bf8x2 -> f32x2. [2]
* - Vector conversions f32x2 -> bf8x2 rne. [2]
* - Vector conversions f32x2 -> bf8x2 sr. [2]
* - Round to nearest even conversions for specific float values. [6]
*
* The results are stored in the p_test array, and the number of completed conversions
* is updated in the p_completed variable.
*/
__host__
__device__
void
test_mx_bf8_scaled_convert
(
uint64_t
N
,
float
*
p_test
,
uint64_t
*
p_completed
)
{
if
(
p_completed
==
nullptr
)
{
return
;
}
uint64_t
&
i
=
*
p_completed
;
i
=
0
;
if
(
p_test
==
nullptr
)
{
return
;
}
// All possible combinations of E8M0 and BF8
for
(
ck
::
index_t
exp_id
=
0
;
exp_id
<
256
;
exp_id
++
)
{
for
(
ck
::
index_t
bf8_id
=
0
;
bf8_id
<
256
;
bf8_id
++
)
{
uint8_t
bf8_uid
=
static_cast
<
uint8_t
>
(
bf8_id
);
auto
v
=
scaled_type_convert
<
float
>
(
e8m0_bexp_t
(
exp_id
),
bf8_ocp_t
{
bf8_uid
});
p_test
[
i
]
=
v
;
i
++
;
if
(
i
>=
N
)
{
return
;
}
}
}
/// Test vector conversions
// bf8x2 -> f32x2
bf8x2_ocp_t
bf8x2
{
bf8x2_ocp_t
::
data_v
{
0b10000100
,
0b00000001
}};
//-2^-14, 2^-16
auto
scale
=
e8m0_bexp_t
(
8.0
f
);
float2_t
f32x2
=
scaled_type_convert
<
float2_t
>
(
scale
,
bf8x2
);
p_test
[
i
++
]
=
f32x2
[
0
];
if
(
i
>=
N
)
{
return
;
}
p_test
[
i
++
]
=
f32x2
[
1
];
if
(
i
>=
N
)
{
return
;
}
// f32x2 -> bf8x2
f32x2
=
{
-
8.0
f
,
4.0
f
};
auto
scale2
=
e8m0_bexp_t
(
2.0
f
);
bf8x2
=
mxf8_convert_rne
<
bf8x2_ocp_t
>
(
f32x2
,
type_convert
<
float
>
(
scale2
));
// expect {-4, 2}
p_test
[
i
++
]
=
type_convert
<
float
>
(
bf8x2
.
AsType
<
bf8_ocp_t
>
()(
ck
::
Number
<
0
>
{}));
//-4f
if
(
i
>=
N
)
{
return
;
}
p_test
[
i
++
]
=
type_convert
<
float
>
(
bf8x2
.
AsType
<
bf8_ocp_t
>
()(
ck
::
Number
<
1
>
{}));
// 2f
if
(
i
>=
N
)
{
return
;
}
auto
scale4
=
e8m0_bexp_t
(
4.0
f
);
bf8x2
=
mxf8_convert_sr
<
bf8x2_ocp_t
>
(
f32x2
,
type_convert
<
float
>
(
scale4
));
// expect {-2, 1}
p_test
[
i
++
]
=
type_convert
<
float
>
(
bf8x2
.
AsType
<
bf8_ocp_t
>
()(
ck
::
Number
<
0
>
{}));
//-2f
if
(
i
>=
N
)
{
return
;
}
p_test
[
i
++
]
=
type_convert
<
float
>
(
bf8x2
.
AsType
<
bf8_ocp_t
>
()(
ck
::
Number
<
1
>
{}));
// 1f
if
(
i
>=
N
)
{
return
;
}
/// Test round to nearest even
p_test
[
i
++
]
=
type_convert
<
float
>
(
mxf8_convert_rne
<
bf8_ocp_t
>
(
1024.0
f
,
4.0
f
));
// 1024/4
if
(
i
>=
N
)
{
return
;
}
p_test
[
i
++
]
=
type_convert
<
float
>
(
mxf8_convert_rne
<
bf8_ocp_t
>
(
std
::
numeric_limits
<
float
>::
quiet_NaN
(),
4.0
f
));
// => NaN
if
(
i
>=
N
)
{
return
;
}
p_test
[
i
++
]
=
type_convert
<
float
>
(
mxf8_convert_rne
<
bf8_ocp_t
>
(
std
::
numeric_limits
<
float
>::
infinity
(),
2.0
f
));
// => BF8 Inf on device
if
(
i
>=
N
)
{
return
;
}
// 31000/0.5 > 57344 => BF8 Inf on device
p_test
[
i
++
]
=
type_convert
<
float
>
(
mxf8_convert_rne
<
bf8_ocp_t
>
(
31000.0
f
,
0.5
f
));
if
(
i
>=
N
)
{
return
;
}
// -31000/0.5 < -57344 => -BF8 Inf on device
p_test
[
i
++
]
=
type_convert
<
float
>
(
mxf8_convert_rne
<
bf8_ocp_t
>
(
-
31000.0
f
,
0.5
f
));
if
(
i
>=
N
)
{
return
;
}
p_test
[
i
++
]
=
type_convert
<
float
>
(
mxf8_convert_rne
<
bf8_ocp_t
>
(
powf
(
2.0
f
,
16.0
f
),
4.0
f
));
// 2^16/4 = 65536/4
if
(
i
>=
N
)
{
return
;
}
}
TEST
(
MXBF8
,
HostScaledConvert
)
{
std
::
vector
<
float
>
out
(
test_size
,
-
1.0
f
);
uint64_t
completed
=
0
;
test_mx_bf8_scaled_convert
(
test_size
,
out
.
data
(),
&
completed
);
// V = X * P; X - E8M0 scale, P - BF8
// If X = NaN, then V = NaN regardless of P
uint8_t
e8m0_nan_id
=
ck
::
NumericLimits
<
e8m0_bexp_t
>::
QuietNaN
().
data
;
for
(
ck
::
index_t
bf8_id
=
0
;
bf8_id
<
256
;
bf8_id
++
)
{
auto
idx
=
e8m0_nan_id
*
256
+
bf8_id
;
ASSERT_TRUE
(
std
::
isnan
(
out
[
idx
]));
}
// If P in {Inf, NaN}, then V = P
std
::
set
<
uint8_t
>
bf8_spec_ids
;
bf8_spec_ids
.
insert
(
0b11111111
);
// -NaN
bf8_spec_ids
.
insert
(
0b01111111
);
// +NaN
bf8_spec_ids
.
insert
(
0b11111101
);
// -NaN
bf8_spec_ids
.
insert
(
0b01111101
);
// +NaN
bf8_spec_ids
.
insert
(
0b11111110
);
// -NaN
bf8_spec_ids
.
insert
(
0b01111110
);
// +NaN
bf8_spec_ids
.
insert
(
0b11111100
);
// -inf
bf8_spec_ids
.
insert
(
0b01111100
);
// +inf
for
(
ck
::
index_t
exp_id
=
0
;
exp_id
<
256
;
exp_id
++
)
{
if
(
exp_id
==
e8m0_nan_id
)
continue
;
for
(
auto
bf8_spec_id
:
bf8_spec_ids
)
{
auto
idx
=
exp_id
*
256
+
bf8_spec_id
;
if
(
std
::
isnan
(
type_convert
<
float
>
(
bf8_ocp_t
{
bf8_spec_id
})))
{
ASSERT_TRUE
(
std
::
isnan
(
out
[
idx
]))
<<
"exp_id: "
<<
exp_id
<<
" bf8_id: "
<<
bf8_spec_id
<<
std
::
endl
<<
type_convert
<
float
>
(
e8m0_bexp_t
(
exp_id
))
<<
" * "
<<
type_convert
<
float
>
(
bf8_ocp_t
{
bf8_spec_id
})
<<
" != "
<<
out
[
idx
];
}
else
{
ASSERT_EQ
(
out
[
idx
],
type_convert
<
float
>
(
bf8_ocp_t
{
bf8_spec_id
}))
<<
"exp_id: "
<<
exp_id
<<
" bf8_id: "
<<
bf8_spec_id
<<
std
::
endl
<<
type_convert
<
float
>
(
e8m0_bexp_t
(
exp_id
))
<<
" * "
<<
type_convert
<
float
>
(
bf8_ocp_t
{
bf8_spec_id
})
<<
" != "
<<
out
[
idx
];
}
}
}
// V = X * P; X, P - finite
for
(
ck
::
index_t
exp_id
=
0
;
exp_id
<
256
;
exp_id
++
)
{
if
(
exp_id
==
e8m0_nan_id
)
continue
;
for
(
ck
::
index_t
bf8_id
=
0
;
bf8_id
<
256
;
bf8_id
++
)
{
if
(
bf8_spec_ids
.
find
(
bf8_id
)
!=
bf8_spec_ids
.
end
())
continue
;
uint8_t
bf8_uid
=
static_cast
<
uint8_t
>
(
bf8_id
);
auto
idx
=
exp_id
*
256
+
bf8_uid
;
ASSERT_FLOAT_EQ
(
out
[
idx
],
type_convert
<
float
>
(
e8m0_bexp_t
(
exp_id
))
*
type_convert
<
float
>
(
bf8_ocp_t
{
bf8_uid
}))
<<
"exp_id: "
<<
exp_id
<<
" bf8_id: "
<<
bf8_uid
<<
std
::
endl
<<
type_convert
<
float
>
(
e8m0_bexp_t
(
exp_id
))
<<
" * "
<<
type_convert
<
float
>
(
bf8_ocp_t
{
bf8_uid
});
}
}
/// Test vector conversions
auto
i
=
256
*
256
;
// bf8x2 -> f32x2
EXPECT_EQ
(
out
[
i
++
],
-
powf
(
2.0
f
,
-
11.0
f
));
EXPECT_EQ
(
out
[
i
++
],
powf
(
2.0
f
,
-
13.0
f
));
// f32x2 -> bf8x2
// RNE
EXPECT_EQ
(
out
[
i
++
],
-
4.0
f
);
EXPECT_EQ
(
out
[
i
++
],
2.0
f
);
// SR
EXPECT_EQ
(
out
[
i
++
],
-
2.0
f
);
EXPECT_EQ
(
out
[
i
++
],
1.0
f
);
/// Test round to nearest even
EXPECT_EQ
(
out
[
i
++
],
1024.0
f
/
4.0
f
)
<<
"out[i-1]: "
<<
out
[
i
-
1
];
EXPECT_TRUE
(
std
::
isnan
(
out
[
i
++
]))
<<
"out[i-1]: "
<<
out
[
i
-
1
];
EXPECT_EQ
(
out
[
i
++
],
type_convert
<
float
>
(
ck
::
NumericLimits
<
bf8_ocp_t
>::
Max
()))
<<
"out[i-1]: "
<<
out
[
i
-
1
];
EXPECT_EQ
(
out
[
i
++
],
type_convert
<
float
>
(
ck
::
NumericLimits
<
bf8_ocp_t
>::
Max
()))
<<
"out[i-1]: "
<<
out
[
i
-
1
];
EXPECT_EQ
(
out
[
i
++
],
type_convert
<
float
>
(
ck
::
NumericLimits
<
bf8_ocp_t
>::
Lowest
()))
<<
"out[i-1]: "
<<
out
[
i
-
1
];
EXPECT_EQ
(
out
[
i
++
],
powf
(
2.0
f
,
14.0
f
))
<<
"out[i-1]: "
<<
out
[
i
-
1
];
EXPECT_EQ
(
test_size
,
completed
);
EXPECT_EQ
(
test_size
,
i
);
}
__global__
void
test_mx_bf8_device_scaled_convert
(
uint64_t
N
,
float
*
p_test
,
uint64_t
*
p_completed
)
{
test_mx_bf8_scaled_convert
(
N
,
p_test
,
p_completed
);
}
TEST
(
MXBF8
,
DeviceScaledConvert
)
{
std
::
vector
<
float
>
out
(
test_size
,
-
1.0
f
);
DeviceMem
device_out
(
test_size
*
sizeof
(
float
));
DeviceMem
device_completed
(
sizeof
(
uint64_t
));
device_out
.
SetValue
(
-
21.0
f
);
device_completed
.
SetValue
(
-
21.0
f
);
test_mx_bf8_device_scaled_convert
<<<
1
,
1
>>>
(
test_size
,
static_cast
<
float
*>
(
device_out
.
GetDeviceBuffer
()),
static_cast
<
uint64_t
*>
(
device_completed
.
GetDeviceBuffer
()));
uint64_t
completed
=
0
;
device_completed
.
FromDevice
(
&
completed
);
device_out
.
FromDevice
(
out
.
data
());
// V = X * P; X - E8M0 scale, P - BF8
// If X = NaN, then V = NaN regardless of P
uint8_t
e8m0_nan_id
=
ck
::
NumericLimits
<
e8m0_bexp_t
>::
QuietNaN
().
data
;
for
(
ck
::
index_t
bf8_id
=
0
;
bf8_id
<
256
;
bf8_id
++
)
{
auto
idx
=
e8m0_nan_id
*
256
+
bf8_id
;
ASSERT_TRUE
(
std
::
isnan
(
out
[
idx
]))
<<
"idx: "
<<
idx
<<
" out[idx]: "
<<
out
[
idx
];
}
// If P in {Inf, NaN}, then V = P
std
::
set
<
uint8_t
>
bf8_spec_ids
;
bf8_spec_ids
.
insert
(
0b11111111
);
//-NaN
bf8_spec_ids
.
insert
(
0b01111111
);
// +NaN
bf8_spec_ids
.
insert
(
0b11111101
);
//-NaN
bf8_spec_ids
.
insert
(
0b01111101
);
// +NaN
bf8_spec_ids
.
insert
(
0b11111110
);
//-NaN
bf8_spec_ids
.
insert
(
0b01111110
);
// +NaN
bf8_spec_ids
.
insert
(
0b11111100
);
//-inf
bf8_spec_ids
.
insert
(
0b01111100
);
// +inf
for
(
ck
::
index_t
exp_id
=
0
;
exp_id
<
256
;
exp_id
++
)
{
if
(
exp_id
==
e8m0_nan_id
)
continue
;
for
(
auto
bf8_spec_id
:
bf8_spec_ids
)
{
auto
idx
=
exp_id
*
256
+
bf8_spec_id
;
if
(
std
::
isnan
(
type_convert
<
float
>
(
bf8_ocp_t
{
bf8_spec_id
})))
{
ASSERT_TRUE
(
std
::
isnan
(
out
[
idx
]))
<<
"exp_id: "
<<
exp_id
<<
" bf8_id: "
<<
bf8_spec_id
<<
std
::
endl
<<
type_convert
<
float
>
(
e8m0_bexp_t
(
exp_id
))
<<
" * "
<<
type_convert
<
float
>
(
bf8_ocp_t
{
bf8_spec_id
})
<<
" != "
<<
out
[
idx
];
}
else
{
ASSERT_EQ
(
out
[
idx
],
type_convert
<
float
>
(
bf8_ocp_t
{
bf8_spec_id
}))
<<
"exp_id: "
<<
exp_id
<<
" bf8_id: "
<<
bf8_spec_id
<<
std
::
endl
<<
type_convert
<
float
>
(
e8m0_bexp_t
(
exp_id
))
<<
" * "
<<
type_convert
<
float
>
(
bf8_ocp_t
{
bf8_spec_id
})
<<
" != "
<<
out
[
idx
];
}
}
}
for
(
ck
::
index_t
exp_id
=
0
;
exp_id
<
256
;
exp_id
++
)
{
if
(
exp_id
==
e8m0_nan_id
)
continue
;
for
(
ck
::
index_t
bf8_id
=
0
;
bf8_id
<
256
;
bf8_id
++
)
{
if
(
bf8_spec_ids
.
find
(
bf8_id
)
!=
bf8_spec_ids
.
end
())
continue
;
uint8_t
bf8_uid
=
static_cast
<
uint8_t
>
(
bf8_id
);
auto
idx
=
exp_id
*
256
+
bf8_uid
;
ASSERT_FLOAT_EQ
(
out
[
idx
],
type_convert
<
float
>
(
e8m0_bexp_t
(
exp_id
))
*
type_convert
<
float
>
(
bf8_ocp_t
{
bf8_uid
}))
<<
"exp_id: "
<<
exp_id
<<
" bf8_id: "
<<
bf8_uid
<<
std
::
endl
<<
type_convert
<
float
>
(
e8m0_bexp_t
(
exp_id
))
<<
" * "
<<
type_convert
<
float
>
(
bf8_ocp_t
{
bf8_uid
});
}
}
/// Test vector conversions
auto
i
=
256
*
256
;
// bf8x2 -> f32x2
EXPECT_EQ
(
out
[
i
++
],
-
powf
(
2.0
f
,
-
11.0
f
));
EXPECT_EQ
(
out
[
i
++
],
powf
(
2.0
f
,
-
13.0
f
));
// f32x2 -> bf8x2
// RNE
EXPECT_EQ
(
out
[
i
++
],
-
4.0
f
);
EXPECT_EQ
(
out
[
i
++
],
2.0
f
);
// SR
EXPECT_EQ
(
out
[
i
++
],
-
2.0
f
);
EXPECT_EQ
(
out
[
i
++
],
1.0
f
);
/// Test round to nearest even
EXPECT_EQ
(
out
[
i
++
],
1024.0
f
/
4.0
f
)
<<
"out[i-1]: "
<<
out
[
i
-
1
];
EXPECT_TRUE
(
std
::
isnan
(
out
[
i
++
]))
<<
"out[i-1]: "
<<
out
[
i
-
1
];
#if 1
EXPECT_TRUE
(
std
::
isinf
(
out
[
i
++
]))
<<
"out[i-1]: "
<<
out
[
i
-
1
];
EXPECT_TRUE
(
std
::
isinf
(
out
[
i
++
]))
<<
"out[i-1]: "
<<
out
[
i
-
1
];
EXPECT_TRUE
(
std
::
isinf
(
out
[
i
++
]))
<<
"out[i-1]: "
<<
out
[
i
-
1
];
#else
// NOTE: Host and Device have different behavior.
// Device returns Infs, while Host returns Max (saturation to finite value).
EXPECT_EQ
(
out
[
i
++
],
type_convert
<
float
>
(
ck
::
NumericLimits
<
bf8_ocp_t
>::
Max
()))
<<
"out[i-1]: "
<<
out
[
i
-
1
];
EXPECT_EQ
(
out
[
i
++
],
type_convert
<
float
>
(
ck
::
NumericLimits
<
bf8_ocp_t
>::
Max
()))
<<
"out[i-1]: "
<<
out
[
i
-
1
];
EXPECT_EQ
(
out
[
i
++
],
type_convert
<
float
>
(
ck
::
NumericLimits
<
bf8_ocp_t
>::
Lowest
()))
<<
"out[i-1]: "
<<
out
[
i
-
1
];
#endif
EXPECT_EQ
(
out
[
i
++
],
powf
(
2.0
f
,
14.0
f
))
<<
"out[i-1]: "
<<
out
[
i
-
1
];
EXPECT_EQ
(
test_size
,
completed
);
EXPECT_EQ
(
test_size
,
i
);
}
__host__
__device__
float
vec16_generator
(
ck
::
index_t
i
)
{
return
powf
(
-
1.0
f
,
i
)
*
powf
(
2.0
f
,
i
);
}
__global__
void
test_mx_bf8x16_device_scaled_convert
(
float
*
p_test
,
uint64_t
*
p_completed
)
{
constexpr
int
N
=
16
;
if
(
p_completed
==
nullptr
)
{
return
;
}
uint64_t
&
i
=
*
p_completed
;
i
=
0
;
if
(
p_test
==
nullptr
)
{
return
;
}
auto
scale2
=
e8m0_bexp_t
(
2.0
f
);
bf8x16_ocp_t
bf8x16
{};
float16_t
float16
{};
ck
::
static_for
<
0
,
N
,
1
>
{}(
[
&
](
auto
ii
)
{
float16
[
static_cast
<
int
>
(
ii
)]
=
vec16_generator
(
ii
);
});
bf8x16
=
scaled_type_convert
<
bf8x16_ocp_t
>
(
scale2
,
float16
);
ck
::
static_for
<
0
,
N
,
1
>
{}([
&
](
auto
ii
)
{
p_test
[
i
++
]
=
type_convert
<
float
>
(
bf8x16
.
AsType
<
bf8_ocp_t
>
()(
ck
::
Number
<
ii
>
{}));
});
}
TEST
(
MXBF8
,
DeviceF32x16ToBF8x16ScaledConvert
)
{
constexpr
int
N
=
16
;
std
::
vector
<
float
>
out
(
N
,
-
1.0
f
);
DeviceMem
device_out
(
N
*
sizeof
(
float
));
DeviceMem
device_completed
(
sizeof
(
uint64_t
));
device_out
.
SetValue
(
-
21.0
f
);
device_completed
.
SetValue
(
-
21.0
f
);
test_mx_bf8x16_device_scaled_convert
<<<
1
,
1
>>>
(
static_cast
<
float
*>
(
device_out
.
GetDeviceBuffer
()),
static_cast
<
uint64_t
*>
(
device_completed
.
GetDeviceBuffer
()));
uint64_t
completed
=
0
;
device_completed
.
FromDevice
(
&
completed
);
device_out
.
FromDevice
(
out
.
data
());
auto
i
=
0
;
ck
::
static_for
<
0
,
N
,
1
>
{}([
&
](
auto
ii
)
{
EXPECT_EQ
(
out
[
i
++
],
vec16_generator
(
ii
)
/
2.0
f
)
<<
"ii: "
<<
ii
<<
std
::
endl
;
});
EXPECT_EQ
(
N
,
completed
);
EXPECT_EQ
(
N
,
i
);
}
__host__
__device__
float
vec32_generator
(
ck
::
index_t
i
)
{
if
(
i
<
16
)
{
return
vec16_generator
(
i
%
16
);
}
else
{
return
1.5
f
*
vec16_generator
(
i
%
16
);
}
}
__global__
void
test_mx_bf8x32_device_scaled_convert
(
float
*
p_test
,
uint64_t
*
p_completed
)
{
constexpr
int
N
=
32
;
if
(
p_completed
==
nullptr
)
{
return
;
}
uint64_t
&
i
=
*
p_completed
;
i
=
0
;
if
(
p_test
==
nullptr
)
{
return
;
}
auto
scale2
=
e8m0_bexp_t
(
2.0
f
);
bf8x32_ocp_t
bf8x32
{};
float32_t
float32
{};
ck
::
static_for
<
0
,
N
,
1
>
{}(
[
&
](
auto
ii
)
{
float32
[
static_cast
<
int
>
(
ii
)]
=
vec32_generator
(
ii
);
});
bf8x32
=
mxf8_convert_rne
<
bf8x32_ocp_t
>
(
float32
,
type_convert
<
float
>
(
scale2
));
ck
::
static_for
<
0
,
N
,
1
>
{}([
&
](
auto
ii
)
{
p_test
[
i
++
]
=
type_convert
<
float
>
(
bf8x32
.
AsType
<
bf8_ocp_t
>
()(
ck
::
Number
<
ii
>
{}));
});
}
TEST
(
MXBF8
,
DeviceF32x32ToBF8x32ScaledConvert
)
{
constexpr
int
N
=
32
;
std
::
vector
<
float
>
out
(
N
,
-
1.0
f
);
DeviceMem
device_out
(
N
*
sizeof
(
float
));
DeviceMem
device_completed
(
sizeof
(
uint64_t
));
device_out
.
SetValue
(
-
21.0
f
);
device_completed
.
SetValue
(
-
21.0
f
);
test_mx_bf8x32_device_scaled_convert
<<<
1
,
1
>>>
(
static_cast
<
float
*>
(
device_out
.
GetDeviceBuffer
()),
static_cast
<
uint64_t
*>
(
device_completed
.
GetDeviceBuffer
()));
uint64_t
completed
=
0
;
device_completed
.
FromDevice
(
&
completed
);
device_out
.
FromDevice
(
out
.
data
());
auto
i
=
0
;
ck
::
static_for
<
0
,
N
,
1
>
{}([
&
](
auto
ii
)
{
EXPECT_EQ
(
out
[
i
++
],
vec32_generator
(
ii
)
/
2.0
f
)
<<
"ii: "
<<
ii
<<
std
::
endl
;
});
EXPECT_EQ
(
N
,
completed
);
EXPECT_EQ
(
N
,
i
);
}
__global__
void
test_mx_bf8x32_device_scaled_convert_sr
(
float
*
p_test
,
uint64_t
*
p_completed
)
{
constexpr
int
N
=
32
;
if
(
p_completed
==
nullptr
)
{
return
;
}
uint64_t
&
i
=
*
p_completed
;
i
=
0
;
if
(
p_test
==
nullptr
)
{
return
;
}
auto
scale2
=
e8m0_bexp_t
(
8.0
f
);
bf8x32_ocp_t
bf8x32
{};
float32_t
float32
{};
ck
::
static_for
<
0
,
N
,
1
>
{}(
[
&
](
auto
ii
)
{
float32
[
static_cast
<
int
>
(
ii
)]
=
vec32_generator
(
ii
);
});
bf8x32
=
mxf8_convert_sr
<
bf8x32_ocp_t
>
(
float32
,
type_convert
<
float
>
(
scale2
));
ck
::
static_for
<
0
,
N
,
1
>
{}([
&
](
auto
ii
)
{
p_test
[
i
++
]
=
type_convert
<
float
>
(
bf8x32
.
AsType
<
bf8_ocp_t
>
()(
ck
::
Number
<
ii
>
{}));
});
}
TEST
(
MXBF8
,
DeviceF32x32ToBF8x32ScaledConvertSR
)
{
constexpr
int
N
=
32
;
std
::
vector
<
float
>
out
(
N
,
-
1.0
f
);
DeviceMem
device_out
(
N
*
sizeof
(
float
));
DeviceMem
device_completed
(
sizeof
(
uint64_t
));
device_out
.
SetValue
(
-
21.0
f
);
device_completed
.
SetValue
(
-
21.0
f
);
test_mx_bf8x32_device_scaled_convert_sr
<<<
1
,
1
>>>
(
static_cast
<
float
*>
(
device_out
.
GetDeviceBuffer
()),
static_cast
<
uint64_t
*>
(
device_completed
.
GetDeviceBuffer
()));
uint64_t
completed
=
0
;
device_completed
.
FromDevice
(
&
completed
);
device_out
.
FromDevice
(
out
.
data
());
auto
i
=
0
;
ck
::
static_for
<
0
,
N
,
1
>
{}([
&
](
auto
ii
)
{
EXPECT_EQ
(
out
[
i
++
],
vec32_generator
(
ii
)
/
8.0
f
)
<<
"ii: "
<<
ii
<<
std
::
endl
;
});
EXPECT_EQ
(
N
,
completed
);
EXPECT_EQ
(
N
,
i
);
}
__global__
void
test_mx_f32x32_device_scaled_convert
(
float
*
p_test
,
uint64_t
*
p_completed
)
{
constexpr
int
N
=
32
;
if
(
p_completed
==
nullptr
)
{
return
;
}
uint64_t
&
i
=
*
p_completed
;
i
=
0
;
if
(
p_test
==
nullptr
)
{
return
;
}
auto
scale2
=
e8m0_bexp_t
(
4.0
f
);
bf8x32_ocp_t
bf8x32
{};
float32_t
float32
{};
ck
::
static_for
<
0
,
N
,
1
>
{}([
&
](
auto
ii
)
{
bf8x32
.
AsType
<
bf8_ocp_t
>
()(
ii
)
=
type_convert
<
bf8_ocp_t
>
(
vec32_generator
(
ii
)
/
16.0
f
);
});
float32
=
scaled_type_convert
<
float32_t
>
(
scale2
,
bf8x32
);
ck
::
static_for
<
0
,
N
,
1
>
{}([
&
](
auto
ii
)
{
p_test
[
i
++
]
=
float32
[
static_cast
<
int
>
(
ii
)];
});
}
TEST
(
MXBF8
,
DeviceBF8x32ToF32x32ScaledConvert
)
{
constexpr
int
N
=
32
;
std
::
vector
<
float
>
out
(
N
,
-
1.0
f
);
DeviceMem
device_out
(
N
*
sizeof
(
float
));
DeviceMem
device_completed
(
sizeof
(
uint64_t
));
device_out
.
SetValue
(
-
21.0
f
);
device_completed
.
SetValue
(
-
21.0
f
);
test_mx_f32x32_device_scaled_convert
<<<
1
,
1
>>>
(
static_cast
<
float
*>
(
device_out
.
GetDeviceBuffer
()),
static_cast
<
uint64_t
*>
(
device_completed
.
GetDeviceBuffer
()));
uint64_t
completed
=
0
;
device_completed
.
FromDevice
(
&
completed
);
device_out
.
FromDevice
(
out
.
data
());
auto
i
=
0
;
ck
::
static_for
<
0
,
N
,
1
>
{}([
&
](
auto
ii
)
{
EXPECT_EQ
(
out
[
i
++
],
vec32_generator
(
ii
)
/
4.0
f
)
<<
"ii: "
<<
ii
<<
std
::
endl
;
});
EXPECT_EQ
(
N
,
completed
);
EXPECT_EQ
(
N
,
i
);
}
test/data_type/test_mx_fp8.cpp
0 → 100644
View file @
dbb7002d
// SPDX-License-Identifier: MIT
// Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved.
#include "gtest/gtest.h"
#include "ck/library/utility/device_memory.hpp"
#include "ck/utility/scaled_type_convert.hpp"
using
ck
::
e8m0_bexp_t
;
using
ck
::
f8_ocp_t
;
using
ck
::
f8x16_ocp_t
;
using
ck
::
f8x2_ocp_t
;
using
ck
::
f8x32_ocp_t
;
using
ck
::
float16_t
;
using
ck
::
float2_t
;
using
ck
::
float32_t
;
using
ck
::
mxf8_convert_rne
;
using
ck
::
mxf8_convert_sr
;
using
ck
::
scaled_type_convert
;
using
ck
::
type_convert
;
using
ck
::
fp8_impl
::
fp8x2_storage_t
;
constexpr
uint64_t
test_size
=
256
*
256
+
2
+
4
+
6
;
/**
* @brief Tests conversion of FP8 values to float using E8M0 exponent scaling.
*
* This function performs a series of conversions from FP8 values to float values using
* E8M0 exponent scaling. It handles all possible combinations of E8M0 and FP8 values,
* as well as specific vector and rounding conversions.
*
* @param N The maximum number of conversions to perform.
* @param p_test Pointer to the output array where the converted float values will be stored.
* @param p_completed Pointer to a variable that tracks the number of completed conversions.
*
* @note If either p_test or p_completed is nullptr, the function will return immediately.
* @note The function will stop converting if the number of conversions reaches N.
* @note First 256*256 conversions are for all possible combinations of E8M0 and FP8 values that are
* stored in memory sequentially with FP8 values varying faster.
*
* The function performs the following conversions:
* - All possible combinations of E8M0 and FP8 values. [256x256]
* - Vector conversions f8x2 -> f32x2. [2]
* - Vector conversions f32x2 -> f8x2 rne. [2]
* - Vector conversions f32x2 -> f8x2 sr. [2]
* - Round to nearest even conversions for specific float values. [6]
*
* The results are stored in the p_test array, and the number of completed conversions
* is updated in the p_completed variable.
*/
__host__
__device__
void
test_mx_fp8_scaled_convert
(
uint64_t
N
,
float
*
p_test
,
uint64_t
*
p_completed
)
{
if
(
p_completed
==
nullptr
)
{
return
;
}
uint64_t
&
i
=
*
p_completed
;
i
=
0
;
if
(
p_test
==
nullptr
)
{
return
;
}
// All possible combinations of E8M0 and FP8
for
(
ck
::
index_t
exp_id
=
0
;
exp_id
<
256
;
exp_id
++
)
{
for
(
ck
::
index_t
fp8_id
=
0
;
fp8_id
<
256
;
fp8_id
++
)
{
uint8_t
fp8_uid
=
static_cast
<
uint8_t
>
(
fp8_id
);
auto
v
=
scaled_type_convert
<
float
>
(
e8m0_bexp_t
(
exp_id
),
f8_ocp_t
{
fp8_uid
});
p_test
[
i
]
=
v
;
i
++
;
if
(
i
>=
N
)
{
return
;
}
}
}
/// Test vector conversions
// f8x2 -> f32x2
f8x2_ocp_t
fp8x2
{
f8x2_ocp_t
::
data_v
{
0b10001000
,
0b00000001
}};
//-2^-6, 2^-9
auto
scale2
=
e8m0_bexp_t
(
2.0
f
);
float2_t
f32x2
=
scaled_type_convert
<
float2_t
>
(
scale2
,
fp8x2
);
p_test
[
i
++
]
=
f32x2
[
0
];
if
(
i
>=
N
)
{
return
;
}
p_test
[
i
++
]
=
f32x2
[
1
];
if
(
i
>=
N
)
{
return
;
}
// f32x2 -> f8x2
f32x2
=
{
-
8.0
f
,
4.0
f
};
fp8x2
=
mxf8_convert_rne
<
f8x2_ocp_t
>
(
f32x2
,
type_convert
<
float
>
(
scale2
));
// expect {-4, 2}
p_test
[
i
++
]
=
type_convert
<
float
>
(
fp8x2
.
AsType
<
f8_ocp_t
>
()(
ck
::
Number
<
0
>
{}));
//-4f
if
(
i
>=
N
)
{
return
;
}
p_test
[
i
++
]
=
type_convert
<
float
>
(
fp8x2
.
AsType
<
f8_ocp_t
>
()(
ck
::
Number
<
1
>
{}));
// 2f
if
(
i
>=
N
)
{
return
;
}
auto
scale4
=
e8m0_bexp_t
(
4.0
f
);
fp8x2
=
mxf8_convert_sr
<
f8x2_ocp_t
>
(
f32x2
,
type_convert
<
float
>
(
scale4
));
// expect {-2, 1}
p_test
[
i
++
]
=
type_convert
<
float
>
(
fp8x2
.
AsType
<
f8_ocp_t
>
()(
ck
::
Number
<
0
>
{}));
//-2f
if
(
i
>=
N
)
{
return
;
}
p_test
[
i
++
]
=
type_convert
<
float
>
(
fp8x2
.
AsType
<
f8_ocp_t
>
()(
ck
::
Number
<
1
>
{}));
// 1f
if
(
i
>=
N
)
{
return
;
}
/// Test round to nearest even
p_test
[
i
++
]
=
type_convert
<
float
>
(
mxf8_convert_rne
<
f8_ocp_t
>
(
1024.0
f
,
4.0
f
));
// 1024/4
if
(
i
>=
N
)
{
return
;
}
p_test
[
i
++
]
=
type_convert
<
float
>
(
mxf8_convert_rne
<
f8_ocp_t
>
(
std
::
numeric_limits
<
float
>::
quiet_NaN
(),
4.0
f
));
// => NaN
if
(
i
>=
N
)
{
return
;
}
// Inf/2 > 448 => NaN on device
p_test
[
i
++
]
=
type_convert
<
float
>
(
mxf8_convert_rne
<
f8_ocp_t
>
(
std
::
numeric_limits
<
float
>::
infinity
(),
2.0
f
));
if
(
i
>=
N
)
{
return
;
}
// 256/0.5 > 448 => NaN on device
p_test
[
i
++
]
=
type_convert
<
float
>
(
mxf8_convert_rne
<
f8_ocp_t
>
(
256.0
f
,
0.5
f
));
if
(
i
>=
N
)
{
return
;
}
// -256/0.5 < -448 => NaN on device
p_test
[
i
++
]
=
type_convert
<
float
>
(
mxf8_convert_rne
<
f8_ocp_t
>
(
-
256.0
f
,
0.5
f
));
if
(
i
>=
N
)
{
return
;
}
// proper scale selection 2^13 < 10000; 2^8 < 448 => scale = 2^(13-8) = 2^5
p_test
[
i
++
]
=
type_convert
<
float
>
(
mxf8_convert_rne
<
f8_ocp_t
>
(
10000.0
f
,
32.0
f
));
// 10000/32 = 312.5
if
(
i
>=
N
)
{
return
;
}
}
TEST
(
MXFP8
,
HostScaledConvert
)
{
std
::
vector
<
float
>
out
(
test_size
,
-
1.0
f
);
uint64_t
completed
=
0
;
test_mx_fp8_scaled_convert
(
test_size
,
out
.
data
(),
&
completed
);
// V = X * P; X - E8M0 scale, P - FP8
// If X = NaN, then V = NaN regardless of P
uint8_t
e8m0_nan_id
=
ck
::
NumericLimits
<
e8m0_bexp_t
>::
QuietNaN
().
data
;
for
(
ck
::
index_t
fp8_id
=
0
;
fp8_id
<
256
;
fp8_id
++
)
{
auto
idx
=
e8m0_nan_id
*
256
+
fp8_id
;
ASSERT_TRUE
(
std
::
isnan
(
out
[
idx
]));
}
// If P in {Inf, NaN}, then V = P
std
::
set
<
uint8_t
>
fp8_nan_ids
;
fp8_nan_ids
.
insert
(
0b11111111
);
//-NaN
fp8_nan_ids
.
insert
(
0b01111111
);
// +NaN
for
(
ck
::
index_t
exp_id
=
0
;
exp_id
<
256
;
exp_id
++
)
{
if
(
exp_id
==
e8m0_nan_id
)
continue
;
for
(
auto
fp8_nan_id
:
fp8_nan_ids
)
{
auto
idx
=
exp_id
*
256
+
fp8_nan_id
;
ASSERT_TRUE
(
std
::
isnan
(
out
[
idx
]));
}
}
for
(
ck
::
index_t
exp_id
=
0
;
exp_id
<
256
;
exp_id
++
)
{
if
(
exp_id
==
e8m0_nan_id
)
continue
;
for
(
ck
::
index_t
fp8_id
=
0
;
fp8_id
<
256
;
fp8_id
++
)
{
if
(
fp8_nan_ids
.
find
(
fp8_id
)
!=
fp8_nan_ids
.
end
())
continue
;
uint8_t
fp8_uid
=
static_cast
<
uint8_t
>
(
fp8_id
);
auto
idx
=
exp_id
*
256
+
fp8_uid
;
ASSERT_FLOAT_EQ
(
out
[
idx
],
type_convert
<
float
>
(
e8m0_bexp_t
(
exp_id
))
*
type_convert
<
float
>
(
f8_ocp_t
{
fp8_uid
}))
<<
"exp_id: "
<<
exp_id
<<
" fp8_id: "
<<
fp8_id
<<
std
::
endl
<<
type_convert
<
float
>
(
e8m0_bexp_t
(
exp_id
))
<<
" * "
<<
type_convert
<
float
>
(
f8_ocp_t
{
fp8_uid
});
}
}
/// Test vector conversions
auto
i
=
256
*
256
;
// f8x2 -> f32x2
EXPECT_EQ
(
out
[
i
++
],
-
powf
(
2.0
f
,
-
5.0
f
));
EXPECT_EQ
(
out
[
i
++
],
powf
(
2.0
f
,
-
8.0
f
));
// f32x2 -> fp8x2
// RNE
EXPECT_EQ
(
out
[
i
++
],
-
4.0
f
);
EXPECT_EQ
(
out
[
i
++
],
2.0
f
);
// SR
EXPECT_EQ
(
out
[
i
++
],
-
2.0
f
);
EXPECT_EQ
(
out
[
i
++
],
1.0
f
);
/// Test round to nearest even
EXPECT_EQ
(
out
[
i
++
],
1024.0
f
/
4.0
f
)
<<
"out[i-1]: "
<<
out
[
i
-
1
];
EXPECT_TRUE
(
std
::
isnan
(
out
[
i
++
]))
<<
"out[i-1]: "
<<
out
[
i
-
1
];
EXPECT_EQ
(
out
[
i
++
],
type_convert
<
float
>
(
ck
::
NumericLimits
<
f8_ocp_t
>::
Max
()))
<<
"out[i-1]: "
<<
out
[
i
-
1
];
EXPECT_EQ
(
out
[
i
++
],
type_convert
<
float
>
(
ck
::
NumericLimits
<
f8_ocp_t
>::
Max
()))
<<
"out[i-1]: "
<<
out
[
i
-
1
];
EXPECT_EQ
(
out
[
i
++
],
type_convert
<
float
>
(
ck
::
NumericLimits
<
f8_ocp_t
>::
Lowest
()))
<<
"out[i-1]: "
<<
out
[
i
-
1
];
EXPECT_EQ
(
out
[
i
++
],
type_convert
<
float
>
(
type_convert
<
f8_ocp_t
>
(
312.5
f
)))
<<
"out[i-1]: "
<<
out
[
i
-
1
];
EXPECT_EQ
(
test_size
,
completed
);
EXPECT_EQ
(
test_size
,
i
);
}
__global__
void
test_mx_fp8_device_scaled_convert
(
uint64_t
N
,
float
*
p_test
,
uint64_t
*
p_completed
)
{
test_mx_fp8_scaled_convert
(
N
,
p_test
,
p_completed
);
}
TEST
(
MXFP8
,
DeviceScaledConvert
)
{
std
::
vector
<
float
>
out
(
test_size
,
-
1.0
f
);
DeviceMem
device_out
(
test_size
*
sizeof
(
float
));
DeviceMem
device_completed
(
sizeof
(
uint64_t
));
device_out
.
SetValue
(
-
21.0
f
);
device_completed
.
SetValue
(
-
21.0
f
);
test_mx_fp8_device_scaled_convert
<<<
1
,
1
>>>
(
test_size
,
static_cast
<
float
*>
(
device_out
.
GetDeviceBuffer
()),
static_cast
<
uint64_t
*>
(
device_completed
.
GetDeviceBuffer
()));
uint64_t
completed
=
0
;
device_completed
.
FromDevice
(
&
completed
);
device_out
.
FromDevice
(
out
.
data
());
// V = X * P; X - E8M0 scale, P - FP8
// If X = NaN, then V = NaN regardless of P
uint8_t
e8m0_nan_id
=
ck
::
NumericLimits
<
e8m0_bexp_t
>::
QuietNaN
().
data
;
for
(
ck
::
index_t
fp8_id
=
0
;
fp8_id
<
256
;
fp8_id
++
)
{
auto
idx
=
e8m0_nan_id
*
256
+
fp8_id
;
ASSERT_TRUE
(
std
::
isnan
(
out
[
idx
]))
<<
"idx: "
<<
idx
<<
" out[idx]: "
<<
out
[
idx
];
}
// If P in {Inf, NaN}, then V = P
std
::
set
<
uint8_t
>
fp8_nan_ids
;
fp8_nan_ids
.
insert
(
0b11111111
);
//-NaN
fp8_nan_ids
.
insert
(
0b01111111
);
// +NaN
for
(
ck
::
index_t
exp_id
=
0
;
exp_id
<
256
;
exp_id
++
)
{
if
(
exp_id
==
e8m0_nan_id
)
continue
;
for
(
auto
fp8_nan_id
:
fp8_nan_ids
)
{
auto
idx
=
exp_id
*
256
+
fp8_nan_id
;
ASSERT_TRUE
(
std
::
isnan
(
out
[
idx
]))
<<
"idx: "
<<
idx
<<
" out[idx]: "
<<
out
[
idx
];
}
}
for
(
ck
::
index_t
exp_id
=
0
;
exp_id
<
256
;
exp_id
++
)
{
if
(
exp_id
==
e8m0_nan_id
)
continue
;
for
(
ck
::
index_t
fp8_id
=
0
;
fp8_id
<
256
;
fp8_id
++
)
{
if
(
fp8_nan_ids
.
find
(
fp8_id
)
!=
fp8_nan_ids
.
end
())
continue
;
uint8_t
fp8_uid
=
static_cast
<
uint8_t
>
(
fp8_id
);
auto
idx
=
exp_id
*
256
+
fp8_uid
;
ASSERT_FLOAT_EQ
(
out
[
idx
],
type_convert
<
float
>
(
e8m0_bexp_t
(
exp_id
))
*
type_convert
<
float
>
(
f8_ocp_t
{
fp8_uid
}))
<<
"exp_id: "
<<
exp_id
<<
" fp8_id: "
<<
fp8_id
<<
std
::
endl
<<
type_convert
<
float
>
(
e8m0_bexp_t
(
exp_id
))
<<
" * "
<<
type_convert
<
float
>
(
f8_ocp_t
{
fp8_uid
});
}
}
/// Test vector conversions
auto
i
=
256
*
256
;
// f8x2 -> f32x2
EXPECT_EQ
(
out
[
i
++
],
-
powf
(
2.0
f
,
-
5.0
f
));
EXPECT_EQ
(
out
[
i
++
],
powf
(
2.0
f
,
-
8.0
f
));
// f32x2 -> fp8x2
// RNE
EXPECT_EQ
(
out
[
i
++
],
-
4.0
f
);
EXPECT_EQ
(
out
[
i
++
],
2.0
f
);
// SR
EXPECT_EQ
(
out
[
i
++
],
-
2.0
f
);
EXPECT_EQ
(
out
[
i
++
],
1.0
f
);
/// Test round to nearest even
EXPECT_EQ
(
out
[
i
++
],
1024.0
f
/
4.0
f
)
<<
"out[i-1]: "
<<
out
[
i
-
1
];
EXPECT_TRUE
(
std
::
isnan
(
out
[
i
++
]))
<<
"out[i-1]: "
<<
out
[
i
-
1
];
#if 1
EXPECT_TRUE
(
std
::
isnan
(
out
[
i
++
]))
<<
"out[i-1]: "
<<
out
[
i
-
1
];
EXPECT_TRUE
(
std
::
isnan
(
out
[
i
++
]))
<<
"out[i-1]: "
<<
out
[
i
-
1
];
EXPECT_TRUE
(
std
::
isnan
(
out
[
i
++
]))
<<
"out[i-1]: "
<<
out
[
i
-
1
];
#else
// NOTE: Host and Device have different behavior.
// Device returns NaN, while Host returns Max (saturation to finite value).
EXPECT_EQ
(
out
[
i
++
],
type_convert
<
float
>
(
ck
::
NumericLimits
<
f8_ocp_t
>::
Max
()))
<<
"out[i-1]: "
<<
out
[
i
-
1
];
EXPECT_EQ
(
out
[
i
++
],
type_convert
<
float
>
(
ck
::
NumericLimits
<
f8_ocp_t
>::
Max
()))
<<
"out[i-1]: "
<<
out
[
i
-
1
];
EXPECT_EQ
(
out
[
i
++
],
type_convert
<
float
>
(
ck
::
NumericLimits
<
f8_ocp_t
>::
Lowest
()))
<<
"out[i-1]: "
<<
out
[
i
-
1
];
#endif
EXPECT_EQ
(
out
[
i
++
],
type_convert
<
float
>
(
type_convert
<
f8_ocp_t
>
(
312.5
f
)))
<<
"out[i-1]: "
<<
out
[
i
-
1
];
EXPECT_EQ
(
test_size
,
completed
);
EXPECT_EQ
(
test_size
,
i
);
}
__host__
__device__
float
vec16_generator
(
ck
::
index_t
i
)
{
return
(
i
<
8
?
-
1.0
:
1.0
)
*
powf
(
2.0
f
,
i
%
8
);
}
__global__
void
test_mx_fp8x16_device_scaled_convert
(
float
*
p_test
,
uint64_t
*
p_completed
)
{
constexpr
int
N
=
16
;
if
(
p_completed
==
nullptr
)
{
return
;
}
uint64_t
&
i
=
*
p_completed
;
i
=
0
;
if
(
p_test
==
nullptr
)
{
return
;
}
auto
scale2
=
e8m0_bexp_t
(
2.0
f
);
f8x16_ocp_t
fp8x16
{};
float16_t
float16
{};
ck
::
static_for
<
0
,
N
,
1
>
{}(
[
&
](
auto
ii
)
{
float16
[
static_cast
<
int
>
(
ii
)]
=
vec16_generator
(
ii
);
});
fp8x16
=
scaled_type_convert
<
ck
::
f8x16_ocp_t
>
(
scale2
,
float16
);
ck
::
static_for
<
0
,
N
,
1
>
{}([
&
](
auto
ii
)
{
p_test
[
i
++
]
=
type_convert
<
float
>
(
fp8x16
.
AsType
<
f8_ocp_t
>
()(
ck
::
Number
<
ii
>
{}));
});
}
TEST
(
MXFP8
,
DeviceF32x16ToF8x16ScaledConvert
)
{
constexpr
int
N
=
16
;
std
::
vector
<
float
>
out
(
N
,
-
1.0
f
);
DeviceMem
device_out
(
N
*
sizeof
(
float
));
DeviceMem
device_completed
(
sizeof
(
uint64_t
));
device_out
.
SetValue
(
-
21.0
f
);
device_completed
.
SetValue
(
-
21.0
f
);
test_mx_fp8x16_device_scaled_convert
<<<
1
,
1
>>>
(
static_cast
<
float
*>
(
device_out
.
GetDeviceBuffer
()),
static_cast
<
uint64_t
*>
(
device_completed
.
GetDeviceBuffer
()));
uint64_t
completed
=
0
;
device_completed
.
FromDevice
(
&
completed
);
device_out
.
FromDevice
(
out
.
data
());
auto
i
=
0
;
ck
::
static_for
<
0
,
N
,
1
>
{}([
&
](
auto
ii
)
{
EXPECT_EQ
(
out
[
i
++
],
vec16_generator
(
ii
)
/
2.0
f
)
<<
"ii: "
<<
ii
<<
std
::
endl
;
});
EXPECT_EQ
(
N
,
completed
);
EXPECT_EQ
(
N
,
i
);
}
__host__
__device__
float
vec32_generator
(
ck
::
index_t
i
)
{
if
(
i
<
16
)
{
return
vec16_generator
(
i
%
16
);
}
else
{
return
1.5
f
*
vec16_generator
(
i
%
16
);
}
}
__global__
void
test_mx_fp8x32_device_scaled_convert
(
float
*
p_test
,
uint64_t
*
p_completed
)
{
constexpr
int
N
=
32
;
if
(
p_completed
==
nullptr
)
{
return
;
}
uint64_t
&
i
=
*
p_completed
;
i
=
0
;
if
(
p_test
==
nullptr
)
{
return
;
}
auto
scale2
=
e8m0_bexp_t
(
2.0
f
);
f8x32_ocp_t
fp8x32
{};
float32_t
float32
{};
ck
::
static_for
<
0
,
N
,
1
>
{}(
[
&
](
auto
ii
)
{
float32
[
static_cast
<
int
>
(
ii
)]
=
vec32_generator
(
ii
);
});
fp8x32
=
mxf8_convert_rne
<
f8x32_ocp_t
>
(
float32
,
type_convert
<
float
>
(
scale2
));
ck
::
static_for
<
0
,
N
,
1
>
{}(
[
&
](
auto
ii
)
{
p_test
[
i
++
]
=
type_convert
<
float
>
(
fp8x32
.
AsType
<
f8_ocp_t
>
()(
ii
));
});
}
TEST
(
MXFP8
,
DeviceF32x32ToF8x32ScaledConvert
)
{
constexpr
int
N
=
32
;
std
::
vector
<
float
>
out
(
N
,
-
1.0
f
);
DeviceMem
device_out
(
N
*
sizeof
(
float
));
DeviceMem
device_completed
(
sizeof
(
uint64_t
));
device_out
.
SetValue
(
-
21.0
f
);
device_completed
.
SetValue
(
-
21.0
f
);
test_mx_fp8x32_device_scaled_convert
<<<
1
,
1
>>>
(
static_cast
<
float
*>
(
device_out
.
GetDeviceBuffer
()),
static_cast
<
uint64_t
*>
(
device_completed
.
GetDeviceBuffer
()));
uint64_t
completed
=
0
;
device_completed
.
FromDevice
(
&
completed
);
device_out
.
FromDevice
(
out
.
data
());
auto
i
=
0
;
ck
::
static_for
<
0
,
N
,
1
>
{}([
&
](
auto
ii
)
{
EXPECT_EQ
(
out
[
i
++
],
vec32_generator
(
ii
)
/
2.0
f
)
<<
"ii: "
<<
ii
<<
std
::
endl
;
});
EXPECT_EQ
(
N
,
completed
);
EXPECT_EQ
(
N
,
i
);
}
__global__
void
test_mx_fp8x32_device_scaled_convert_sr
(
float
*
p_test
,
uint64_t
*
p_completed
)
{
constexpr
int
N
=
32
;
if
(
p_completed
==
nullptr
)
{
return
;
}
uint64_t
&
i
=
*
p_completed
;
i
=
0
;
if
(
p_test
==
nullptr
)
{
return
;
}
auto
scale2
=
e8m0_bexp_t
(
8.0
f
);
f8x32_ocp_t
fp8x32
{};
float32_t
float32
{};
ck
::
static_for
<
0
,
N
,
1
>
{}(
[
&
](
auto
ii
)
{
float32
[
static_cast
<
int
>
(
ii
)]
=
vec32_generator
(
ii
);
});
fp8x32
=
mxf8_convert_sr
<
f8x32_ocp_t
>
(
float32
,
type_convert
<
float
>
(
scale2
));
ck
::
static_for
<
0
,
N
,
1
>
{}(
[
&
](
auto
ii
)
{
p_test
[
i
++
]
=
type_convert
<
float
>
(
fp8x32
.
AsType
<
f8_ocp_t
>
()(
ii
));
});
}
TEST
(
MXFP8
,
DeviceF32x32ToF8x32ScaledConvertSR
)
{
constexpr
int
N
=
32
;
std
::
vector
<
float
>
out
(
N
,
-
1.0
f
);
DeviceMem
device_out
(
N
*
sizeof
(
float
));
DeviceMem
device_completed
(
sizeof
(
uint64_t
));
device_out
.
SetValue
(
-
21.0
f
);
device_completed
.
SetValue
(
-
21.0
f
);
test_mx_fp8x32_device_scaled_convert_sr
<<<
1
,
1
>>>
(
static_cast
<
float
*>
(
device_out
.
GetDeviceBuffer
()),
static_cast
<
uint64_t
*>
(
device_completed
.
GetDeviceBuffer
()));
uint64_t
completed
=
0
;
device_completed
.
FromDevice
(
&
completed
);
device_out
.
FromDevice
(
out
.
data
());
auto
i
=
0
;
ck
::
static_for
<
0
,
N
,
1
>
{}([
&
](
auto
ii
)
{
EXPECT_EQ
(
out
[
i
++
],
vec32_generator
(
ii
)
/
8.0
f
)
<<
"ii: "
<<
ii
<<
std
::
endl
;
});
EXPECT_EQ
(
N
,
completed
);
EXPECT_EQ
(
N
,
i
);
}
__global__
void
test_mx_f32x32_device_scaled_convert
(
float
*
p_test
,
uint64_t
*
p_completed
)
{
constexpr
int
N
=
32
;
if
(
p_completed
==
nullptr
)
{
return
;
}
uint64_t
&
i
=
*
p_completed
;
i
=
0
;
if
(
p_test
==
nullptr
)
{
return
;
}
auto
scale2
=
e8m0_bexp_t
(
4.0
f
);
f8x32_ocp_t
fp8x32
{};
float32_t
float32
{};
ck
::
static_for
<
0
,
N
,
1
>
{}([
&
](
auto
ii
)
{
fp8x32
.
AsType
<
f8_ocp_t
>
()(
ii
)
=
type_convert
<
f8_ocp_t
>
(
vec32_generator
(
ii
)
/
16.0
f
);
});
float32
=
scaled_type_convert
<
float32_t
>
(
scale2
,
fp8x32
);
ck
::
static_for
<
0
,
N
,
1
>
{}([
&
](
auto
ii
)
{
p_test
[
i
++
]
=
float32
[
static_cast
<
int
>
(
ii
)];
});
}
TEST
(
MXFP8
,
DeviceF8x32ToF32x32ScaledConvert
)
{
constexpr
int
N
=
32
;
std
::
vector
<
float
>
out
(
N
,
-
1.0
f
);
DeviceMem
device_out
(
N
*
sizeof
(
float
));
DeviceMem
device_completed
(
sizeof
(
uint64_t
));
device_out
.
SetValue
(
-
21.0
f
);
device_completed
.
SetValue
(
-
21.0
f
);
test_mx_f32x32_device_scaled_convert
<<<
1
,
1
>>>
(
static_cast
<
float
*>
(
device_out
.
GetDeviceBuffer
()),
static_cast
<
uint64_t
*>
(
device_completed
.
GetDeviceBuffer
()));
uint64_t
completed
=
0
;
device_completed
.
FromDevice
(
&
completed
);
device_out
.
FromDevice
(
out
.
data
());
auto
i
=
0
;
ck
::
static_for
<
0
,
N
,
1
>
{}([
&
](
auto
ii
)
{
EXPECT_EQ
(
out
[
i
++
],
vec32_generator
(
ii
)
/
4.0
f
)
<<
"ii: "
<<
ii
<<
std
::
endl
;
});
EXPECT_EQ
(
N
,
completed
);
EXPECT_EQ
(
N
,
i
);
}
test/data_type/test_pk_i4.cpp
0 → 100644
View file @
dbb7002d
// SPDX-License-Identifier: MIT
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
#include <bitset>
#include <cinttypes>
#include <cstdint>
#include <iomanip>
#include "gtest/gtest.h"
#include <hip/hip_runtime.h>
#include "ck/host_utility/hip_check_error.hpp"
#include "ck/utility/data_type.hpp"
#include "ck/utility/math_v2.hpp"
#include "ck/utility/get_id.hpp"
#include "ck/library/utility/device_memory.hpp"
#include "ck/tensor_operation/gpu/element/unary_element_wise_operation.hpp"
using
ck
::
bhalf2_t
;
using
ck
::
bhalf_t
;
using
ck
::
float2_t
;
using
ck
::
half2_t
;
using
ck
::
half4_t
;
using
ck
::
half_t
;
using
ck
::
pk_i4_t
;
using
ck
::
pk_i4x4_t
;
TEST
(
PackedInt4
,
ConvertToFloat
)
{
#ifdef CK_USE_PK4_LAYOUT_SHUFFLE
constexpr
float
first_input_val
=
7.
f
;
constexpr
float
second_input_val
=
-
1.
f
;
#else
constexpr
float
first_input_val
=
-
1.
f
;
constexpr
float
second_input_val
=
7.
f
;
#endif
uint8_t
data
=
0b11110111
;
// {-1, 7}
pk_i4_t
in
=
ck
::
bit_cast
<
int8_t
>
(
data
);
float2_t
out
=
ck
::
type_convert
<
float2_t
>
(
in
);
EXPECT_EQ
(
out
.
x
,
first_input_val
);
EXPECT_EQ
(
out
.
y
,
second_input_val
);
}
TEST
(
PackedInt4
,
ConvertToHalf
)
{
#ifdef CK_USE_PK4_LAYOUT_SHUFFLE
constexpr
half_t
first_input_val
=
ck
::
type_convert
<
half_t
>
(
7.
f
);
constexpr
half_t
second_input_val
=
ck
::
type_convert
<
half_t
>
(
-
1.
f
);
#else
constexpr
half_t
first_input_val
=
ck
::
type_convert
<
half_t
>
(
-
1.
f
);
constexpr
half_t
second_input_val
=
ck
::
type_convert
<
half_t
>
(
7.
f
);
#endif
uint8_t
data
=
0b11110111
;
// {-1, 7}
pk_i4_t
in
=
ck
::
bit_cast
<
int8_t
>
(
data
);
half2_t
out
=
ck
::
type_convert
<
half2_t
>
(
in
);
EXPECT_EQ
(
out
.
x
,
first_input_val
);
EXPECT_EQ
(
out
.
y
,
second_input_val
);
}
TEST
(
PackedInt4
,
ConvertToBHalf
)
{
#ifdef CK_USE_PK4_LAYOUT_SHUFFLE
const
bhalf_t
first_input_val
=
ck
::
type_convert
<
bhalf_t
>
(
7.
f
);
const
bhalf_t
second_input_val
=
ck
::
type_convert
<
bhalf_t
>
(
-
1.
f
);
#else
const
bhalf_t
first_input_val
=
ck
::
type_convert
<
bhalf_t
>
(
-
1.
f
);
const
bhalf_t
second_input_val
=
ck
::
type_convert
<
bhalf_t
>
(
7.
f
);
#endif
uint8_t
data
=
0b11110111
;
// {-1, 7}
pk_i4_t
in
=
ck
::
bit_cast
<
int8_t
>
(
data
);
bhalf2_t
out
=
ck
::
type_convert
<
bhalf2_t
>
(
in
);
EXPECT_EQ
(
out
.
x
,
first_input_val
);
EXPECT_EQ
(
out
.
y
,
second_input_val
);
}
test/mx_mfma_op/CMakeLists.txt
0 → 100644
View file @
dbb7002d
add_custom_target
(
test_mx_mfma
)
add_gtest_executable
(
test_mx_mfma_op mx_mfma_op.cpp
)
if
(
result EQUAL 0
)
target_link_libraries
(
test_mx_mfma_op PRIVATE utility
)
endif
()
add_dependencies
(
test_mx_mfma test_mx_mfma_op
)
test/mx_mfma_op/mx_mfma_op.cpp
0 → 100644
View file @
dbb7002d
// SPDX-License-Identifier: MIT
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
#include "gtest/gtest.h"
#include "mx_mfma_op.hpp"
using
ck
::
e8m0_bexp_t
;
using
ck
::
f8_t
;
using
ck
::
half_t
;
using
ck
::
type_convert
;
/**
* @brief Run the test for the given MFMA instruction
*
* @param init - selects initialization algorithm for A and B tensors
*/
template
<
typename
AType
,
typename
BType
,
typename
CType
,
ck
::
MFMA_F8F6F4
mfma
>
bool
run_mfma_test
(
ck
::
index_t
init
)
{
using
ALayout
=
ck
::
tensor_layout
::
gemm
::
ColumnMajor
;
using
BLayout
=
ck
::
tensor_layout
::
gemm
::
ColumnMajor
;
using
CLayout
=
ck
::
tensor_layout
::
gemm
::
ColumnMajor
;
using
AccType
=
float
;
// only MFMA_F32 instructions supported
using
CPUAccType
=
AccType
;
ck
::
mfma_type
<
static_cast
<
ck
::
MfmaInstr
>
(
mfma
)
>
mfma_instr
;
constexpr
auto
BLOCK_M
=
mfma_instr
.
m_per_blk
;
constexpr
auto
BLOCK_N
=
mfma_instr
.
n_per_blk
;
constexpr
auto
BLOCK_K
=
mfma_instr
.
num_input_blks
*
mfma_instr
.
k_per_blk
;
const
auto
mx_mfma_kernel
=
ck
::
matmul
<
AType
,
BType
,
CType
,
AccType
,
BLOCK_M
,
BLOCK_N
,
BLOCK_K
>
;
bool
pass
=
true
;
pass
=
ck
::
mfma_test
::
TestMFMA
<
decltype
(
mx_mfma_kernel
),
AType
,
BType
,
CType
,
AccType
,
CPUAccType
,
ALayout
,
BLayout
,
CLayout
,
BLOCK_M
,
BLOCK_N
,
BLOCK_K
>
{}(
mx_mfma_kernel
,
init
);
return
pass
;
}
TEST
(
MFMA
,
FP8MFMA16x16x128
)
{
auto
AB_init
=
0
;
auto
pass
=
run_mfma_test
<
f8_t
,
f8_t
,
half_t
,
ck
::
MFMA_F8F6F4
::
F32_16x16x128
>
(
AB_init
);
EXPECT_TRUE
(
pass
);
}
TEST
(
MFMA
,
FP8MFMA32x32x64
)
{
auto
AB_init
=
0
;
auto
pass
=
run_mfma_test
<
f8_t
,
f8_t
,
float
,
ck
::
MFMA_F8F6F4
::
F32_32x32x64
>
(
AB_init
);
EXPECT_TRUE
(
pass
);
}
test/mx_mfma_op/mx_mfma_op.hpp
0 → 100644
View file @
dbb7002d
#pragma once
#include "ck/ck.hpp"
#include "ck/library/utility/host_tensor.hpp"
#include "ck/library/utility/device_memory.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/warp/xdlops_gemm.hpp"
#include "ck/library/utility/host_tensor_generator.hpp"
#include "ck/library/reference_tensor_operation/cpu/reference_gemm.hpp"
#include "ck/library/utility/check_err.hpp"
namespace
ck
{
// MFMA instructions supported in this test
enum
class
MFMA_F8F6F4
{
F32_16x16x128
=
static_cast
<
int
>
(
MfmaInstr
::
mfma_f32_16x16x128f8f6f4
),
// V_MFMA_F32_16X16X128_F8F6F4
F32_32x32x64
=
static_cast
<
int
>
(
MfmaInstr
::
mfma_f32_32x32x64f8f6f4
)
// V_MFMA_F32_32X32X64_F8F6F4
};
template
<
typename
AFragT
,
typename
BFragT
,
typename
AccumFragT
,
int32_t
BLOCK_M
,
int32_t
BLOCK_N
>
struct
mfma_type_selector
;
template
<
typename
AFragT
,
typename
BFragT
,
typename
AccumFragT
>
struct
mfma_type_selector
<
AFragT
,
BFragT
,
AccumFragT
,
16
,
16
>
{
__device__
void
operator
()(
AFragT
const
&
fragA
,
BFragT
const
&
fragB
,
AccumFragT
&
fragAcc
)
{
auto
op
=
mfma_type
<
MfmaInstr
::
mfma_f32_16x16x128f8f6f4
>
{};
op
.
template
run
<
16
,
16
,
AFragT
,
BFragT
,
AccumFragT
>(
fragA
,
fragB
,
fragAcc
);
}
};
template
<
typename
AFragT
,
typename
BFragT
,
typename
AccumFragT
>
struct
mfma_type_selector
<
AFragT
,
BFragT
,
AccumFragT
,
32
,
32
>
{
__device__
void
operator
()(
AFragT
const
&
fragA
,
BFragT
const
&
fragB
,
AccumFragT
&
fragAcc
)
{
auto
op
=
mfma_type
<
MfmaInstr
::
mfma_f32_32x32x64f8f6f4
>
{};
op
.
template
run
<
32
,
32
,
AFragT
,
BFragT
,
AccumFragT
>(
fragA
,
fragB
,
fragAcc
);
}
};
template
<
typename
VecT
>
static
constexpr
int32_t
vectorSize
(
const
VecT
&
)
{
return
scalar_type
<
VecT
>::
vector_size
;
}
// Define a load function for input A blocks:
// Size: (BLOCK_M x BLOCK_K)
// ASSUMPTION:
// - We want contiguous BLOCK_M sized column neighbors in register.
// - Data is in col_major format
// This means:
// - From A we will load K columns of size BLOCK_M to satisfy our input data
template
<
typename
AType
,
typename
AFragT
,
int32_t
BLOCK_M
,
int32_t
BLOCK_K
>
__device__
AFragT
load_A_col_major
(
AType
const
*
input_ptr
)
{
// clang-format off
// Register Mapping for 16x128: || Register Mapping for 32x64:
// Size | BLOCK_M | BLOCK_M | BLOCK_M | BLOCK_M | || Size | BLOCK_M | BLOCK_M |
// M | 0 ... 15 | 0 ... 15 | 0 ... 15 | 0 ... 15 | || M | 0 ... 31 | 0 ... 31 |
// Thread Id | 0 ... 15 | 16 ... 31 | 32 ... 47 | 48 ... 63 | Vector || Thread Id | 0 ... 31 | 32 ... 63 | Vector
// Register Element ------------ ------------- ------------ ------------- Element || Register Element ------------ ------------- Element
// Reg 0 [0:7] | K0 | K32 | K64 | K96 | v[0] || Reg 0 [0:7] | K0 | K32 | v[0]
// Reg 0 [8:15] | K1 | K33 | K65 | K97 | v[1] || Reg 0 [8:15] | K1 | K33 | v[1]
// Reg 0 [16:23] | K2 | K34 | K66 | K98 | v[2] || Reg 0 [16:23] | K2 | K34 | v[2]
// Reg 0 [24:31] | K3 | K35 | K67 | K99 | v[3] || Reg 0 [24:31] | K3 | K35 | v[3]
// Reg 1 [0:7] | K4 | K36 | K68 | K100 | v[4] || Reg 1 [0:7] | K4 | K36 | v[4]
// Reg 1 [8:15] | K5 | K37 | K69 | K101 | v[5] || Reg 1 [8:15] | K5 | K37 | v[5]
// Reg 1 [16:23] | K6 | K38 | K70 | K102 | v[6] || Reg 1 [16:23] | K6 | K38 | v[6]
// Reg 1 [24:31] | K7 | K39 | K71 | K103 | v[7] || Reg 1 [24:31] | K7 | K39 | v[7]
// Reg 2 [0:7] | K8 | K40 | K72 | K104 | v[8] || Reg 2 [0:7] | K8 | K40 | v[8]
// Reg 2 [8:15] | K9 | K41 | K73 | K105 | v[9] || Reg 2 [8:15] | K9 | K41 | v[9]
// Reg 2 [16:23] | K10 | K42 | K74 | K106 | v[10] || Reg 2 [16:23] | K10 | K42 | v[10]
// Reg 2 [24:31] | K11 | K43 | K75 | K107 | v[11] || Reg 2 [24:31] | K11 | K43 | v[11]
// Reg 3 [0:7] | K12 | K44 | K76 | K108 | v[12] || Reg 3 [0:7] | K12 | K44 | v[12]
// Reg 3 [8:15] | K13 | K45 | K77 | K109 | v[13] || Reg 3 [8:15] | K13 | K45 | v[13]
// Reg 3 [16:23] | K14 | K46 | K78 | K110 | v[14] || Reg 3 [16:23] | K14 | K46 | v[14]
// Reg 3 [24:31] | K15 | K47 | K79 | K111 | v[15] || Reg 3 [24:31] | K15 | K47 | v[15]
// Reg 4 [0:7] | K16 | K48 | K80 | K112 | v[16] || Reg 4 [0:7] | K16 | K48 | v[16]
// Reg 4 [8:15] | K17 | K49 | K81 | K113 | v[17] || Reg 4 [8:15] | K17 | K49 | v[17]
// Reg 4 [16:23] | K18 | K50 | K82 | K114 | v[18] || Reg 4 [16:23] | K18 | K50 | v[18]
// Reg 4 [24:31] | K19 | K51 | K83 | K115 | v[19] || Reg 4 [24:31] | K19 | K51 | v[19]
// Reg 5 [0:7] | K20 | K52 | K84 | K116 | v[20] || Reg 5 [0:7] | K20 | K52 | v[20]
// Reg 5 [8:15] | K21 | K53 | K85 | K117 | v[21] || Reg 5 [8:15] | K21 | K53 | v[21]
// Reg 5 [16:23] | K22 | K54 | K86 | K118 | v[22] || Reg 5 [16:23] | K22 | K54 | v[22]
// Reg 5 [24:31] | K23 | K55 | K87 | K119 | v[23] || Reg 5 [24:31] | K23 | K55 | v[23]
// Reg 6 [0:7] | K24 | K56 | K88 | K120 | v[24] || Reg 6 [0:7] | K24 | K56 | v[24]
// Reg 6 [8:15] | K25 | K57 | K89 | K121 | v[25] || Reg 6 [8:15] | K25 | K57 | v[25]
// Reg 6 [16:23] | K26 | K58 | K90 | K122 | v[26] || Reg 6 [16:23] | K26 | K58 | v[26]
// Reg 6 [24:31] | K27 | K59 | K91 | K123 | v[27] || Reg 6 [24:31] | K27 | K59 | v[27]
// Reg 7 [0:7] | K28 | K60 | K92 | K124 | v[28] || Reg 7 [0:7] | K28 | K60 | v[28]
// Reg 7 [8:15] | K29 | K61 | K93 | K125 | v[29] || Reg 7 [8:15] | K29 | K61 | v[29]
// Reg 7 [16:23] | K30 | K62 | K94 | K126 | v[30] || Reg 7 [16:23] | K30 | K62 | v[30]
// Reg 7 [24:31] | K31 | K63 | K95 | K127 | v[31] || Reg 7 [24:31] | K31 | K63 | v[31]
// clang-format on
// Here we want to load a BLOCK_M x BLOCK_K block of data.
static
constexpr
uint32_t
VW
=
vectorSize
(
AFragT
{});
using
ARawT
=
typename
scalar_type
<
AFragT
>::
type
;
using
AScalarFragT
=
vector_type
<
ARawT
,
VW
>::
type
;
// To start the loading process, let's visualize in 2D coords.
// Each thread will load 32 elements.
// We need to know where they start, and where the next elements are.
auto
startCoord2D
=
std
::
make_pair
(
threadIdx
.
x
%
BLOCK_M
,
// Row
(
threadIdx
.
x
/
BLOCK_M
)
*
VW
);
// Col
auto
stepCoord2D
=
std
::
make_pair
(
0u
,
1u
);
// Flatten to 1D col_major offsets.
auto
col_major
=
[](
auto
const
&
coord
,
auto
ld
)
{
return
coord
.
first
+
coord
.
second
*
ld
;
};
// BLOCK_M is a stride in A matrix
auto
startOffset
=
col_major
(
startCoord2D
,
BLOCK_M
);
auto
kOffset
=
col_major
(
stepCoord2D
,
BLOCK_M
);
// kOffset == BLOCK_M
// This means every BLOCK_M element is loaded into output vector
auto
fragA
=
AScalarFragT
{};
#pragma unroll VW
for
(
uint32_t
i
=
0
;
i
<
VW
;
i
++
)
{
fragA
[
i
]
=
bit_cast
<
ARawT
>
(
input_ptr
[
startOffset
+
i
*
kOffset
]);
}
return
fragA
;
}
// Define a load function for input B blocks:
// Size: (BLOCK_K x BLOCK_N)
// ASSUMPTION:
// - We want contiguous BLOCK_N sized row neighbors in register.
// - Data is in row_major format
// This means:
// - From B we will load K rows of size BLOCK_N to satisfy our input data
template
<
typename
BType
,
typename
BFragT
,
int32_t
BLOCK_K
,
int32_t
BLOCK_N
>
__device__
BFragT
load_B_col_major
(
BType
const
*
input_ptr
)
{
// clang-format off
// Register Mapping for 128x16: || Register Mapping for 64x32:
// Size | BLOCK_N | BLOCK_N | BLOCK_N | BLOCK_N | || Size | BLOCK_N | BLOCK_N |
// N | 0 ... 15 | 0 ... 15 | 0 ... 15 | 0 ... 15 | || N | 0 ... 31 | 0 ... 31 |
// Thread Id | 0 ... 15 | 16 ... 31 | 32 ... 47 | 48 ... 63 | Vector || Thread Id | 0 ... 31 | 32 ... 63 | Vector
// Register Element ------------ ------------- ------------ ------------- Element || Register Element ------------ ------------- Element
// Reg 0 [0:7] | K0 | K32 | K64 | K96 | v[0] || Reg 0 [0:7] | K0 | K32 | v[0]
// Reg 0 [8:15] | K1 | K33 | K65 | K97 | v[1] || Reg 0 [8:15] | K1 | K33 | v[1]
// Reg 0 [16:23] | K2 | K34 | K66 | K98 | v[2] || Reg 0 [16:23] | K2 | K34 | v[2]
// Reg 0 [24:31] | K3 | K35 | K67 | K99 | v[3] || Reg 0 [24:31] | K3 | K35 | v[3]
// Reg 1 [0:7] | K4 | K36 | K68 | K100 | v[4] || Reg 1 [0:7] | K4 | K36 | v[4]
// Reg 1 [8:15] | K5 | K37 | K69 | K101 | v[5] || Reg 1 [8:15] | K5 | K37 | v[5]
// Reg 1 [16:23] | K6 | K38 | K70 | K102 | v[6] || Reg 1 [16:23] | K6 | K38 | v[6]
// Reg 1 [24:31] | K7 | K39 | K71 | K103 | v[7] || Reg 1 [24:31] | K7 | K39 | v[7]
// Reg 2 [0:7] | K8 | K40 | K72 | K104 | v[8] || Reg 2 [0:7] | K8 | K40 | v[8]
// Reg 2 [8:15] | K9 | K41 | K73 | K105 | v[9] || Reg 2 [8:15] | K9 | K41 | v[9]
// Reg 2 [16:23] | K10 | K42 | K74 | K106 | v[10] || Reg 2 [16:23] | K10 | K42 | v[10]
// Reg 2 [24:31] | K11 | K43 | K75 | K107 | v[11] || Reg 2 [24:31] | K11 | K43 | v[11]
// Reg 3 [0:7] | K12 | K44 | K76 | K108 | v[12] || Reg 3 [0:7] | K12 | K44 | v[12]
// Reg 3 [8:15] | K13 | K45 | K77 | K109 | v[13] || Reg 3 [8:15] | K13 | K45 | v[13]
// Reg 3 [16:23] | K14 | K46 | K78 | K110 | v[14] || Reg 3 [16:23] | K14 | K46 | v[14]
// Reg 3 [24:31] | K15 | K47 | K79 | K111 | v[15] || Reg 3 [24:31] | K15 | K47 | v[15]
// Reg 4 [0:7] | K16 | K48 | K80 | K112 | v[16] || Reg 4 [0:7] | K16 | K48 | v[16]
// Reg 4 [8:15] | K17 | K49 | K81 | K113 | v[17] || Reg 4 [8:15] | K17 | K49 | v[17]
// Reg 4 [16:23] | K18 | K50 | K82 | K114 | v[18] || Reg 4 [16:23] | K18 | K50 | v[18]
// Reg 4 [24:31] | K19 | K51 | K83 | K115 | v[19] || Reg 4 [24:31] | K19 | K51 | v[19]
// Reg 5 [0:7] | K20 | K52 | K84 | K116 | v[20] || Reg 5 [0:7] | K20 | K52 | v[20]
// Reg 5 [8:15] | K21 | K53 | K85 | K117 | v[21] || Reg 5 [8:15] | K21 | K53 | v[21]
// Reg 5 [16:23] | K22 | K54 | K86 | K118 | v[22] || Reg 5 [16:23] | K22 | K54 | v[22]
// Reg 5 [24:31] | K23 | K55 | K87 | K119 | v[23] || Reg 5 [24:31] | K23 | K55 | v[23]
// Reg 6 [0:7] | K24 | K56 | K88 | K120 | v[24] || Reg 6 [0:7] | K24 | K56 | v[24]
// Reg 6 [8:15] | K25 | K57 | K89 | K121 | v[25] || Reg 6 [8:15] | K25 | K57 | v[25]
// Reg 6 [16:23] | K26 | K58 | K90 | K122 | v[26] || Reg 6 [16:23] | K26 | K58 | v[26]
// Reg 6 [24:31] | K27 | K59 | K91 | K123 | v[27] || Reg 6 [24:31] | K27 | K59 | v[27]
// Reg 7 [0:7] | K28 | K60 | K92 | K124 | v[28] || Reg 7 [0:7] | K28 | K60 | v[28]
// Reg 7 [8:15] | K29 | K61 | K93 | K125 | v[29] || Reg 7 [8:15] | K29 | K61 | v[29]
// Reg 7 [16:23] | K30 | K62 | K94 | K126 | v[30] || Reg 7 [16:23] | K30 | K62 | v[30]
// Reg 7 [24:31] | K31 | K63 | K95 | K127 | v[31] || Reg 7 [24:31] | K31 | K63 | v[31]
// clang-format on
// Here we want to load a BLOCK_K x BLOCK_N block of data.
static
constexpr
uint32_t
VW
=
vectorSize
(
BFragT
{});
// To start the loading process, let's visualize in 2D coords.
// Each thread will load 32 elements.
// We need to know where they start, and where the next elements are.
auto
startCoord2D
=
std
::
make_pair
((
threadIdx
.
x
/
BLOCK_N
)
*
VW
,
// Row
threadIdx
.
x
%
BLOCK_N
);
// Col
// Flatten to 1D col_major offsets.
auto
col_major
=
[](
auto
const
&
coord
,
auto
ld
)
{
return
coord
.
first
+
coord
.
second
*
ld
;
};
auto
startOffset
=
col_major
(
startCoord2D
,
BLOCK_K
);
auto
const
*
fragPtr
=
reinterpret_cast
<
BFragT
const
*>
(
input_ptr
+
startOffset
);
return
*
fragPtr
;
}
// Define a store function for C
// Size: (BLOCK_M x BLOCK_N)
// ASSUMPTION:
// - We want contiguous BLOCK_N sized row neighbors in register.
// - Data is in col_major format
// This means:
// - From C we will load BLOCK_M rows of size BLOCK_N to satisfy our input data
template
<
typename
CType
,
typename
CFragT
,
int32_t
BLOCK_M
,
int32_t
BLOCK_N
>
struct
store_C_col_major
;
// Here we want to store a 16x16 block of data.
//
// Size | BLOCK_N | BLOCK_N | BLOCK_N | BLOCK_N |
// N | 0 ... 15 | 0 ... 15 | 0 ... 15 | 0 ... 15 |
// Thread Id | 0 ... 15 | 16 ... 31 | 32 ... 47 | 48 ... 63 | Vector
// Register Element ------------ ------------- ------------ -------------- Element
// Reg0 | M0 | M4 | M8 | M12 | v[0]
// Reg1 | M1 | M5 | M9 | M13 | v[1]
// Reg2 | M2 | M6 | M10 | M14 | v[2]
// Reg3 | M3 | M7 | M11 | M15 | v[3]
template
<
typename
CType
,
typename
CFragT
>
struct
store_C_col_major
<
CType
,
CFragT
,
16
,
16
>
{
__device__
void
operator
()(
CType
*
output
,
CFragT
cFrag
)
{
static
constexpr
uint32_t
VW
=
vectorSize
(
cFrag
);
// 4
static
constexpr
uint32_t
Dim
=
16
;
// Each thread will load 4 elements.
// We need to know where they start, and where the next elements are.
auto
startCoord2D
=
std
::
make_pair
((
threadIdx
.
x
/
Dim
)
*
VW
,
// Row
threadIdx
.
x
%
Dim
);
// Col
// Flatten to 1D col_major offsets.
auto
col_major
=
[](
auto
const
&
coord
,
auto
ld
)
{
return
coord
.
first
+
coord
.
second
*
ld
;
};
auto
startOffset
=
col_major
(
startCoord2D
,
16
);
auto
*
fragPtr
=
reinterpret_cast
<
CFragT
*>
(
output
+
startOffset
);
*
fragPtr
=
cFrag
;
}
};
// Here we want to store a 32x32 block of data.
// Register Mapping:
// Size | BLOCK_N | BLOCK_N |
// N | 0 ... 31 | 0 ... 31 |
// Thread Id | 0 ... 31 | 32 ... 63 | Vector
// Register Element ------------ ------------- Element
// Reg0 | M0 | M4 | v[0]
// Reg1 | M1 | M5 | v[1]
// Reg2 | M2 | M6 | v[2]
// Reg3 | M3 | M7 | v[3]
// ____________ _____________
// Reg4 | M8 | M12 | v[4]
// Reg5 | M9 | M13 | v[5]
// Reg6 | M10 | M14 | v[6]
// Reg7 | M11 | M15 | v[7]
// ____________ _____________
// Reg8 | M16 | M20 | v[8]
// Reg9 | M17 | M21 | v[9]
// Reg10 | M18 | M22 | v[10]
// Reg11 | M19 | M23 | v[11]
// ____________ _____________
// Reg12 | M24 | M28 | v[12]
// Reg13 | M25 | M29 | v[13]
// Reg14 | M26 | M30 | v[14]
// Reg15 | M27 | M31 | v[15]
template
<
typename
CType
,
typename
CFragT
>
struct
store_C_col_major
<
CType
,
CFragT
,
32
,
32
>
{
__device__
void
operator
()(
CType
*
output
,
CFragT
cFrag
)
{
static
constexpr
uint32_t
WAVE_SIZE
=
64
;
static
constexpr
uint32_t
VW
=
4
;
static
constexpr
uint32_t
Dim
=
32
;
static
constexpr
uint32_t
M_PER_VW_CHUNK
=
VW
*
WAVE_SIZE
/
32
;
// 8
auto
startCoord2D
=
std
::
make_pair
((
threadIdx
.
x
/
Dim
)
*
VW
,
// Row
threadIdx
.
x
%
Dim
);
// Col
// Major step between 'chunks'
auto
majorStepCoord2D
=
std
::
make_pair
(
M_PER_VW_CHUNK
,
0
);
// Flatten to 1D col_major offsets.
auto
col_major
=
[](
auto
const
&
coord
,
auto
ld
)
{
return
coord
.
first
+
coord
.
second
*
ld
;
};
auto
startOffset
=
col_major
(
startCoord2D
,
32
);
auto
kMajorOffset
=
col_major
(
majorStepCoord2D
,
32
);
// 8
// we can vector store 4 contiguous elements at a time.
using
CRawT
=
typename
scalar_type
<
CFragT
>::
type
;
using
CScalarFragT
=
vector_type
<
CRawT
,
VW
>::
type
;
union
{
CFragT
frag
;
CScalarFragT
chunks
[
vectorSize
(
CFragT
{})
/
VW
];
}
fragC
{
cFrag
};
// Initialize with input fragment
*
(
reinterpret_cast
<
CScalarFragT
*>
(
output
+
startOffset
))
=
fragC
.
chunks
[
0
];
*
(
reinterpret_cast
<
CScalarFragT
*>
(
output
+
startOffset
+
kMajorOffset
))
=
fragC
.
chunks
[
1
];
*
(
reinterpret_cast
<
CScalarFragT
*>
(
output
+
startOffset
+
2
*
kMajorOffset
))
=
fragC
.
chunks
[
2
];
*
(
reinterpret_cast
<
CScalarFragT
*>
(
output
+
startOffset
+
3
*
kMajorOffset
))
=
fragC
.
chunks
[
3
];
}
};
template
<
typename
AType
,
typename
BType
,
typename
CType
,
typename
AccType
,
int32_t
BLOCK_M
,
int32_t
BLOCK_N
,
int32_t
BLOCK_K
>
__global__
void
matmul
(
const
AType
*
a
,
const
BType
*
b
,
CType
*
c
)
{
constexpr
int
WAVE_SIZE
=
64
;
assert
(
threadIdx
.
x
<
WAVE_SIZE
);
assert
(
blockDim
.
x
==
1
&&
blockDim
.
y
==
1
&&
blockDim
.
z
==
1
);
using
AFragT
=
vector_type
<
AType
,
BLOCK_M
*
BLOCK_K
/
WAVE_SIZE
>::
type
;
using
BFragT
=
vector_type
<
BType
,
BLOCK_K
*
BLOCK_N
/
WAVE_SIZE
>::
type
;
using
CFragT
=
vector_type
<
CType
,
BLOCK_M
*
BLOCK_N
/
WAVE_SIZE
>::
type
;
using
AccumFragT
=
vector_type
<
AccType
,
BLOCK_M
*
BLOCK_N
/
WAVE_SIZE
>
;
using
RawAccumFragT
=
vector_type
<
AccType
,
BLOCK_M
*
BLOCK_N
/
WAVE_SIZE
>::
type
;
// Create frags
auto
fragA
=
AFragT
{};
auto
fragB
=
BFragT
{};
auto
fragC
=
CFragT
{};
auto
fragAcc
=
AccumFragT
{
0
};
// Load the inputs.
// A = col major, BLOCK_M x BLOCK_K
fragA
=
load_A_col_major
<
AType
,
AFragT
,
BLOCK_M
,
BLOCK_K
>
(
a
);
// B = col major, BLOCK_K x BLOCK_N
fragB
=
load_B_col_major
<
BType
,
BFragT
,
BLOCK_K
,
BLOCK_N
>
(
b
);
// Matrix multiply-accumulate using MFMA units
// Accumulation intermediate = BLOCK_M x BLOCK_N
mfma_type_selector
<
AFragT
,
BFragT
,
AccumFragT
,
BLOCK_M
,
BLOCK_N
>
{}(
fragA
,
fragB
,
fragAcc
);
for
(
int
i
=
0
;
i
<
vectorSize
(
fragC
);
++
i
)
{
fragC
[
i
]
=
type_convert
<
CType
>
(
fragAcc
.
template
AsType
<
RawAccumFragT
>()[
Number
<
0
>
{}][
i
]);
}
auto
storeC
=
store_C_col_major
<
CType
,
CFragT
,
BLOCK_M
,
BLOCK_N
>
{};
storeC
(
c
,
fragC
);
}
/**
* @brief Structure to hold dimension parameters for GEMM tensors.
*
* M Number of rows in matrix A and matrix C.
* N Number of columns in matrix B and matrix C.
* K Number of columns in matrix A and number of rows in matrix B.
* StrideA Stride (leading dimension) of matrix A.
* StrideB Stride (leading dimension) of matrix B.
* StrideC Stride (leading dimension) of matrix C.
*/
struct
GemmParams
{
ck
::
index_t
M
=
16
;
ck
::
index_t
N
=
16
;
ck
::
index_t
K
=
128
;
ck
::
index_t
StrideA
=
-
1
;
ck
::
index_t
StrideB
=
-
1
;
ck
::
index_t
StrideC
=
-
1
;
};
namespace
mfma_test
{
template
<
typename
GemmInstance
,
typename
ADataType
,
typename
BDataType
,
typename
CDataType
,
typename
AElementwiseOperation
,
typename
BElementwiseOperation
,
typename
CElementwiseOperation
>
void
RunHostGEMM
(
const
Tensor
<
ADataType
>&
A
,
const
Tensor
<
BDataType
>&
B
,
Tensor
<
CDataType
>&
C
,
AElementwiseOperation
a_element_op
,
BElementwiseOperation
b_element_op
,
CElementwiseOperation
c_element_op
)
{
auto
ref_gemm
=
GemmInstance
{};
auto
ref_invoker
=
ref_gemm
.
MakeInvoker
();
auto
ref_argument
=
ref_gemm
.
MakeArgument
(
A
,
B
,
C
,
a_element_op
,
b_element_op
,
c_element_op
);
ref_invoker
.
Run
(
ref_argument
);
}
template
<
typename
KernelType
,
typename
ADataType
,
typename
BDataType
,
typename
CDataType
>
bool
RunDeviceGEMM
(
KernelType
kernel
,
const
Tensor
<
ADataType
>&
A
,
const
Tensor
<
BDataType
>&
B
,
Tensor
<
CDataType
>&
C
)
{
DeviceMem
a_m_k_device_buf
(
sizeof
(
ADataType
)
*
A
.
mDesc
.
GetElementSpaceSize
());
DeviceMem
b_n_k_device_buf
(
sizeof
(
BDataType
)
*
B
.
mDesc
.
GetElementSpaceSize
());
DeviceMem
c_m_n_device_buf
(
sizeof
(
CDataType
)
*
C
.
mDesc
.
GetElementSpaceSize
());
a_m_k_device_buf
.
ToDevice
(
A
.
mData
.
data
());
b_n_k_device_buf
.
ToDevice
(
B
.
mData
.
data
());
kernel
<<<
1
,
64
>>>
(
static_cast
<
const
ADataType
*>
(
a_m_k_device_buf
.
GetDeviceBuffer
()),
static_cast
<
const
BDataType
*>
(
b_n_k_device_buf
.
GetDeviceBuffer
()),
static_cast
<
CDataType
*>
(
c_m_n_device_buf
.
GetDeviceBuffer
()));
c_m_n_device_buf
.
FromDevice
(
C
.
mData
.
data
());
return
true
;
}
template
<
typename
DeviceMFMA
,
typename
ADataType
,
typename
BDataType
,
typename
CDataType
,
typename
GPUAccDataType
,
typename
CPUAccDataType
,
typename
ALayout
,
typename
BLayout
,
typename
CLayout
,
index_t
BLOCK_M
,
index_t
BLOCK_N
,
index_t
BLOCK_K
>
struct
TestMFMA
{
auto
PrepareGemmTensors
(
const
GemmParams
&
params
,
index_t
init
)
{
auto
f_host_tensor_descriptor
=
[](
std
::
size_t
row
,
std
::
size_t
col
,
std
::
size_t
stride
,
auto
layout
)
{
if
(
std
::
is_same
<
decltype
(
layout
),
ck
::
tensor_layout
::
gemm
::
RowMajor
>::
value
)
{
return
HostTensorDescriptor
(
std
::
vector
<
std
::
size_t
>
({
row
,
col
}),
std
::
vector
<
std
::
size_t
>
({
stride
,
1
}));
}
else
{
return
HostTensorDescriptor
(
std
::
vector
<
std
::
size_t
>
({
row
,
col
}),
std
::
vector
<
std
::
size_t
>
({
1
,
stride
}));
}
};
Tensor
<
ADataType
>
a_m_k
(
f_host_tensor_descriptor
(
params
.
M
,
params
.
K
,
params
.
StrideA
,
ALayout
{}));
Tensor
<
BDataType
>
b_n_k
(
f_host_tensor_descriptor
(
params
.
K
,
params
.
N
,
params
.
StrideB
,
BLayout
{}));
Tensor
<
CDataType
>
c_m_n_host_result
(
f_host_tensor_descriptor
(
params
.
M
,
params
.
N
,
params
.
StrideC
,
CLayout
{}));
Tensor
<
CDataType
>
c_m_n_device_result
(
f_host_tensor_descriptor
(
params
.
M
,
params
.
N
,
params
.
StrideC
,
CLayout
{}));
switch
(
init
)
{
case
0
:
a_m_k
.
GenerateTensorValue
(
GeneratorTensor_1
<
ADataType
>
{
0.015625
f
});
// NOTE: not all numbers are representable in FP8, BF8, etc.
b_n_k
.
GenerateTensorValue
(
GeneratorTensor_Sequential
<
BDataType
,
1
>
{});
break
;
case
1
:
// results in C = {K}
a_m_k
.
GenerateTensorValue
(
GeneratorTensor_1
<
ADataType
>
{
1.0
f
});
b_n_k
.
GenerateTensorValue
(
GeneratorTensor_1
<
BDataType
>
{
1.0
f
});
break
;
case
2
:
// expect small round off errors
a_m_k
.
GenerateTensorValue
(
GeneratorTensor_3
<
ADataType
>
{
-
5
,
5
});
b_n_k
.
GenerateTensorValue
(
GeneratorTensor_3
<
BDataType
>
{
-
5
,
5
});
break
;
case
3
:
// expect small round off errors
a_m_k
.
GenerateTensorValue
(
GeneratorTensor_4
<
ADataType
>
(
-
1
,
3
));
b_n_k
.
GenerateTensorValue
(
GeneratorTensor_4
<
BDataType
>
(
1
,
3
));
break
;
default:
// all initial values are representable in FP8, BF8
a_m_k
.
GenerateTensorValue
(
GeneratorTensor_2
<
ADataType
>
{
-
5
,
6
});
b_n_k
.
GenerateTensorValue
(
GeneratorTensor_2
<
BDataType
>
{
-
5
,
6
});
break
;
}
return
std
::
make_tuple
(
a_m_k
,
b_n_k
,
c_m_n_host_result
,
c_m_n_device_result
);
}
auto
operator
()(
const
DeviceMFMA
&
mfma_kernel
,
index_t
init
)
{
std
::
cout
<<
"ALayout = "
<<
ALayout
{}.
name
<<
", BLayout = "
<<
BLayout
{}.
name
<<
", CLayout = "
<<
CLayout
{}.
name
<<
std
::
endl
;
// Arrange
GemmParams
params
;
params
.
M
=
BLOCK_M
;
params
.
N
=
BLOCK_N
;
params
.
K
=
BLOCK_K
;
auto
f_get_default_stride
=
[](
std
::
size_t
row
,
std
::
size_t
col
,
ck
::
index_t
stride
,
auto
layout
)
{
if
(
stride
==
-
1
)
{
// give a chance if stride is -1, return a default packed stride
if
constexpr
(
std
::
is_same_v
<
decltype
(
layout
),
ck
::
tensor_layout
::
gemm
::
RowMajor
>
)
{
return
static_cast
<
std
::
size_t
>
(
col
);
}
else
{
return
static_cast
<
std
::
size_t
>
(
row
);
}
}
else
return
static_cast
<
std
::
size_t
>
(
stride
);
};
params
.
StrideA
=
f_get_default_stride
(
BLOCK_M
,
BLOCK_K
,
params
.
StrideA
,
ALayout
{});
params
.
StrideB
=
f_get_default_stride
(
BLOCK_K
,
BLOCK_N
,
params
.
StrideB
,
BLayout
{});
params
.
StrideC
=
f_get_default_stride
(
BLOCK_M
,
BLOCK_N
,
params
.
StrideC
,
CLayout
{});
auto
host_tensors
=
PrepareGemmTensors
(
params
,
init
);
const
Tensor
<
ADataType
>&
a
=
std
::
get
<
0
>
(
host_tensors
);
const
Tensor
<
BDataType
>&
b
=
std
::
get
<
1
>
(
host_tensors
);
Tensor
<
CDataType
>&
c_host
=
std
::
get
<
2
>
(
host_tensors
);
Tensor
<
CDataType
>&
c_device
=
std
::
get
<
3
>
(
host_tensors
);
using
PassThrough
=
ck
::
tensor_operation
::
element_wise
::
PassThrough
;
auto
a_element_op
=
PassThrough
{};
auto
b_element_op
=
PassThrough
{};
auto
c_element_op
=
PassThrough
{};
using
ReferenceGemmInstance
=
ck
::
tensor_operation
::
host
::
ReferenceGemm
<
ADataType
,
BDataType
,
CDataType
,
CPUAccDataType
,
PassThrough
,
PassThrough
,
PassThrough
>
;
RunHostGEMM
<
ReferenceGemmInstance
>
(
a
,
b
,
c_host
,
a_element_op
,
b_element_op
,
c_element_op
);
RunDeviceGEMM
(
mfma_kernel
,
a
,
b
,
c_device
);
bool
res
=
false
;
if
constexpr
(
std
::
is_same
<
CDataType
,
float
>::
value
||
std
::
is_same
<
CDataType
,
half_t
>::
value
)
{
res
=
ck
::
utils
::
check_err
(
c_device
.
mData
,
c_host
.
mData
);
std
::
cout
<<
(
res
?
"SUCCESS"
:
"FAILURE"
)
<<
std
::
endl
;
}
else
{
std
::
cout
<<
"UNSUPPORTED CDataType"
<<
std
::
endl
;
}
return
res
;
}
};
}
// namespace mfma_test
}
// namespace ck
test/smfmac_op/smfmac_op_xdl.cpp
View file @
dbb7002d
...
...
@@ -40,7 +40,7 @@ class TestSmfmac : public ::testing::Test
void
Run
()
{
bool
pass
=
true
;
if
(
ck
::
get_device_name
()
==
"gfx942"
)
if
(
ck
::
get_device_name
()
==
"gfx942"
||
ck
::
get_device_name
()
==
"gfx950"
)
{
constexpr
auto
matmul_default
=
ck
::
smfmac_op_util
::
matmul
<
Src1Type
,
Src1VecSize
,
...
...
Prev
1
…
8
9
10
11
12
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment