Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
TransformerEngine
Commits
4ef4eae6
Commit
4ef4eae6
authored
Jul 02, 2025
by
wenjh
Browse files
Resolve merge issues from develop_v2.4
Signed-off-by:
wenjh
<
wenjh@sugon.com
>
parent
0e886dab
Changes
5
Show whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
119 additions
and
7 deletions
+119
-7
tests/cpp/test_common.h
tests/cpp/test_common.h
+1
-1
tests/pytorch/references/blockwise_quantizer_reference.py
tests/pytorch/references/blockwise_quantizer_reference.py
+1
-0
tests/pytorch/test_float8_blockwise_scaling_exact.py
tests/pytorch/test_float8_blockwise_scaling_exact.py
+1
-1
transformer_engine/common/common.h
transformer_engine/common/common.h
+1
-1
transformer_engine/common/transpose/quantize_transpose_vector_blockwise.cu
...e/common/transpose/quantize_transpose_vector_blockwise.cu
+115
-4
No files found.
tests/cpp/test_common.h
View file @
4ef4eae6
...
@@ -97,7 +97,7 @@ struct BitsNumber {
...
@@ -97,7 +97,7 @@ struct BitsNumber {
template
<
typename
T
>
template
<
typename
T
>
struct
TypeInfo
{
struct
TypeInfo
{
#if FP4_TYPE_SUPPORTED
#if FP4_TYPE_SUPPORTED
using
types
=
std
::
tuple
<
byte
,
int16
,
int32
,
int64
,
fp32
,
fp16
,
bf16
,
fp8e4m3
,
fp8e5m2
,
fp8e8m0
,
fp4e2m1
>
;
using
types
=
std
::
tuple
<
byte
,
int16
,
int32
,
int64
,
fp32
,
fp16
,
bf16
,
fp8e4m3
,
fp8e5m2
,
fp8e8m0
,
int8
,
fp4e2m1
>
;
#else
#else
using
types
=
std
::
tuple
<
byte
,
int16
,
int32
,
int64
,
fp32
,
fp16
,
bf16
,
fp8e4m3
,
fp8e5m2
,
fp8e8m0
,
int8
>
;
using
types
=
std
::
tuple
<
byte
,
int16
,
int32
,
int64
,
fp32
,
fp16
,
bf16
,
fp8e4m3
,
fp8e5m2
,
fp8e8m0
,
int8
>
;
#endif
#endif
...
...
tests/pytorch/references/blockwise_quantizer_reference.py
View file @
4ef4eae6
...
@@ -278,6 +278,7 @@ class BlockwiseQuantizerReference:
...
@@ -278,6 +278,7 @@ class BlockwiseQuantizerReference:
eps
:
float
=
0.0
,
eps
:
float
=
0.0
,
pow_2_scales
:
bool
=
False
,
pow_2_scales
:
bool
=
False
,
quant_tile_shape
:
Tuple
[
int
,
int
]
=
(
blockwise_fp8_block_len
,
blockwise_fp8_block_len
),
quant_tile_shape
:
Tuple
[
int
,
int
]
=
(
blockwise_fp8_block_len
,
blockwise_fp8_block_len
),
munge_scale_shapes
:
bool
=
True
,
)
->
QuantizeResult
:
)
->
QuantizeResult
:
# sanity checks
# sanity checks
assert
x
.
dim
()
==
2
assert
x
.
dim
()
==
2
...
...
tests/pytorch/test_float8_blockwise_scaling_exact.py
View file @
4ef4eae6
...
@@ -121,7 +121,7 @@ def test_quantization_1D_block_tiling_with_compact_data_and_scales(
...
@@ -121,7 +121,7 @@ def test_quantization_1D_block_tiling_with_compact_data_and_scales(
pow_2_scales
:
bool
,
pow_2_scales
:
bool
,
)
->
None
:
)
->
None
:
te_dtype
=
TE_DType
[
quant_dtype
]
te_dtype
=
TE_DType
[
quant_dtype
]
tile_size
=
(
1
,
128
)
tile_size
=
(
1
,
blockwise_fp8_block_len
)
# This test runs a comparison of the ref class versus the class using
# This test runs a comparison of the ref class versus the class using
# CUDA kernels to quantize. They should quantize identically for pixels
# CUDA kernels to quantize. They should quantize identically for pixels
# that are not DC values in the scale factor shape.
# that are not DC values in the scale factor shape.
...
...
transformer_engine/common/common.h
View file @
4ef4eae6
...
@@ -395,7 +395,7 @@ struct BitsNumber {
...
@@ -395,7 +395,7 @@ struct BitsNumber {
template
<
typename
T
>
template
<
typename
T
>
struct
TypeInfo
{
struct
TypeInfo
{
#if FP4_TYPE_SUPPORTED
#if FP4_TYPE_SUPPORTED
using
types
=
std
::
tuple
<
byte
,
int16
,
int32
,
int64
,
fp32
,
fp16
,
bf16
,
fp8e4m3
,
fp8e5m2
,
fp4e2m1
>
;
using
types
=
std
::
tuple
<
byte
,
int16
,
int32
,
int64
,
fp32
,
fp16
,
bf16
,
fp8e4m3
,
fp8e5m2
,
int8
,
fp4e2m1
>
;
#else
#else
using
types
=
std
::
tuple
<
byte
,
int16
,
int32
,
int64
,
fp32
,
fp16
,
bf16
,
fp8e4m3
,
fp8e5m2
,
int8
>
;
using
types
=
std
::
tuple
<
byte
,
int16
,
int32
,
int64
,
fp32
,
fp16
,
bf16
,
fp8e4m3
,
fp8e5m2
,
int8
>
;
#endif
#endif
...
...
transformer_engine/common/transpose/quantize_transpose_vector_blockwise.cu
View file @
4ef4eae6
...
@@ -575,9 +575,11 @@ __global__ void __launch_bounds__(kThreadsPerBlock)
...
@@ -575,9 +575,11 @@ __global__ void __launch_bounds__(kThreadsPerBlock)
const
size_t
scale_t_stride_x
,
const
size_t
scale_t_stride_y
,
const
float
epsilon
,
const
size_t
scale_t_stride_x
,
const
size_t
scale_t_stride_y
,
const
float
epsilon
,
FP8BlockwiseRowwiseOption
rowwise_option
,
FP8BlockwiseColumnwiseOption
columnwise_option
,
FP8BlockwiseRowwiseOption
rowwise_option
,
FP8BlockwiseColumnwiseOption
columnwise_option
,
const
bool
pow_2_scaling
)
{
const
bool
pow_2_scaling
)
{
bool
return_rowwise
=
rowwise_option
==
FP8BlockwiseRowwiseOption
::
ROWWISE
;
bool
return_rowwise
=
rowwise_option
!=
FP8BlockwiseRowwiseOption
::
NONE
;
bool
return_columnwise_transpose
=
bool
return_columnwise_gemm_ready
=
columnwise_option
==
FP8BlockwiseColumnwiseOption
::
COLUMNWISE_TRANSPOSE
;
columnwise_option
==
FP8BlockwiseColumnwiseOption
::
COLUMNWISE_GEMM_READY
;
bool
return_columnwise_compact
=
columnwise_option
==
FP8BlockwiseColumnwiseOption
::
COLUMNWISE_COMPACT
;
using
SMemVec
=
Vec
<
IType
,
kNVecSMem
>
;
using
SMemVec
=
Vec
<
IType
,
kNVecSMem
>
;
using
OVec
=
Vec
<
OType
,
kNVecOut
>
;
using
OVec
=
Vec
<
OType
,
kNVecOut
>
;
...
@@ -742,7 +744,7 @@ __global__ void __launch_bounds__(kThreadsPerBlock)
...
@@ -742,7 +744,7 @@ __global__ void __launch_bounds__(kThreadsPerBlock)
}
}
// Step 3: Transpose, cast and store to output_t
// Step 3: Transpose, cast and store to output_t
if
(
return_columnwise_
transpose
)
{
if
(
return_columnwise_
gemm_ready
)
{
constexpr
int
c_stride
=
constexpr
int
c_stride
=
kThreadsPerBlock
/
kNumThreadsStore64
;
// Stride in columns of shared memory
kThreadsPerBlock
/
kNumThreadsStore64
;
// Stride in columns of shared memory
constexpr
int
total_smem_cols
=
kTileDim64
/
kNVecSMem
;
constexpr
int
total_smem_cols
=
kTileDim64
/
kNVecSMem
;
...
@@ -848,6 +850,115 @@ __global__ void __launch_bounds__(kThreadsPerBlock)
...
@@ -848,6 +850,115 @@ __global__ void __launch_bounds__(kThreadsPerBlock)
}
}
}
}
}
}
// Step 4 (return_columnwise_compact): cast in 128x1 style and store to output, skip transpose
if
(
return_columnwise_compact
)
{
// thread tile should be 4x16, 16 means 8 smem reads
constexpr
int
kThreadTileRow
=
kTileDim64
/
kThreadsPerWarp
;
constexpr
int
kThreadTileCol
=
kNVecOut
;
using
RegVec
=
Vec
<
IType
,
kThreadTileCol
>
;
using
RegScaleVec
=
Vec
<
CType
,
kThreadTileCol
>
;
constexpr
int
num_smem_reads
=
kNVecOut
/
kNVecSMem
;
// c_stride will not be used here because we only have one iteration
// constexpr int c_stride = kThreadTileCol * kNumWarps / kNVecSMem;
constexpr
int
kNumColBlocks
=
kTileDim64
/
kThreadTileCol
;
constexpr
int
num_iterations
=
1
;
// should be only one iteration
static_assert
(
num_iterations
==
1
,
"num_iterations should be 1 for columnwise non-transpose case"
);
const
int
thr_idx_in_warp
=
threadIdx
.
x
%
kThreadsPerWarp
;
const
int
warp_idx
=
threadIdx
.
x
/
kThreadsPerWarp
;
if
(
warp_idx
>=
kNumColBlocks
)
{
return
;
// No work to do
}
const
int
r_s
=
thr_idx_in_warp
*
kThreadTileRow
;
// Row in shared memory
int
c_s
=
warp_idx
*
num_smem_reads
;
// Column in shared memory
size_t
r_g
=
static_cast
<
size_t
>
(
blockIdx
.
y
)
*
kTileDim64
+
r_s
;
// Row in global memory
const
size_t
c_g
=
static_cast
<
size_t
>
(
blockIdx
.
x
)
*
kTileDim64
+
c_s
*
kNVecSMem
;
// Column in global memory
const
size_t
num_ele
=
c_g
<
row_length
?
min
(
static_cast
<
size_t
>
(
kThreadTileCol
),
row_length
-
c_g
)
:
0
;
// For not aligned case
#pragma unroll
for
(
int
iter
=
0
;
iter
<
num_iterations
;
++
iter
)
{
RegVec
reg_vec
[
kThreadTileRow
];
RegScaleVec
thr_scale
;
// Step 3.1: Load from shared memory to registers
#pragma unroll
for
(
int
i
=
0
;
i
<
kThreadTileRow
;
++
i
)
{
int
r
=
r_s
+
i
;
#pragma unroll
for
(
int
j
=
0
;
j
<
num_smem_reads
;
++
j
)
{
int
c
=
c_s
+
j
;
SMemVec
smem_vec
=
smem
[
r
*
kSMemCol64
+
c
];
// copy smem_vec to reg vec with its elements
#pragma unroll
for
(
int
k
=
0
;
k
<
kNVecSMem
;
++
k
)
{
reg_vec
[
i
].
data
.
elt
[
j
*
kNVecSMem
+
k
]
=
smem_vec
.
data
.
elt
[
k
];
}
}
}
#pragma unroll
for
(
int
reg_idx
=
0
;
reg_idx
<
kThreadTileCol
;
++
reg_idx
)
{
// Step 3.2: Compute local amax
CType
amax
=
0
;
#pragma unroll
for
(
int
i
=
0
;
i
<
kThreadTileRow
;
++
i
)
{
amax
=
fmaxf
(
amax
,
fabsf
(
reg_vec
[
i
].
data
.
elt
[
reg_idx
]));
}
// Step 3.3: Reduce amax
const
bool
is_src_lane
=
thr_idx_in_warp
==
0
;
amax
=
warp_reduce_max
<
kThreadsPerWarp
>
(
amax
);
constexpr
int
lane_zero
=
0
;
#ifdef __HIP_PLATFORM_AMD__
amax
=
__shfl_sync
((
unsigned
long
long
)(
0xFFFFFFFF
),
amax
,
lane_zero
,
kThreadsPerWarp
);
#else
amax
=
__shfl_sync
(
0xFFFFFFFF
,
amax
,
lane_zero
);
#endif
// Step 3.4: Compute scale
CType
scale
;
scale
=
compute_scale_from_types
<
IType
,
OType
>
(
amax
,
epsilon
,
pow_2_scaling
);
thr_scale
.
data
.
elt
[
reg_idx
]
=
scale
;
// Step 3.5: Write scale_inv_t
bool
write_scale_inv
=
is_src_lane
;
if
constexpr
(
!
kAligned
)
{
write_scale_inv
&=
(
c_g
+
reg_idx
<
row_length
);
}
if
(
write_scale_inv
)
{
CType
scale_inv
=
1.0
/
scale
;
size_t
row_idx
=
static_cast
<
size_t
>
(
blockIdx
.
y
);
size_t
col_idx
=
static_cast
<
size_t
>
(
blockIdx
.
x
)
*
kTileDim64
+
c_s
*
kNVecSMem
+
reg_idx
;
tile_scales_inv_t
[
row_idx
*
scale_t_stride_y
+
col_idx
*
scale_t_stride_x
]
=
scale_inv
;
}
}
// Step 3.6: Quantize
for
(
int
row_idx
=
0
;
row_idx
<
kThreadTileRow
;
++
row_idx
)
{
OType
*
output_g
=
&
output_t
[(
r_g
+
row_idx
)
*
row_length
+
c_g
];
// Output address in global memory
OVec
output_vec
;
#pragma unroll
for
(
int
i
=
0
;
i
<
kThreadTileCol
;
++
i
)
{
if
constexpr
(
std
::
is_same_v
<
OType
,
int8_t
>
)
{
output_vec
.
data
.
elt
[
i
]
=
static_cast
<
OType
>
(
lroundf
(
fmaxf
(
-
127.0
f
,
fminf
(
127.0
f
,
static_cast
<
CType
>
(
reg_vec
[
row_idx
].
data
.
elt
[
i
])
*
thr_scale
.
data
.
elt
[
i
]))));
}
else
{
output_vec
.
data
.
elt
[
i
]
=
static_cast
<
OType
>
(
static_cast
<
CType
>
(
reg_vec
[
row_idx
].
data
.
elt
[
i
])
*
thr_scale
.
data
.
elt
[
i
]);
}
}
// Step 3.7: Store output_t
if
constexpr
(
kAligned
)
{
output_vec
.
store_to
(
output_g
);
}
else
{
if
(
r_g
+
row_idx
<
num_rows
)
{
output_vec
.
store_to_elts
(
output_g
,
0
,
num_ele
);
}
}
}
// Step 3.8: Update output address, column index of shared memory
// this section shouldn't matter since we only have one iteration
}
}
}
}
#ifdef __HIP_PLATFORM_AMD__
#ifdef __HIP_PLATFORM_AMD__
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment