Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
FlashMLA
Commits
60dfab33
Commit
60dfab33
authored
Feb 21, 2026
by
zhanghj2
Browse files
float传bf16使用round_half_ulp_truncate
parent
68971b5c
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
11 additions
and
15 deletions
+11
-15
csrc/utils.h
csrc/utils.h
+11
-15
No files found.
csrc/utils.h
View file @
60dfab33
...
...
@@ -276,23 +276,21 @@ __forceinline__ __device__ auto convert_type(Tensor<Engine, Layout> const &tenso
if
constexpr
(
std
::
is_same_v
<
To_type
,
cutlass
::
bfloat16_t
>
)
{
cutlass
::
NumericArrayConverter
<
To_type
,
From_type
,
numel
,
cutlass
::
FloatRoundStyle
::
round_to_nearest
>
convert_op
;
*
result_ptr
=
convert_op
(
*
reinterpret_cast
<
const
cutlass
::
Array
<
From_type
,
numel
>
*>
(
tensor
.
data
()));
}
else
if
constexpr
(
std
::
is_same_v
<
To_type
,
cutlass
::
float_e4m3_t
>
)
{
cutlass
::
NumericArrayConverter
<
To_type
,
From_type
,
numel
,
cutlass
::
FloatRoundStyle
::
round_to_nearest
>
convert_op
;
*
result_ptr
=
convert_op
(
*
reinterpret_cast
<
const
cutlass
::
Array
<
From_type
,
numel
>
*>
(
tensor
.
data
()));
}
else
{
cutlass
::
NumericArrayConverter
<
To_type
,
From_type
,
numel
>
convert_op
;
*
result_ptr
=
convert_op
(
*
reinterpret_cast
<
const
cutlass
::
Array
<
From_type
,
numel
>
*>
(
tensor
.
data
()));
}
return
tensor_To_type
;
}
else
if
constexpr
(
std
::
is_same_v
<
To_type
,
cutlass
::
float_e4m3_t
>
)
{
cutlass
::
NumericArrayConverter
<
To_type
,
From_type
,
numel
,
cutlass
::
FloatRoundStyle
::
round_to_nearest
>
convert_op
;
*
result_ptr
=
convert_op
(
*
reinterpret_cast
<
const
cutlass
::
Array
<
From_type
,
numel
>
*>
(
tensor
.
data
()));
}
else
{
cutlass
::
NumericArrayConverter
<
To_type
,
From_type
,
numel
>
convert_op
;
*
result_ptr
=
convert_op
(
*
reinterpret_cast
<
const
cutlass
::
Array
<
From_type
,
numel
>
*>
(
tensor
.
data
()));
}
return
tensor_To_type
;
}
#else
{
if
constexpr
(
std
::
is_same_v
<
To_type
,
cutlass
::
bfloat16_t
>
)
{
cutlass
::
NumericArrayConverter
<
To_type
,
From_type
,
numel
,
cutlass
::
FloatRoundStyle
::
round_
toward_zero
>
convert_op
;
cutlass
::
NumericArrayConverter
<
To_type
,
From_type
,
numel
,
cutlass
::
FloatRoundStyle
::
round_
half_ulp_truncate
>
convert_op
;
*
result_ptr
=
convert_op
(
*
reinterpret_cast
<
const
cutlass
::
Array
<
From_type
,
numel
>
*>
(
tensor
.
data
()));
}
else
{
cutlass
::
NumericArrayConverter
<
To_type
,
From_type
,
numel
>
convert_op
;
...
...
@@ -300,8 +298,6 @@ __forceinline__ __device__ auto convert_type(Tensor<Engine, Layout> const &tenso
}
return
tensor_To_type
;
}
#endif
// cutlass::NumericArrayConverter<To_type, From_type, numel> convert_op;
// // HACK: this requires tensor to be "contiguous"
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment