Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
tilelang
Commits
fdbf4d6c
Unverified
Commit
fdbf4d6c
authored
Aug 04, 2025
by
Wenhao Xie
Committed by
GitHub
Aug 04, 2025
Browse files
[Enhancement] Optimize BF16 casting performance (#689)
* use more efficient bf16 type related conversion * update macro
parent
d2afb513
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
345 additions
and
10 deletions
+345
-10
src/target/codegen_cuda.cc
src/target/codegen_cuda.cc
+62
-10
src/target/codegen_cuda.h
src/target/codegen_cuda.h
+3
-0
src/tl_templates/cuda/cuda_bf16_fallbacks.cuh
src/tl_templates/cuda/cuda_bf16_fallbacks.cuh
+257
-0
src/tl_templates/cuda/cuda_bf16_wrapper.h
src/tl_templates/cuda/cuda_bf16_wrapper.h
+23
-0
No files found.
src/target/codegen_cuda.cc
View file @
fdbf4d6c
...
@@ -192,6 +192,9 @@ std::string CodeGenTileLangCUDA::Finish() {
...
@@ -192,6 +192,9 @@ std::string CodeGenTileLangCUDA::Finish() {
decl_stream
<<
"#include <tl_templates/cuda/ldsm.h>
\n
"
;
decl_stream
<<
"#include <tl_templates/cuda/ldsm.h>
\n
"
;
decl_stream
<<
"#include <tl_templates/cuda/threadblock_swizzle.h>
\n
"
;
decl_stream
<<
"#include <tl_templates/cuda/threadblock_swizzle.h>
\n
"
;
decl_stream
<<
"#include <tl_templates/cuda/debug.h>
\n
"
;
decl_stream
<<
"#include <tl_templates/cuda/debug.h>
\n
"
;
decl_stream
<<
"#ifdef ENABLE_BF16
\n
"
;
decl_stream
<<
"#include <tl_templates/cuda/cuda_bf16_fallbacks.cuh>
\n
"
;
decl_stream
<<
"#endif
\n
"
;
if
(
need_global_barrier_
)
{
if
(
need_global_barrier_
)
{
decl_stream
<<
"__device__ unsigned "
<<
vid_global_barrier_state_
decl_stream
<<
"__device__ unsigned "
<<
vid_global_barrier_state_
...
@@ -734,18 +737,67 @@ void CodeGenTileLangCUDA::VisitExpr_(const CastNode *op, std::ostream &os) {
...
@@ -734,18 +737,67 @@ void CodeGenTileLangCUDA::VisitExpr_(const CastNode *op, std::ostream &os) {
this
->
PrintIndent
();
this
->
PrintIndent
();
this
->
PrintType
(
target_ty
,
stream
);
this
->
PrintType
(
target_ty
,
stream
);
stream
<<
' '
<<
sret
<<
";
\n
"
;
stream
<<
' '
<<
sret
<<
";
\n
"
;
{
std
::
string
src
=
SSAGetID
(
PrintExpr
(
op
->
value
),
from_ty
);
std
::
string
src
=
SSAGetID
(
PrintExpr
(
op
->
value
),
from_ty
);
for
(
int
i
=
0
,
lanes
=
from_ty
.
lanes
();
i
<
lanes
;
++
i
)
{
// Handle bfloat16 special cases with supported ops
std
::
ostringstream
val
;
bool
used_bf16_op
=
false
;
val
<<
"("
;
if
(
from_ty
.
is_bfloat16
()
||
target_ty
.
is_bfloat16
())
{
PrintType
(
target_ty
.
element_of
(),
val
);
std
::
ostringstream
func_name
;
val
<<
")("
;
if
(
from_ty
.
is_bfloat16
())
PrintVecElemLoad
(
src
,
from_ty
,
i
,
val
);
func_name
<<
"bf16"
;
val
<<
")"
;
else
if
(
from_ty
.
is_float
())
PrintVecElemStore
(
sret
,
target_ty
,
i
,
val
.
str
());
func_name
<<
"float"
;
if
(
from_ty
.
lanes
()
>
1
)
func_name
<<
from_ty
.
lanes
();
func_name
<<
"2"
;
if
(
target_ty
.
is_bfloat16
())
func_name
<<
"bf16"
;
else
if
(
target_ty
.
is_float
())
func_name
<<
"float"
;
else
if
(
target_ty
==
DataType
::
Int
(
16
))
func_name
<<
"int16"
;
if
(
target_ty
.
lanes
()
>
1
)
func_name
<<
target_ty
.
lanes
();
auto
fname
=
func_name
.
str
();
if
(
bf16_supported_ops_
.
count
(
fname
))
{
used_bf16_op
=
true
;
stream
<<
"#ifdef ENABLE_BF16
\n
"
;
PrintIndent
();
stream
<<
"reinterpret_cast<"
;
if
(
target_ty
.
is_bfloat16
())
stream
<<
"__nv_bfloat16"
;
else
PrintType
(
target_ty
.
element_of
(),
stream
);
if
(
target_ty
.
lanes
()
>
1
)
stream
<<
target_ty
.
lanes
();
stream
<<
" &>("
<<
sret
<<
") = fastertransformer::"
<<
fname
<<
"(reinterpret_cast<"
;
if
(
from_ty
.
is_bfloat16
())
stream
<<
"__nv_bfloat16"
;
else
PrintType
(
from_ty
.
element_of
(),
stream
);
if
(
from_ty
.
lanes
()
>
1
)
stream
<<
from_ty
.
lanes
();
stream
<<
" const &>("
<<
src
<<
"));
\n
"
;
stream
<<
"#else
\n
"
;
}
}
}
}
// Fallback: elementwise cast
for
(
int
i
=
0
,
lanes
=
from_ty
.
lanes
();
i
<
lanes
;
++
i
)
{
std
::
ostringstream
val
;
val
<<
"("
;
PrintType
(
target_ty
.
element_of
(),
val
);
val
<<
")("
;
PrintVecElemLoad
(
src
,
from_ty
,
i
,
val
);
val
<<
")"
;
PrintVecElemStore
(
sret
,
target_ty
,
i
,
val
.
str
());
}
if
(
used_bf16_op
)
{
stream
<<
"#endif
\n
"
;
}
os
<<
sret
;
os
<<
sret
;
}
}
...
...
src/target/codegen_cuda.h
View file @
fdbf4d6c
...
@@ -125,6 +125,9 @@ private:
...
@@ -125,6 +125,9 @@ private:
const
VarNode
*
variable
,
std
::
ostream
&
os
);
const
VarNode
*
variable
,
std
::
ostream
&
os
);
int32_t
GetWmmaFragmentSize
(
const
std
::
string
&
scope
,
const
VarNode
*
variable
,
int32_t
GetWmmaFragmentSize
(
const
std
::
string
&
scope
,
const
VarNode
*
variable
,
int32_t
size
);
int32_t
size
);
std
::
unordered_set
<
std
::
string
>
bf16_supported_ops_
=
{
"bf1622float2"
,
"bf1622int16"
,
"float22bf162"
,
"bf162bf162"
};
};
};
}
// namespace codegen
}
// namespace codegen
...
...
src/tl_templates/cuda/cuda_bf16_fallbacks.cuh
0 → 100644
View file @
fdbf4d6c
// Downloaded from from FasterTransformer v5.2.1
// https://github.com/NVIDIA/FasterTransformer/blob/release/v5.2.1_tag/src/fastertransformer/utils/cuda_bf16_fallbacks.cuh
/*
* Copyright (c) 2019-2022, NVIDIA CORPORATION. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#pragma once
#include "cuda_bf16_wrapper.h"
#include <cuda_fp16.h>
namespace
fastertransformer
{
#ifdef ENABLE_BF16
inline
__device__
float2
bf1622float2
(
const
__nv_bfloat162
val
)
{
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
float2
f_val
;
f_val
.
x
=
__low2float
(
val
);
f_val
.
y
=
__high2float
(
val
);
return
f_val
;
#else
return
__bfloat1622float2
(
val
);
#endif
}
inline
__device__
int16_t
bf1622int16
(
__nv_bfloat162
val
)
{
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
float2
f_val
;
f_val
.
x
=
max
(
min
(
__low2float
(
val
),
127.
f
),
-
128.
f
);
f_val
.
y
=
max
(
min
(
__high2float
(
val
),
127.
f
),
-
128.
f
);
union
{
int8_t
int8
[
2
];
int16_t
int16
;
};
int8
[
0
]
=
static_cast
<
int8_t
>
(
static_cast
<
short
>
(
f_val
.
x
));
int8
[
1
]
=
static_cast
<
int8_t
>
(
static_cast
<
short
>
(
f_val
.
y
));
return
int16
;
#else
val
=
__hmin2
(
val
,
make_bfloat162
(
127.
,
127.
));
val
=
__hmax2
(
val
,
make_bfloat162
(
-
128.
,
-
128.
));
union
{
int8_t
int8
[
2
];
int16_t
int16
;
};
int8
[
0
]
=
static_cast
<
int8_t
>
(
static_cast
<
short
>
(
val
.
x
));
int8
[
1
]
=
static_cast
<
int8_t
>
(
static_cast
<
short
>
(
val
.
y
));
return
int16
;
#endif
}
inline
__device__
__nv_bfloat162
float22bf162
(
const
float2
val
)
{
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
return
__floats2bfloat162_rn
(
val
.
x
,
val
.
y
);
#else
return
__float22bfloat162_rn
(
val
);
#endif
}
inline
__device__
__nv_bfloat162
bf162bf162
(
const
__nv_bfloat16
val
)
{
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
__nv_bfloat162
val2
;
val2
.
x
=
val
;
val2
.
y
=
val
;
return
val2
;
#else
return
__bfloat162bfloat162
(
val
);
#endif
}
inline
__device__
__nv_bfloat162
bf16hadd2
(
const
__nv_bfloat162
x
,
const
__nv_bfloat162
y
)
{
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
float
fxl
,
fxh
,
fyl
,
fyh
;
fxl
=
__low2float
(
x
);
fxh
=
__high2float
(
x
);
fyl
=
__low2float
(
y
);
fyh
=
__high2float
(
y
);
return
__floats2bfloat162_rn
(
fxl
+
fyl
,
fxh
+
fyh
);
#else
return
__hadd2
(
x
,
y
);
#endif
}
inline
__device__
__nv_bfloat16
bf16hadd
(
const
__nv_bfloat16
x
,
const
__nv_bfloat16
y
)
{
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
return
__float2bfloat16
(
__bfloat162float
(
x
)
+
__bfloat162float
(
y
)
);
#else
return
__hadd
(
x
,
y
);
#endif
}
inline
__device__
__nv_bfloat162
bf16hsub2
(
const
__nv_bfloat162
x
,
const
__nv_bfloat162
y
)
{
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
float
fxl
,
fxh
,
fyl
,
fyh
;
fxl
=
__low2float
(
x
);
fxh
=
__high2float
(
x
);
fyl
=
__low2float
(
y
);
fyh
=
__high2float
(
y
);
return
__floats2bfloat162_rn
(
fxl
-
fyl
,
fxh
-
fyh
);
#else
return
__hsub2
(
x
,
y
);
#endif
}
inline
__device__
__nv_bfloat16
bf16hsub
(
const
__nv_bfloat16
x
,
const
__nv_bfloat16
y
)
{
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
return
__float2bfloat16
(
__bfloat162float
(
x
)
-
__bfloat162float
(
y
)
);
#else
return
__hsub
(
x
,
y
);
#endif
}
inline
__device__
__nv_bfloat162
bf16hmul2
(
const
__nv_bfloat162
x
,
const
__nv_bfloat162
y
)
{
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
float
fxl
,
fxh
,
fyl
,
fyh
;
fxl
=
__low2float
(
x
);
fxh
=
__high2float
(
x
);
fyl
=
__low2float
(
y
);
fyh
=
__high2float
(
y
);
return
__floats2bfloat162_rn
(
fxl
*
fyl
,
fxh
*
fyh
);
#else
return
__hmul2
(
x
,
y
);
#endif
}
inline
__device__
__nv_bfloat16
bf16hmul
(
const
__nv_bfloat16
x
,
const
__nv_bfloat16
y
)
{
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
return
__float2bfloat16
(
__bfloat162float
(
x
)
*
__bfloat162float
(
y
)
);
#else
return
__hmul
(
x
,
y
);
#endif
}
inline
__device__
__nv_bfloat162
bf16hfma2
(
const
__nv_bfloat162
x
,
const
__nv_bfloat162
y
,
const
__nv_bfloat162
z
)
{
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
float
fxl
,
fxh
,
fyl
,
fyh
,
fzl
,
fzh
;
fxl
=
__low2float
(
x
);
fxh
=
__high2float
(
x
);
fyl
=
__low2float
(
y
);
fyh
=
__high2float
(
y
);
fzl
=
__low2float
(
z
);
fzh
=
__high2float
(
z
);
return
__floats2bfloat162_rn
(
fxl
*
fyl
+
fzl
,
fxh
*
fyh
+
fzh
);
#else
return
__hfma2
(
x
,
y
,
z
);
#endif
}
inline
__device__
__nv_bfloat16
bf16hfma
(
const
__nv_bfloat16
x
,
const
__nv_bfloat16
y
,
const
__nv_bfloat16
z
)
{
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
return
__float2bfloat16
(
__bfloat162float
(
x
)
*
__bfloat162float
(
y
)
+
__bfloat162float
(
z
));
#else
return
__hfma
(
x
,
y
,
z
);
#endif
}
inline
__device__
__nv_bfloat162
bf16exp2
(
const
__nv_bfloat162
x
)
{
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
float
fxl
,
fxh
;
fxl
=
__low2float
(
x
);
fxh
=
__high2float
(
x
);;
return
__floats2bfloat162_rn
(
expf
(
fxl
),
expf
(
fxh
));
#else
return
h2exp
(
x
);
#endif
}
#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ < 800)
inline
__device__
__nv_bfloat162
operator
*
(
const
__nv_bfloat162
x
,
const
__nv_bfloat162
y
)
{
return
bf16hmul2
(
x
,
y
);
};
inline
__device__
__nv_bfloat162
operator
+
(
const
__nv_bfloat162
x
,
const
__nv_bfloat162
y
)
{
return
bf16hadd2
(
x
,
y
);
};
inline
__device__
__nv_bfloat162
make_bfloat162
(
const
__nv_bfloat16
x
,
const
__nv_bfloat16
y
)
{
__nv_bfloat162
t
;
t
.
x
=
x
;
t
.
y
=
y
;
return
t
;
}
#endif
inline
__device__
__nv_bfloat16
bf16hadd
(
__nv_bfloat16
a
,
__nv_bfloat16
b
,
__nv_bfloat16
c
)
{
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
return
__float2bfloat16
(
__bfloat162float
(
a
)
+
__bfloat162float
(
b
)
+
__bfloat162float
(
c
));
#else
return
a
+
b
+
c
;
#endif
}
inline
__device__
__nv_bfloat16
bf16hadd
(
__nv_bfloat16
a
,
__nv_bfloat16
b
,
__nv_bfloat16
c
,
__nv_bfloat16
d
)
{
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
return
__float2bfloat16
(
__bfloat162float
(
a
)
+
__bfloat162float
(
b
)
+
__bfloat162float
(
c
)
+
__bfloat162float
(
d
));
#else
return
(
__nv_bfloat16
)((
float
)
a
+
(
float
)
b
+
(
float
)
c
+
(
float
)
d
);
#endif
}
inline
__device__
__nv_bfloat162
bf16hadd2
(
__nv_bfloat162
a
,
__nv_bfloat162
b
,
__nv_bfloat162
c
)
{
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
float
fal
,
fah
,
fbl
,
fbh
,
fcl
,
fch
;
fal
=
__low2float
(
a
);
fah
=
__high2float
(
a
);
fbl
=
__low2float
(
b
);
fbh
=
__high2float
(
b
);
fcl
=
__low2float
(
c
);
fch
=
__high2float
(
c
);
return
__floats2bfloat162_rn
(
fal
+
fbl
+
fcl
,
fah
+
fbh
+
fch
);
#else
return
a
+
b
+
c
;
#endif
}
inline
__device__
__nv_bfloat16
bf16hmul
(
__nv_bfloat16
a
,
__nv_bfloat16
b
,
__nv_bfloat16
c
)
{
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
return
__float2bfloat16
(
__bfloat162float
(
a
)
*
__bfloat162float
(
b
)
*
__bfloat162float
(
c
));
#else
return
a
*
b
*
c
;
#endif
}
inline
__device__
__nv_bfloat162
bf16hmul2
(
__nv_bfloat162
a
,
__nv_bfloat162
b
,
__nv_bfloat162
c
)
{
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
float
fal
,
fah
,
fbl
,
fbh
,
fcl
,
fch
;
fal
=
__low2float
(
a
);
fah
=
__high2float
(
a
);
fbl
=
__low2float
(
b
);
fbh
=
__high2float
(
b
);
fcl
=
__low2float
(
c
);
fch
=
__high2float
(
c
);
return
__floats2bfloat162_rn
(
fal
*
fbl
*
fcl
,
fah
*
fbh
*
fch
);
#else
return
a
*
b
*
c
;
#endif
}
inline
__device__
__nv_bfloat162
bf16hfma2
(
__nv_bfloat162
a
,
__nv_bfloat162
b
,
__nv_bfloat162
c
,
__nv_bfloat162
d
)
{
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
float
fal
,
fah
,
fbl
,
fbh
,
fcl
,
fch
,
fdl
,
fdh
;
fal
=
__low2float
(
a
);
fah
=
__high2float
(
a
);
fbl
=
__low2float
(
b
);
fbh
=
__high2float
(
b
);
fcl
=
__low2float
(
c
);
fch
=
__high2float
(
c
);
fdl
=
__low2float
(
d
);
fdh
=
__high2float
(
d
);
return
__floats2bfloat162_rn
(
fal
*
fbl
*
fcl
+
fdl
,
fah
*
fbh
*
fch
+
fdh
);
#else
return
a
*
b
*
c
+
d
;
#endif
}
#endif // ENABLE_BF16
}
// namespace fastertransformer
src/tl_templates/cuda/cuda_bf16_wrapper.h
0 → 100644
View file @
fdbf4d6c
// Downloaded from from FasterTransformer v5.2.1
// https://github.com/NVIDIA/FasterTransformer/blob/release/v5.2.1_tag/src/fastertransformer/utils/cuda_bf16_wrapper.h
/*
* Copyright (c) 2019-2022, NVIDIA CORPORATION. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#pragma once
#ifdef ENABLE_BF16
#include <cuda_bf16.h>
#endif
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment