Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel_ROCM
Commits
e38b4a33
"...composable_kernel_rocm.git" did not exist on "313bbea5886850acab286f45e9d9816cf0b0dca0"
Commit
e38b4a33
authored
Feb 14, 2025
by
Rostyslav Geyyer
Browse files
Add a conversion for a repro test
parent
d1499dd8
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
62 additions
and
0 deletions
+62
-0
include/ck/utility/type_convert.hpp
include/ck/utility/type_convert.hpp
+27
-0
test/data_type/test_mx_fp4_repro.cpp
test/data_type/test_mx_fp4_repro.cpp
+35
-0
No files found.
include/ck/utility/type_convert.hpp
View file @
e38b4a33
...
@@ -978,6 +978,33 @@ inline __host__ __device__ f4x2_t f4_convert_sr(float2_t x, float scale = 1.0f)
...
@@ -978,6 +978,33 @@ inline __host__ __device__ f4x2_t f4_convert_sr(float2_t x, float scale = 1.0f)
#endif
#endif
}
}
// convert vector of 2 fp32 to vector of 2 fp4 with sr
inline
__host__
__device__
f4x2_t
f4_convert_sr_repro
(
float2_t
x
,
float
scale
=
1.0
f
)
{
constexpr
int
seed
=
1254739
;
uint32_t
rng
=
prand_generator
<
float
,
seed
>
(
reinterpret_cast
<
uintptr_t
>
(
&
x
),
x
[
0
]);
#if defined(__gfx950__)
union
{
uint32_t
bitwise
;
f4x2_t
f4x2_array
[
4
];
}
value
{
0
};
value
.
bitwise
=
__builtin_amdgcn_cvt_scalef32_sr_pk_fp4_f32
(
value
.
bitwise
,
float2_t
{
x
[
1
],
x
[
0
]},
rng
,
scale
,
0
);
return
value
.
f4x2_array
[
0
];
#else
union
{
uint32_t
bitwise
;
f4x2_t
f4x2_array
[
4
];
}
value
{
0
};
uint8_t
l
=
utils
::
sat_convert_to_type_sr
<
f4_t
>
(
x
[
1
]
/
scale
,
rng
);
uint8_t
h
=
utils
::
sat_convert_to_type_sr
<
f4_t
>
(
x
[
0
]
/
scale
,
rng
);
value
.
bitwise
=
(
h
<<
4
)
|
l
;
return
value
.
f4x2_array
[
0
];
#endif
}
// convert vector of 32 fp32 to vector of 32 fp4 with sr
// convert vector of 32 fp32 to vector of 32 fp4 with sr
inline
__host__
__device__
f4x32_t
f4_convert_sr
(
float32_t
x
,
float
scale
=
1.0
f
)
inline
__host__
__device__
f4x32_t
f4_convert_sr
(
float32_t
x
,
float
scale
=
1.0
f
)
{
{
...
...
test/data_type/test_mx_fp4_repro.cpp
View file @
e38b4a33
...
@@ -63,6 +63,23 @@ __host__ __device__ void test_mx_fp32_to_fp4_sr(float* p_test)
...
@@ -63,6 +63,23 @@ __host__ __device__ void test_mx_fp32_to_fp4_sr(float* p_test)
__global__
void
run_test_mx_fp32_to_fp4_sr
(
float
*
p_test
)
{
test_mx_fp32_to_fp4_sr
(
p_test
);
}
__global__
void
run_test_mx_fp32_to_fp4_sr
(
float
*
p_test
)
{
test_mx_fp32_to_fp4_sr
(
p_test
);
}
__host__
__device__
void
test_mx_fp32_to_fp4_sr_failing
(
float
*
p_test
)
{
float2_t
f32x2
=
{
1.0
f
,
-
4.0
f
};
auto
scale2
=
e8m0_bexp_t
(
2.0
f
);
f4x2_t
f4x2
=
ck
::
f4_convert_sr_repro
(
f32x2
,
type_convert
<
float
>
(
scale2
));
// expect {0.5, -2}
p_test
[
0
]
=
type_convert
<
float
>
(
f4_t
(
f4x2
.
AsType
<
f4x2_pk_t
>
()(
ck
::
Number
<
0
>
{}).
unpack
<>
(
ck
::
Number
<
0
>
{})));
// 0.5f
p_test
[
1
]
=
type_convert
<
float
>
(
f4_t
(
f4x2
.
AsType
<
f4x2_pk_t
>
()(
ck
::
Number
<
0
>
{}).
unpack
<>
(
ck
::
Number
<
1
>
{})));
// -2.0f
}
__global__
void
run_test_mx_fp32_to_fp4_sr_failing
(
float
*
p_test
)
{
test_mx_fp32_to_fp4_sr_failing
(
p_test
);
}
TEST
(
MXFP4
,
FP4ToFP32
)
TEST
(
MXFP4
,
FP4ToFP32
)
{
{
std
::
vector
<
float
>
out
(
2
,
-
1.0
f
);
std
::
vector
<
float
>
out
(
2
,
-
1.0
f
);
...
@@ -120,3 +137,21 @@ TEST(MXFP4, FP32ToFP4SR)
...
@@ -120,3 +137,21 @@ TEST(MXFP4, FP32ToFP4SR)
EXPECT_EQ
(
out
[
0
],
0.5
f
);
EXPECT_EQ
(
out
[
0
],
0.5
f
);
EXPECT_EQ
(
out
[
1
],
-
2.0
f
);
EXPECT_EQ
(
out
[
1
],
-
2.0
f
);
}
}
TEST
(
MXFP4
,
FP32ToFP4SRFailing
)
{
std
::
vector
<
float
>
out
(
2
,
-
1.0
f
);
DeviceMem
device_out
(
2
*
sizeof
(
float
));
// DeviceMem device_completed(sizeof(uint64_t));
run_test_mx_fp32_to_fp4_sr_failing
<<<
1
,
1
>>>
(
static_cast
<
float
*>
(
device_out
.
GetDeviceBuffer
()));
// uint64_t completed = 0;
// device_completed.FromDevice(&completed);
device_out
.
FromDevice
(
out
.
data
());
// SR
EXPECT_EQ
(
out
[
0
],
0.5
f
);
EXPECT_EQ
(
out
[
1
],
-
2.0
f
);
}
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment