Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
TransformerEngine
Commits
9a815d0b
Commit
9a815d0b
authored
Jun 12, 2025
by
wenjh
Browse files
Merge branch 'develop_v2.4'
parents
3d57ff8c
e2860c76
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
111 additions
and
17 deletions
+111
-17
tests/pytorch/references/blockwise_quantizer_reference.py
tests/pytorch/references/blockwise_quantizer_reference.py
+2
-0
tests/pytorch/test_float8_blockwise_scaling_exact.py
tests/pytorch/test_float8_blockwise_scaling_exact.py
+8
-5
transformer_engine/common/transpose/quantize_transpose_vector_blockwise.cu
...e/common/transpose/quantize_transpose_vector_blockwise.cu
+101
-12
No files found.
tests/pytorch/references/blockwise_quantizer_reference.py
View file @
9a815d0b
...
...
@@ -170,6 +170,8 @@ class BlockwiseQuantizerReference:
)
qx
=
x_tiled
*
scale
.
reshape
(
M
,
K
//
tile_len
,
1
)
qx
=
torch
.
clamp
(
qx
,
min
=-
dtype_max
,
max
=
dtype_max
)
if
quant_dtype
==
torch
.
int8
:
qx
=
torch
.
round
(
qx
)
qx
=
qx
.
to
(
dtype
=
quant_dtype
)
qx
=
qx
.
reshape
(
M
,
K
)
return
qx
,
scale_inv
...
...
tests/pytorch/test_float8_blockwise_scaling_exact.py
View file @
9a815d0b
...
...
@@ -153,7 +153,7 @@ def check_quantization_block_tiling_versus_reference(
)
# Check
torch
.
testing
.
assert_close
(
qx
.
float
(),
qx_ref
.
float
(),
atol
=
0.0
,
rtol
=
0.0
)
torch
.
testing
.
assert_close
(
qx
.
float
(),
qx_ref
.
float
(),
atol
=
0.0
if
quant_dtype
!=
torch
.
int8
else
1.0
,
rtol
=
0.0
)
# Zero out values that are don't care values
# Scale format has padding.
scale_mask
=
torch
.
ones
(
...
...
@@ -163,7 +163,7 @@ def check_quantization_block_tiling_versus_reference(
QuantizeResult
(
qx
,
scale_mask
,
None
,
None
),
tile_size
).
scale
sx
=
sx
*
scale_mask
torch
.
testing
.
assert_close
(
sx
,
sx_ref
,
atol
=
0.0
,
rtol
=
0.0
)
torch
.
testing
.
assert_close
(
sx
,
sx_ref
,
atol
=
0.0
if
x_dtype
!=
torch
.
float32
else
1e-5
,
rtol
=
0.0
if
x_dtype
!=
torch
.
float32
else
5e-5
)
if
return_transpose
:
assert
qx_t
is
not
None
...
...
@@ -179,8 +179,8 @@ def check_quantization_block_tiling_versus_reference(
QuantizeResult
(
qx_t
,
scale_mask
,
None
,
None
),
tile_size
).
scale
sx_t
=
sx_t
*
scale_mask
torch
.
testing
.
assert_close
(
qx_t
.
float
(),
qx_t_ref
.
float
(),
atol
=
0.0
,
rtol
=
0.0
)
torch
.
testing
.
assert_close
(
sx_t
,
sx_t_ref
,
atol
=
0.0
,
rtol
=
0.0
)
torch
.
testing
.
assert_close
(
qx_t
.
float
(),
qx_t_ref
.
float
(),
atol
=
0.0
if
quant_dtype
!=
torch
.
int8
else
1.0
,
rtol
=
0.0
if
x_dtype
!=
torch
.
float32
else
2.5e-1
)
torch
.
testing
.
assert_close
(
sx_t
,
sx_t_ref
,
atol
=
0.0
if
x_dtype
!=
torch
.
float32
else
1e-5
,
rtol
=
0.0
if
x_dtype
!=
torch
.
float32
else
5e-5
)
else
:
# should be None
assert
qx_t
is
None
and
qx_t_ref
is
None
...
...
@@ -344,7 +344,10 @@ def test_quantization_block_tiling_extrema_versus_reference(
torch
.
testing
.
assert_close
(
sx
.
flatten
()[
0
],
sx_ref
.
flatten
()[
0
],
atol
=
0.0
,
rtol
=
0.0
)
if
extrema_high
:
expected_value
=
torch
.
finfo
(
quant_dtype
).
max
/
torch
.
finfo
(
x_dtype
).
max
if
quant_dtype
==
torch
.
int8
:
expected_value
=
torch
.
iinfo
(
quant_dtype
).
max
/
torch
.
finfo
(
x_dtype
).
max
else
:
expected_value
=
torch
.
finfo
(
quant_dtype
).
max
/
torch
.
finfo
(
x_dtype
).
max
if
pow_2_scales
:
expected_value
=
math
.
floor
(
math
.
log2
(
expected_value
))
expected_value
=
math
.
pow
(
2.0
,
expected_value
)
...
...
transformer_engine/common/transpose/quantize_transpose_vector_blockwise.cu
View file @
9a815d0b
...
...
@@ -27,6 +27,90 @@
#include "common/utils.cuh"
namespace
transformer_engine
{
#ifdef __HIP_PLATFORM_AMD__
__device__
bool
is_little_endian
()
{
int
num
=
1
;
const
char
*
ptr
=
reinterpret_cast
<
const
char
*>
(
&
num
);
if
(
*
ptr
==
1
)
{
return
true
;
}
else
{
return
false
;
}
}
struct
BitFloat
{
private:
char
data
[
3
];
public:
__device__
BitFloat
(
const
float
val
,
bool
pow2scale
)
{
uint32_t
raw_val
=
*
reinterpret_cast
<
const
uint32_t
*>
(
&
val
);
if
(
~
raw_val
&
0x7f800000
)
{
if
(
pow2scale
&&
(
raw_val
&
0x000000FF
))
{
raw_val
|=
0x100
;
}
else
{
raw_val
+=
0x7f
+
((
raw_val
>>
8
)
&
1
);
}
}
else
if
(
raw_val
&
0xffff
)
{
raw_val
|=
0x100
;
}
raw_val
=
(
raw_val
>>
8
);
const
char
*
ptr
=
reinterpret_cast
<
const
char
*>
(
&
raw_val
);
if
(
is_little_endian
())
{
data
[
0
]
=
ptr
[
0
];
data
[
1
]
=
ptr
[
1
];
data
[
2
]
=
ptr
[
2
];
}
else
{
data
[
0
]
=
ptr
[
1
];
data
[
1
]
=
ptr
[
2
];
data
[
2
]
=
ptr
[
3
];
}
}
__device__
operator
float
()
const
{
uint32_t
raw_val
=
0
;
char
*
ptr
=
reinterpret_cast
<
char
*>
(
&
raw_val
);
if
(
is_little_endian
())
{
ptr
[
1
]
=
data
[
0
];
ptr
[
2
]
=
data
[
1
];
ptr
[
3
]
=
data
[
2
];
}
else
{
ptr
[
0
]
=
data
[
0
];
ptr
[
1
]
=
data
[
1
];
ptr
[
2
]
=
data
[
2
];
}
return
*
reinterpret_cast
<
const
float
*>
(
&
raw_val
);
}
};
struct
BitFloat2
{
BitFloat
u
;
BitFloat
v
;
};
template
<
>
struct
BytesToType
<
6
>
{
using
Type
=
BitFloat2
;
static_assert
(
sizeof
(
Type
)
==
6
);
};
#endif
namespace
{
using
transformer_engine
::
detail
::
FP8BlockwiseColumnwiseOption
;
...
...
@@ -169,7 +253,7 @@ __global__ void __launch_bounds__(kThreadsPerBlock) block_scaled_1d_cast_transpo
extern
__shared__
char
smem_base
[];
#ifdef __HIP_PLATFORM_AMD__
using
HipSMemVec
=
Vec
<
std
::
conditional_t
<
std
::
is_same_v
<
IType
,
float
>
,
__hip_bf
loat
16
,
IType
>
,
kNVecSMem
>
;
using
HipSMemVec
=
Vec
<
std
::
conditional_t
<
std
::
is_same_v
<
IType
,
float
>
,
BitF
loat
,
IType
>
,
kNVecSMem
>
;
HipSMemVec
*
smem
=
reinterpret_cast
<
HipSMemVec
*>
(
&
smem_base
[
0
]);
#else
SMemVec
*
smem
=
reinterpret_cast
<
SMemVec
*>
(
&
smem_base
[
0
]);
...
...
@@ -213,14 +297,7 @@ __global__ void __launch_bounds__(kThreadsPerBlock) block_scaled_1d_cast_transpo
#pragma unroll
for
(
int
j
=
0
;
j
<
kNVecSMem
;
++
j
)
{
uint32_t
raw_val
=
*
reinterpret_cast
<
const
uint32_t
*>
(
&
input_vec
.
smem_type
.
data
.
elt
[
i
].
data
.
elt
[
j
]);
// [Workaround] Under certain critical conditions, scale will be 2 * ref_scale because of float -> bfloat16.
// We use carry over here to avoid this issue.
if
(
pow_2_scaling
&&
(
raw_val
&
0x0000FFFF
))
{
raw_val
|=
0x00010000
;
}
smem
[
r
*
kSMemCol
+
c
].
data
.
elt
[
j
]
=
*
reinterpret_cast
<
const
float
*>
(
&
raw_val
);
smem
[
r
*
kSMemCol
+
c
].
data
.
elt
[
j
]
=
BitFloat
(
input_vec
.
smem_type
.
data
.
elt
[
i
].
data
.
elt
[
j
],
pow_2_scaling
);
}
}
else
...
...
@@ -335,8 +412,14 @@ __global__ void __launch_bounds__(kThreadsPerBlock) block_scaled_1d_cast_transpo
for
(
int
i
=
0
;
i
<
kNVecOut
/
kNVecSMem
;
++
i
)
{
#pragma unroll
for
(
int
j
=
0
;
j
<
kNVecSMem
;
++
j
)
{
output_vec
.
data
.
elt
[
i
*
kNVecSMem
+
j
]
=
if
constexpr
(
std
::
is_same_v
<
OType
,
int8_t
>
)
{
output_vec
.
data
.
elt
[
i
*
kNVecSMem
+
j
]
=
static_cast
<
OType
>
(
lroundf
(
fmaxf
(
-
127.0
f
,
fminf
(
127.0
f
,
static_cast
<
CType
>
(
smem_vec
[
i
].
data
.
elt
[
j
])
*
scale
))));
}
else
{
output_vec
.
data
.
elt
[
i
*
kNVecSMem
+
j
]
=
static_cast
<
OType
>
(
static_cast
<
CType
>
(
smem_vec
[
i
].
data
.
elt
[
j
])
*
scale
);
}
}
}
// Step 2.7: Store output_c
...
...
@@ -445,8 +528,14 @@ __global__ void __launch_bounds__(kThreadsPerBlock) block_scaled_1d_cast_transpo
OVec
output_vec
;
#pragma unroll
for
(
int
i
=
0
;
i
<
kNVecOut
;
++
i
)
{
output_vec
.
data
.
elt
[
i
]
=
if
constexpr
(
std
::
is_same_v
<
OType
,
int8_t
>
)
{
output_vec
.
data
.
elt
[
i
]
=
static_cast
<
OType
>
(
lroundf
(
fmaxf
(
-
127.0
f
,
fminf
(
127.0
f
,
static_cast
<
CType
>
(
smem_vec
[
i
].
data
.
elt
[
smem_idx
])
*
scale
))));
}
else
{
output_vec
.
data
.
elt
[
i
]
=
static_cast
<
OType
>
(
static_cast
<
CType
>
(
smem_vec
[
i
].
data
.
elt
[
smem_idx
])
*
scale
);
}
}
// Step 3.7: Store output_t
if
constexpr
(
kAligned
)
{
...
...
@@ -550,7 +639,7 @@ void quantize_transpose_vector_blockwise(const SimpleTensor& input, SimpleTensor
full_tile
,
kAligned
,
#ifdef __HIP_PLATFORM_AMD__
using
HipSMemType
=
std
::
conditional_t
<
std
::
is_same_v
<
InputType
,
float
>
,
hip_bf
loat
16
,
InputType
>
;
using
HipSMemType
=
std
::
conditional_t
<
std
::
is_same_v
<
InputType
,
float
>
,
BitF
loat
,
InputType
>
;
size_t
smem_bytes
=
kSMemSize
*
sizeof
(
HipSMemType
);
#else
size_t
smem_bytes
=
kSMemSize
*
sizeof
(
InputType
);
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment