Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
TransformerEngine
Commits
5b82e699
"git@developer.sourcefind.cn:kecinstone/2024-pra-vllm.git" did not exist on "97460585d9ae2c79fb625d0e4ad48f17b753a2da"
Commit
5b82e699
authored
Jun 12, 2025
by
wenjh
Browse files
Merge branch 'develop_v2.4'
parents
9a815d0b
7f946529
Changes
3
Show whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
15 additions
and
4 deletions
+15
-4
tests/pytorch/test_int8_blockwise_layers.py
tests/pytorch/test_int8_blockwise_layers.py
+1
-1
transformer_engine/common/recipe/fp8_block_scaling.cu
transformer_engine/common/recipe/fp8_block_scaling.cu
+6
-1
transformer_engine/common/transpose/quantize_transpose_square_blockwise.cu
...e/common/transpose/quantize_transpose_square_blockwise.cu
+8
-2
No files found.
tests/pytorch/test_int8_blockwise_layers.py
View file @
5b82e699
...
...
@@ -167,7 +167,7 @@ class TestFP8BlockScalingRecipeLayerNormLinear(TestFP8RecipeLayerNormLinearBase)
dtype
=
dtype
,
y_error
=
0.9
,
ln_out_error
=
0.5
,
dgrad_error
=
1
.5
,
dgrad_error
=
1
,
wgrad_error
=
1
,
bgrad_error
=
0.5
,
recipe1_golden_tensors
=
None
,
...
...
transformer_engine/common/recipe/fp8_block_scaling.cu
View file @
5b82e699
...
...
@@ -116,7 +116,12 @@ __global__ void __launch_bounds__(kThreadsPerBlock)
if
(
h_in_input
<
h
&&
w_in_input
<
w
&&
idx_in_input
>=
start_offset
&&
idx_in_input
<
end_offset
)
{
float
inp
=
static_cast
<
float
>
(
input_minus_offset
[
idx_in_input
])
*
scale
;
if
constexpr
(
std
::
is_same_v
<
OType
,
int8_t
>
)
{
smem
[
h_in_smem
][
w_in_smem
]
=
static_cast
<
OType
>
(
lroundf
(
fmaxf
(
-
127.0
f
,
fminf
(
127.0
f
,
inp
))));
}
else
{
smem
[
h_in_smem
][
w_in_smem
]
=
static_cast
<
OType
>
(
inp
);
}
skip_store
=
false
;
}
}
...
...
transformer_engine/common/transpose/quantize_transpose_square_blockwise.cu
View file @
5b82e699
...
...
@@ -431,9 +431,15 @@ __global__ void __launch_bounds__(THREADS_PER_BLOCK) block_scaled_cast_transpose
for
(
int
j
=
0
;
j
<
THREAD_TILE_DIM_X
;
j
++
)
{
// Step 3: Store cast output
CType
scale_data
=
block_tile_scale
;
OType
scaled_elt
=
OType
scaled_elt
=
0
;
if
constexpr
(
std
::
is_same_v
<
OType
,
int8_t
>
)
{
scaled_elt
=
static_cast
<
OType
>
(
lroundf
(
fmaxf
(
-
127.0
f
,
fminf
(
127.0
f
,
static_cast
<
CType
>
(
thrd_tile_input
[
i
].
data
.
elt
[
j
])
*
scale_data
))));
}
else
{
scaled_elt
=
static_cast
<
OType
>
(
static_cast
<
CType
>
(
thrd_tile_input
[
i
].
data
.
elt
[
j
])
*
scale_data
);
}
tmp_output_c
.
data
.
elt
[
j
]
=
scaled_elt
;
// Step 4: do transpose within thread tile
if
constexpr
(
kReturnTranspose
)
{
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment