Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel_ROCM
Commits
f9181773
Commit
f9181773
authored
Feb 12, 2025
by
Rostyslav Geyyer
Browse files
Permute packed f4_t values
parent
ee8937a8
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
18 additions
and
41 deletions
+18
-41
include/ck/utility/data_type.hpp
include/ck/utility/data_type.hpp
+3
-3
include/ck/utility/scaled_type_convert.hpp
include/ck/utility/scaled_type_convert.hpp
+2
-2
include/ck/utility/type_convert.hpp
include/ck/utility/type_convert.hpp
+4
-4
test/data_type/test_mx_fp4.cpp
test/data_type/test_mx_fp4.cpp
+9
-32
No files found.
include/ck/utility/data_type.hpp
View file @
f9181773
...
@@ -40,14 +40,14 @@ struct f4x2_pk_t
...
@@ -40,14 +40,14 @@ struct f4x2_pk_t
{
{
static_assert
(
I
<
2
,
"Index is out of range."
);
static_assert
(
I
<
2
,
"Index is out of range."
);
if
constexpr
(
I
==
0
)
if
constexpr
(
I
==
0
)
return
data
&
0b00001111
;
else
return
(
data
>>
4
);
return
(
data
>>
4
);
else
return
data
&
0b00001111
;
}
}
__host__
__device__
inline
type
pack
(
const
type
x0
,
const
type
x1
)
__host__
__device__
inline
type
pack
(
const
type
x0
,
const
type
x1
)
{
{
return
(
x
1
<<
4
)
|
(
x
0
&
0b00001111
);
return
(
x
0
<<
4
)
|
(
x
1
&
0b00001111
);
}
}
};
};
...
...
include/ck/utility/scaled_type_convert.hpp
View file @
f9181773
...
@@ -380,9 +380,9 @@ inline __host__ __device__ float2_t scaled_type_convert<float2_t, f4x2_t>(e8m0_b
...
@@ -380,9 +380,9 @@ inline __host__ __device__ float2_t scaled_type_convert<float2_t, f4x2_t>(e8m0_b
return
__builtin_amdgcn_cvt_scalef32_pk_f32_fp4
(
value
.
bitwise
,
type_convert
<
float
>
(
scale
),
0
);
return
__builtin_amdgcn_cvt_scalef32_pk_f32_fp4
(
value
.
bitwise
,
type_convert
<
float
>
(
scale
),
0
);
#else
#else
float2_t
ret
{
utils
::
to_float
<
f4_t
>
(
float2_t
ret
{
utils
::
to_float
<
f4_t
>
(
scale
,
x
.
template
AsType
<
f4x2_pk_t
>()[
Number
<
0
>
{}].
unpack
<>
(
Number
<
1
>
{})),
scale
,
x
.
template
AsType
<
f4x2_pk_t
>()[
Number
<
0
>
{}].
unpack
<>
(
Number
<
0
>
{})),
utils
::
to_float
<
f4_t
>
(
utils
::
to_float
<
f4_t
>
(
scale
,
x
.
template
AsType
<
f4x2_pk_t
>()[
Number
<
0
>
{}].
unpack
<>
(
Number
<
0
>
{}))};
scale
,
x
.
template
AsType
<
f4x2_pk_t
>()[
Number
<
0
>
{}].
unpack
<>
(
Number
<
1
>
{}))};
return
ret
;
return
ret
;
#endif
#endif
}
}
...
...
include/ck/utility/type_convert.hpp
View file @
f9181773
...
@@ -742,8 +742,8 @@ inline __host__ __device__ f4x2_t f4_convert_rne(float2_t x, float scale = 1.0f)
...
@@ -742,8 +742,8 @@ inline __host__ __device__ f4x2_t f4_convert_rne(float2_t x, float scale = 1.0f)
uint32_t
bitwise
;
uint32_t
bitwise
;
f4x2_t
f4x2_array
[
4
];
f4x2_t
f4x2_array
[
4
];
}
value
{
0
};
}
value
{
0
};
uint8_t
l
=
utils
::
sat_convert_to_type
<
f4_t
>
(
x
[
0
]
/
scale
);
uint8_t
l
=
utils
::
sat_convert_to_type
<
f4_t
>
(
x
[
1
]
/
scale
);
uint8_t
h
=
utils
::
sat_convert_to_type
<
f4_t
>
(
x
[
1
]
/
scale
);
uint8_t
h
=
utils
::
sat_convert_to_type
<
f4_t
>
(
x
[
0
]
/
scale
);
value
.
bitwise
=
(
h
<<
4
)
|
l
;
value
.
bitwise
=
(
h
<<
4
)
|
l
;
return
value
.
f4x2_array
[
0
];
return
value
.
f4x2_array
[
0
];
#endif
#endif
...
@@ -969,8 +969,8 @@ inline __host__ __device__ f4x2_t f4_convert_sr(float2_t x, float scale = 1.0f)
...
@@ -969,8 +969,8 @@ inline __host__ __device__ f4x2_t f4_convert_sr(float2_t x, float scale = 1.0f)
uint32_t
bitwise
;
uint32_t
bitwise
;
f4x2_t
f4x2_array
[
4
];
f4x2_t
f4x2_array
[
4
];
}
value
{
0
};
}
value
{
0
};
uint8_t
l
=
utils
::
sat_convert_to_type_sr
<
f4_t
>
(
x
[
0
]
/
scale
,
rng
);
uint8_t
l
=
utils
::
sat_convert_to_type_sr
<
f4_t
>
(
x
[
1
]
/
scale
,
rng
);
uint8_t
h
=
utils
::
sat_convert_to_type_sr
<
f4_t
>
(
x
[
1
]
/
scale
,
rng
);
uint8_t
h
=
utils
::
sat_convert_to_type_sr
<
f4_t
>
(
x
[
0
]
/
scale
,
rng
);
value
.
bitwise
=
(
h
<<
4
)
|
l
;
value
.
bitwise
=
(
h
<<
4
)
|
l
;
return
value
.
f4x2_array
[
0
];
return
value
.
f4x2_array
[
0
];
#endif
#endif
...
...
test/data_type/test_mx_fp4.cpp
View file @
f9181773
...
@@ -264,11 +264,6 @@ TEST(MXFP4, DeviceScaledConvert)
...
@@ -264,11 +264,6 @@ TEST(MXFP4, DeviceScaledConvert)
device_completed
.
FromDevice
(
&
completed
);
device_completed
.
FromDevice
(
&
completed
);
device_out
.
FromDevice
(
out
.
data
());
device_out
.
FromDevice
(
out
.
data
());
for
(
ck
::
index_t
id
=
0
;
id
<
256
*
16
;
id
++
)
{
printf
(
"%f
\n
"
,
out
.
data
()[
id
]);
}
// V = X * P; X - E8M0 scale, P - FP4
// V = X * P; X - E8M0 scale, P - FP4
// If X = NaN, then V = NaN regardless of P
// If X = NaN, then V = NaN regardless of P
...
@@ -279,32 +274,14 @@ TEST(MXFP4, DeviceScaledConvert)
...
@@ -279,32 +274,14 @@ TEST(MXFP4, DeviceScaledConvert)
ASSERT_TRUE
(
std
::
isnan
(
out
[
idx
]))
<<
"idx: "
<<
idx
<<
" out[idx]: "
<<
out
[
idx
];
ASSERT_TRUE
(
std
::
isnan
(
out
[
idx
]))
<<
"idx: "
<<
idx
<<
" out[idx]: "
<<
out
[
idx
];
}
}
// If P in {Inf, NaN}, then V = P
std
::
set
<
uint8_t
>
fp4_nan_ids
;
fp4_nan_ids
.
insert
(
0b11111111
);
//-NaN
fp4_nan_ids
.
insert
(
0b01111111
);
// +NaN
for
(
ck
::
index_t
exp_id
=
0
;
exp_id
<
256
;
exp_id
++
)
{
if
(
exp_id
==
e8m0_nan_id
)
continue
;
for
(
auto
fp4_nan_id
:
fp4_nan_ids
)
{
auto
idx
=
exp_id
*
256
+
fp4_nan_id
;
ASSERT_TRUE
(
std
::
isnan
(
out
[
idx
]))
<<
"idx: "
<<
idx
<<
" out[idx]: "
<<
out
[
idx
];
}
}
for
(
ck
::
index_t
exp_id
=
0
;
exp_id
<
256
;
exp_id
++
)
for
(
ck
::
index_t
exp_id
=
0
;
exp_id
<
256
;
exp_id
++
)
{
{
if
(
exp_id
==
e8m0_nan_id
)
if
(
exp_id
==
e8m0_nan_id
)
continue
;
continue
;
for
(
ck
::
index_t
fp4_id
=
0
;
fp4_id
<
25
6
;
fp4_id
++
)
for
(
ck
::
index_t
fp4_id
=
0
;
fp4_id
<
1
6
;
fp4_id
++
)
{
{
if
(
fp4_nan_ids
.
find
(
fp4_id
)
!=
fp4_nan_ids
.
end
())
continue
;
uint8_t
fp4_uid
=
static_cast
<
uint8_t
>
(
fp4_id
);
uint8_t
fp4_uid
=
static_cast
<
uint8_t
>
(
fp4_id
);
auto
idx
=
exp_id
*
25
6
+
fp4_uid
;
auto
idx
=
exp_id
*
1
6
+
fp4_uid
;
ASSERT_FLOAT_EQ
(
out
[
idx
],
ASSERT_FLOAT_EQ
(
out
[
idx
],
type_convert
<
float
>
(
e8m0_bexp_t
(
exp_id
))
*
type_convert
<
float
>
(
e8m0_bexp_t
(
exp_id
))
*
type_convert
<
float
>
(
f4_t
(
fp4_uid
&
0b00001111
)))
type_convert
<
float
>
(
f4_t
(
fp4_uid
&
0b00001111
)))
...
@@ -319,19 +296,19 @@ TEST(MXFP4, DeviceScaledConvert)
...
@@ -319,19 +296,19 @@ TEST(MXFP4, DeviceScaledConvert)
auto
i
=
256
*
16
;
auto
i
=
256
*
16
;
// f4x2 -> f32x2
// f4x2 -> f32x2
EXPECT_EQ
(
out
[
i
++
],
-
powf
(
2.0
f
,
-
5
.0
f
)
)
;
EXPECT_EQ
(
out
[
i
++
],
1
.0
f
);
EXPECT_EQ
(
out
[
i
++
],
powf
(
2.0
f
,
-
8
.0
f
)
)
;
EXPECT_EQ
(
out
[
i
++
],
-
4
.0
f
);
// f32x2 -> f4x2
// f32x2 -> f4x2
// RNE
// RNE
EXPECT_EQ
(
out
[
i
++
],
-
4.0
f
);
EXPECT_EQ
(
out
[
i
++
],
0.5
f
);
EXPECT_EQ
(
out
[
i
++
],
2.0
f
);
EXPECT_EQ
(
out
[
i
++
],
-
2.0
f
);
// SR
// SR
EXPECT_EQ
(
out
[
i
++
],
0.5
f
);
EXPECT_EQ
(
out
[
i
++
],
-
2.0
f
);
EXPECT_EQ
(
out
[
i
++
],
-
2.0
f
);
EXPECT_EQ
(
out
[
i
++
],
1.0
f
);
/// Test round to nearest even
/// Test round to nearest even
EXPECT_EQ
(
out
[
i
++
],
10
24.0
f
/
4.0
f
)
<<
"out[i-1]: "
<<
out
[
i
-
1
];
EXPECT_EQ
(
out
[
i
++
],
24.0
f
/
4.0
f
)
<<
"out[i-1]: "
<<
out
[
i
-
1
];
EXPECT_TRUE
(
std
::
isnan
(
out
[
i
++
]))
<<
"out[i-1]: "
<<
out
[
i
-
1
];
EXPECT_TRUE
(
std
::
isnan
(
out
[
i
++
]))
<<
"out[i-1]: "
<<
out
[
i
-
1
];
#if 1
#if 1
EXPECT_TRUE
(
std
::
isnan
(
out
[
i
++
]))
<<
"out[i-1]: "
<<
out
[
i
-
1
];
EXPECT_TRUE
(
std
::
isnan
(
out
[
i
++
]))
<<
"out[i-1]: "
<<
out
[
i
-
1
];
...
@@ -347,7 +324,7 @@ TEST(MXFP4, DeviceScaledConvert)
...
@@ -347,7 +324,7 @@ TEST(MXFP4, DeviceScaledConvert)
EXPECT_EQ
(
out
[
i
++
],
type_convert
<
float
>
(
ck
::
NumericLimits
<
f4_t
>::
Lowest
()))
EXPECT_EQ
(
out
[
i
++
],
type_convert
<
float
>
(
ck
::
NumericLimits
<
f4_t
>::
Lowest
()))
<<
"out[i-1]: "
<<
out
[
i
-
1
];
<<
"out[i-1]: "
<<
out
[
i
-
1
];
#endif
#endif
EXPECT_EQ
(
out
[
i
++
],
type_convert
<
float
>
(
type_convert
<
f4_t
>
(
312.5
f
)))
EXPECT_EQ
(
out
[
i
++
],
type_convert
<
float
>
(
type_convert
<
f4_t
>
(
5.0
f
)))
<<
"out[i-1]: "
<<
out
[
i
-
1
];
<<
"out[i-1]: "
<<
out
[
i
-
1
];
EXPECT_EQ
(
test_size
,
completed
);
EXPECT_EQ
(
test_size
,
completed
);
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment