Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
TransformerEngine
Commits
2b05e121
Commit
2b05e121
authored
Jun 17, 2025
by
yuguo
Browse files
Merge commit '
a69692ac
' of...
Merge commit '
a69692ac
' of
https://github.com/NVIDIA/TransformerEngine
parents
0fd441c2
a69692ac
Changes
245
Show whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
414 additions
and
102 deletions
+414
-102
transformer_engine/common/transpose/quantize_transpose_vector_blockwise.cu
...e/common/transpose/quantize_transpose_vector_blockwise.cu
+152
-23
transformer_engine/common/transpose/transpose.cu
transformer_engine/common/transpose/transpose.cu
+3
-4
transformer_engine/common/transpose/transpose_fusion.cu
transformer_engine/common/transpose/transpose_fusion.cu
+8
-8
transformer_engine/common/util/cast.cu
transformer_engine/common/util/cast.cu
+44
-2
transformer_engine/common/util/cast_gated_kernels.cuh
transformer_engine/common/util/cast_gated_kernels.cuh
+18
-16
transformer_engine/common/util/cast_kernels.cuh
transformer_engine/common/util/cast_kernels.cuh
+34
-21
transformer_engine/common/util/cuda_driver.cpp
transformer_engine/common/util/cuda_driver.cpp
+30
-4
transformer_engine/common/util/cuda_driver.h
transformer_engine/common/util/cuda_driver.h
+1
-1
transformer_engine/common/util/dequantize_kernels.cuh
transformer_engine/common/util/dequantize_kernels.cuh
+2
-2
transformer_engine/common/util/multi_stream.cpp
transformer_engine/common/util/multi_stream.cpp
+61
-0
transformer_engine/common/util/multi_stream.h
transformer_engine/common/util/multi_stream.h
+20
-0
transformer_engine/common/util/padding.cu
transformer_engine/common/util/padding.cu
+4
-4
transformer_engine/common/util/pybind_helper.h
transformer_engine/common/util/pybind_helper.h
+4
-0
transformer_engine/common/utils.cuh
transformer_engine/common/utils.cuh
+4
-0
transformer_engine/debug/features/api.py
transformer_engine/debug/features/api.py
+3
-3
transformer_engine/debug/features/fake_quant.py
transformer_engine/debug/features/fake_quant.py
+1
-1
transformer_engine/debug/features/log_fp8_tensor_stats.py
transformer_engine/debug/features/log_fp8_tensor_stats.py
+0
-1
transformer_engine/debug/features/per_tensor_scaling.py
transformer_engine/debug/features/per_tensor_scaling.py
+3
-3
transformer_engine/debug/features/utils/stats_computation.py
transformer_engine/debug/features/utils/stats_computation.py
+5
-2
transformer_engine/debug/pytorch/debug_quantization.py
transformer_engine/debug/pytorch/debug_quantization.py
+17
-7
No files found.
transformer_engine/common/transpose/quantize_transpose_vector_blockwise.cu
View file @
2b05e121
...
...
@@ -190,14 +190,14 @@ Step 2: Cast and store to output_c
| ... |
+-------------------------------+-------------------------------+-------------------------------+-------------------------------+
Step 3: Transpose, cast and store to output_t
Step 3
(if columnwise transpose is True, GEMM_READY)
: Transpose, cast and store to output_t
* shard memory: 128x128 elements with type=InputType (below graph doesn't consider padding)
* 8 warps
* Loop 2 times
* What each thread does in each loop:
* 2 elements (in a row) are read from the shared memory at a time, for a total of 16 times
* Every 8 consecutive threads do reduction and calculate the amax of each column
* 16 elements are quantized and write to output_
c
at a time, for a total of 2 times
* 16 elements are quantized and write to output_
t
at a time, for a total of 2 times
+------8 elements-------+------8 elements-------+-----40 elements-------+------8 elements-------+------8 elements-------+------8 elements-------+-----40 elements-------+------8 elements-------+
| T0 | T8 | T16 | T24 | | | | T0 | T8 | T16 | T24 | | | |
| T1 | T9 | T17 | T25 | | | | T1 | T9 | T17 | T25 | | | |
...
...
@@ -209,6 +209,29 @@ Step 3: Transpose, cast and store to output_t
| T7 | T15 | T23 | T31 | | | | T7 | T15 | T23 | T31 | | | |
+-----------------------+-----------------------+-----------------------+-----------------------+-----------------------+-----------------------+-----------------------+-----------------------+
Step 3 (if columnwise transpose is False, COMPACT format): Skip Transpose, cast and store to output_t
* shard memory: 128x128 elements with type=InputType (below graph doesn't consider padding)
* 8 warps
* Loop 1 times
* What each thread does in each loop:
* 16 elements (in a row) are read from the shared memory, for a total of 4 rows,
* it needs 8 reads in smem to get 16 elements in a row, thread tile shape is 16x4
* Every 32 consecutive threads in a warp do reduction and calculate the amax of each column,
* so each thread will do warp shuffle 16 times to get the amax of each column
* 16 elements are quantized and write to output_t at a time, for a total of 4 times
+------16 elements-------+------16 elements-------+-----80 elements-----+------16 elements------+
| T0 | | | |
| T1 | | | |
| T2 | | | |
| T3 | | | |
| T4 | | | |
| T5 | | | |
| T6 | | | |
| T7 | | | |
| ... | | | |
| T31 | | | |
+-----------------------+-----------------------+-----------------------+-----------------------+
*/
// clang-format on
...
...
@@ -231,6 +254,7 @@ constexpr int kNumThreadsLoad = kTileDim / kNVecIn;
constexpr
int
kNumThreadsStore
=
kTileDim
/
kNVecOut
;
static_assert
(
kNumThreadsLoad
<=
kThreadsPerWarp
,
"kNumThreadsLoad must be <= kThreadsPerWarp"
);
static_assert
(
kNumThreadsStore
<=
kThreadsPerWarp
,
"kNumThreadsStore must be <= kThreadsPerWarp"
);
constexpr
int
kNumWarps
=
kThreadsPerBlock
/
kThreadsPerWarp
;
template
<
bool
kAligned
,
typename
CType
,
typename
IType
,
typename
OType
>
__global__
void
__launch_bounds__
(
kThreadsPerBlock
)
block_scaled_1d_cast_transpose_kernel
(
...
...
@@ -240,9 +264,11 @@ __global__ void __launch_bounds__(kThreadsPerBlock) block_scaled_1d_cast_transpo
const
size_t
scale_t_stride_x
,
const
size_t
scale_t_stride_y
,
const
float
epsilon
,
FP8BlockwiseRowwiseOption
rowwise_option
,
FP8BlockwiseColumnwiseOption
columnwise_option
,
const
bool
pow_2_scaling
)
{
bool
return_rowwise
=
rowwise_option
==
FP8BlockwiseRowwiseOption
::
ROWWISE
;
bool
return_columnwise_transpose
=
columnwise_option
==
FP8BlockwiseColumnwiseOption
::
COLUMNWISE_TRANSPOSE
;
bool
return_rowwise
=
rowwise_option
!=
FP8BlockwiseRowwiseOption
::
NONE
;
bool
return_columnwise_gemm_ready
=
columnwise_option
==
FP8BlockwiseColumnwiseOption
::
COLUMNWISE_GEMM_READY
;
bool
return_columnwise_compact
=
columnwise_option
==
FP8BlockwiseColumnwiseOption
::
COLUMNWISE_COMPACT
;
using
SMemVec
=
Vec
<
IType
,
kNVecSMem
>
;
using
OVec
=
Vec
<
OType
,
kNVecOut
>
;
...
...
@@ -439,8 +465,8 @@ __global__ void __launch_bounds__(kThreadsPerBlock) block_scaled_1d_cast_transpo
}
}
// Step 3: Transpose, cast and store to output_t
if
(
return_columnwise_
transpose
)
{
// Step 3
(return_columnwise_gemm_ready)
: Transpose, cast and store to output_t
if
(
return_columnwise_
gemm_ready
)
{
constexpr
int
c_stride
=
kThreadsPerBlock
/
kNumThreadsStore
;
// Stride in columns of shared memory
constexpr
int
num_iterations
=
kTileDim
/
(
c_stride
*
kNVecSMem
);
...
...
@@ -554,6 +580,103 @@ __global__ void __launch_bounds__(kThreadsPerBlock) block_scaled_1d_cast_transpo
}
}
}
// Step 4 (return_columnwise_compact): cast in 128x1 style and store to output, skip transpose
if
(
return_columnwise_compact
)
{
// thread tile should be 4x16, 16 means 8 smem reads
constexpr
int
kThreadTileRow
=
kTileDim
/
kThreadsPerWarp
;
constexpr
int
kThreadTileCol
=
kNVecOut
;
using
RegVec
=
Vec
<
IType
,
kThreadTileCol
>
;
using
RegScaleVec
=
Vec
<
CType
,
kThreadTileCol
>
;
constexpr
int
num_smem_reads
=
kNVecOut
/
kNVecSMem
;
// c_stride will not be used here because we only have one iteration
// constexpr int c_stride = kThreadTileCol * kNumWarps / kNVecSMem;
constexpr
int
num_iterations
=
kTileDim
/
(
kNumWarps
*
kThreadTileCol
);
// should be only one iteration
static_assert
(
num_iterations
==
1
,
"num_iterations should be 1 for columnwise non-transpose case"
);
const
int
thr_idx_in_warp
=
threadIdx
.
x
%
kThreadsPerWarp
;
const
int
warp_idx
=
threadIdx
.
x
/
kThreadsPerWarp
;
const
int
r_s
=
thr_idx_in_warp
*
kThreadTileRow
;
// Row in shared memory
int
c_s
=
warp_idx
*
num_smem_reads
;
// Column in shared memory
size_t
r_g
=
static_cast
<
size_t
>
(
blockIdx
.
y
)
*
kTileDim
+
r_s
;
// Row in global memory
const
size_t
c_g
=
static_cast
<
size_t
>
(
blockIdx
.
x
)
*
kTileDim
+
c_s
*
kNVecSMem
;
// Column in global memory
const
size_t
num_ele
=
c_g
<
row_length
?
min
(
static_cast
<
size_t
>
(
kThreadTileCol
),
row_length
-
c_g
)
:
0
;
// For not aligned case
#pragma unroll
for
(
int
iter
=
0
;
iter
<
num_iterations
;
++
iter
)
{
RegVec
reg_vec
[
kThreadTileRow
];
RegScaleVec
thr_scale
;
// Step 3.1: Load from shared memory to registers
#pragma unroll
for
(
int
i
=
0
;
i
<
kThreadTileRow
;
++
i
)
{
int
r
=
r_s
+
i
;
#pragma unroll
for
(
int
j
=
0
;
j
<
num_smem_reads
;
++
j
)
{
int
c
=
c_s
+
j
;
SMemVec
smem_vec
=
smem
[
r
*
kSMemCol
+
c
];
// copy smem_vec to reg vec with its elements
#pragma unroll
for
(
int
k
=
0
;
k
<
kNVecSMem
;
++
k
)
{
reg_vec
[
i
].
data
.
elt
[
j
*
kNVecSMem
+
k
]
=
smem_vec
.
data
.
elt
[
k
];
}
}
}
#pragma unroll
for
(
int
reg_idx
=
0
;
reg_idx
<
kThreadTileCol
;
++
reg_idx
)
{
// Step 3.2: Compute local amax
CType
amax
=
0
;
#pragma unroll
for
(
int
i
=
0
;
i
<
kThreadTileRow
;
++
i
)
{
amax
=
fmaxf
(
amax
,
fabsf
(
reg_vec
[
i
].
data
.
elt
[
reg_idx
]));
}
// Step 3.3: Reduce amax
const
bool
is_src_lane
=
thr_idx_in_warp
==
0
;
amax
=
warp_reduce_max
<
kThreadsPerWarp
>
(
amax
);
constexpr
int
lane_zero
=
0
;
amax
=
__shfl_sync
(
0xFFFFFFFF
,
amax
,
lane_zero
);
// Step 3.4: Compute scale
CType
scale
;
scale
=
compute_scale_from_types
<
IType
,
OType
>
(
amax
,
epsilon
,
pow_2_scaling
);
thr_scale
.
data
.
elt
[
reg_idx
]
=
scale
;
// Step 3.5: Write scale_inv_t
bool
write_scale_inv
=
is_src_lane
;
if
constexpr
(
!
kAligned
)
{
write_scale_inv
&=
(
c_g
+
reg_idx
<
row_length
);
}
if
(
write_scale_inv
)
{
CType
scale_inv
=
1.0
/
scale
;
size_t
row_idx
=
static_cast
<
size_t
>
(
blockIdx
.
y
);
size_t
col_idx
=
static_cast
<
size_t
>
(
blockIdx
.
x
)
*
kTileDim
+
c_s
*
kNVecSMem
+
reg_idx
;
tile_scales_inv_t
[
row_idx
*
scale_t_stride_y
+
col_idx
*
scale_t_stride_x
]
=
scale_inv
;
}
}
// Step 3.6: Quantize
for
(
int
row_idx
=
0
;
row_idx
<
kThreadTileRow
;
++
row_idx
)
{
OType
*
output_g
=
&
output_t
[(
r_g
+
row_idx
)
*
row_length
+
c_g
];
// Output address in global memory
OVec
output_vec
;
#pragma unroll
for
(
int
i
=
0
;
i
<
kThreadTileCol
;
++
i
)
{
output_vec
.
data
.
elt
[
i
]
=
static_cast
<
OType
>
(
static_cast
<
CType
>
(
reg_vec
[
row_idx
].
data
.
elt
[
i
])
*
thr_scale
.
data
.
elt
[
i
]);
}
// Step 3.7: Store output_t
if
constexpr
(
kAligned
)
{
output_vec
.
store_to
(
output_g
);
}
else
{
if
(
r_g
+
row_idx
<
num_rows
)
{
output_vec
.
store_to_elts
(
output_g
,
0
,
num_ele
);
}
}
}
// Step 3.8: Update output address, column index of shared memory
// this section shouldn't matter since we only have one iteration
}
}
}
}
// namespace
...
...
@@ -569,11 +692,6 @@ void quantize_transpose_vector_blockwise(const SimpleTensor& input, SimpleTensor
const
bool
pow2_scale
,
cudaStream_t
stream
)
{
NVTE_API_CALL
(
quantize_transpose_vector_blockwise
);
// assert that rowwise_option and columnwise_option are not both NONE
NVTE_CHECK
(
rowwise_option
!=
FP8BlockwiseRowwiseOption
::
NONE
||
columnwise_option
!=
FP8BlockwiseColumnwiseOption
::
NONE
,
"rowwise_option and columnwise_option cannot both be NONE"
);
const
size_t
row_length
=
input
.
shape
.
size
()
>
0
?
input
.
shape
.
at
(
input
.
shape
.
size
()
-
1
)
:
1u
;
size_t
num_elements
=
row_length
;
size_t
num_rows
=
1
;
...
...
@@ -594,32 +712,43 @@ void quantize_transpose_vector_blockwise(const SimpleTensor& input, SimpleTensor
size_t
scale_t_stride_y
=
0
;
if
(
rowwise_option
!=
FP8BlockwiseRowwiseOption
::
NONE
)
{
NVTE_CHECK
(
rowwise_option
==
FP8BlockwiseRowwiseOption
::
ROWWISE
,
NVTE_CHECK
(
rowwise_option
==
FP8BlockwiseRowwiseOption
::
ROWWISE_GEMM_READY
||
rowwise_option
==
FP8BlockwiseRowwiseOption
::
ROWWISE_COMPACT
,
"Unexpected rowwise enum value"
);
NVTE_CHECK
(
input
.
shape
==
output
.
shape
,
"Input and output must have the same shape."
);
NVTE_CHECK
(
scale_inv
.
shape
.
size
()
==
2
,
"Scale dimension must be 2."
);
size_t
scale_k
=
scale_inv
.
shape
[
1
];
scale_stride_x
=
scale_k
;
scale_stride_y
=
1
;
bool
rowwise_compact
=
rowwise_option
==
FP8BlockwiseRowwiseOption
::
ROWWISE_COMPACT
;
scale_stride_x
=
rowwise_compact
?
1
:
scale_k
;
scale_stride_y
=
rowwise_compact
?
scale_k
:
1
;
}
if
(
columnwise_option
!=
FP8BlockwiseColumnwiseOption
::
NONE
)
{
NVTE_CHECK
(
columnwise_option
==
FP8BlockwiseColumnwiseOption
::
COLUMNWISE_TRANSPOSE
,
"Unexpected columnwise enum value"
);
NVTE_CHECK
(
output_t
.
shape
.
size
()
==
input
.
shape
.
size
(),
"output_t must have same number of dimensions as input."
);
if
(
output_t
.
shape
.
size
()
>
0
)
{
if
(
columnwise_option
==
FP8BlockwiseColumnwiseOption
::
COLUMNWISE_GEMM_READY
)
{
NVTE_CHECK
(
output_t
.
shape
[
0
]
==
row_length
,
"Wrong dimension 0 of output_t."
);
for
(
size_t
i
=
1
;
i
<
output_t
.
shape
.
size
();
++
i
)
{
NVTE_CHECK
(
output_t
.
shape
.
at
(
i
)
==
input
.
shape
.
at
(
i
-
1
),
"Wrong dimension in output_t"
);
}
}
else
{
NVTE_CHECK
(
columnwise_option
==
FP8BlockwiseColumnwiseOption
::
COLUMNWISE_COMPACT
,
"Unexpected columnwise option enum value"
);
NVTE_CHECK
(
output_t
.
shape
[
0
]
==
input
.
shape
[
0
],
"Wrong dimension 0 of output_t."
);
NVTE_CHECK
(
input
.
shape
==
output_t
.
shape
,
"Input and output_t must have the same shape for columnwise non-transpose case."
);
}
}
NVTE_CHECK
(
output
.
dtype
==
output_t
.
dtype
,
"output and output_t need to have the same dtype."
);
NVTE_CHECK
(
scale_inv_t
.
shape
.
size
()
==
2
,
"Scale_t dimension must be 2."
);
scale_t_stride_x
=
scale_inv_t
.
shape
[
1
];
scale_t_stride_y
=
1
;
bool
columnwise_compact
=
columnwise_option
==
FP8BlockwiseColumnwiseOption
::
COLUMNWISE_COMPACT
;
size_t
scale_t_k
=
scale_inv_t
.
shape
[
1
];
scale_t_stride_x
=
columnwise_compact
?
1
:
scale_t_k
;
scale_t_stride_y
=
columnwise_compact
?
scale_t_k
:
1
;
}
const
size_t
num_blocks_x
=
DIVUP
(
row_length
,
(
size_t
)
kTileDim
);
...
...
transformer_engine/common/transpose/transpose.cu
View file @
2b05e121
...
...
@@ -288,14 +288,13 @@ void nvte_transpose(const NVTETensor input, NVTETensor output, cudaStream_t stre
NVTE_API_CALL
(
nvte_transpose
);
using
namespace
transformer_engine
;
auto
noop
=
Tensor
();
transpose
(
*
reinterpret_cast
<
const
Tensor
*>
(
input
),
noop
,
reinterpret_cast
<
Tensor
*>
(
output
),
stream
);
transpose
(
*
convertNVTETensorCheck
(
input
),
noop
,
convertNVTETensor
(
output
),
stream
);
}
void
nvte_transpose_with_noop
(
const
NVTETensor
input
,
const
NVTETensor
noop
,
NVTETensor
output
,
cudaStream_t
stream
)
{
NVTE_API_CALL
(
nvte_transpose_with_noop
);
using
namespace
transformer_engine
;
transpose
(
*
reinterpret_cast
<
const
Tensor
*>
(
input
),
*
reinterpret_cast
<
const
Tensor
*>
(
noop
),
reinterpret_cast
<
Tensor
*>
(
output
),
stream
);
transpose
(
*
convertNVTE
Tensor
Check
(
input
),
*
convertNVTE
Tensor
Check
(
noop
),
convertNVTE
Tensor
(
output
),
stream
);
}
transformer_engine/common/transpose/transpose_fusion.cu
View file @
2b05e121
...
...
@@ -386,17 +386,18 @@ void populate_transpose_dbias_workspace_config(const Tensor &input, /*cast*/
workspace
->
data
.
dtype
=
DType
::
kFloat32
;
}
else
{
// Check that workspace matches expected size
const
size_t
workspace_size
=
const
size_t
workspace_size
=
get_buffer_size_bytes
(
std
::
accumulate
(
workspace
->
data
.
shape
.
begin
(),
workspace
->
data
.
shape
.
end
(),
1
,
std
::
multiplies
<
size_t
>
())
*
typeToSize
(
workspace
->
data
.
dtype
);
const
size_t
required_size
=
num_rows_partial_dbias
*
row_length
*
typeToSize
(
DType
::
kFloat32
);
std
::
multiplies
<
size_t
>
()),
workspace
->
data
.
dtype
);
const
size_t
required_size
=
get_buffer_size_bytes
(
num_rows_partial_dbias
,
row_length
,
DType
::
kFloat32
);
NVTE_CHECK
(
!
workspace
->
data
.
shape
.
empty
(),
"Invalid workspace dims (expected ("
,
num_rows_partial_dbias
,
","
,
row_length
,
"), found ())"
);
NVTE_CHECK
(
workspace_size
>=
required_size
,
"Invalid workspace (expected dims=("
,
num_rows_partial_dbias
,
","
,
row_length
,
"), dtype="
,
to_string
(
DType
::
kFloat32
),
"; found dims="
,
workspace
->
data
.
shape
,
", dtype="
,
typeTo
Size
(
workspace
->
data
.
dtype
),
")"
);
", dtype="
,
typeTo
NumBits
(
workspace
->
data
.
dtype
),
"
bits
)"
);
}
}
...
...
@@ -513,7 +514,6 @@ void nvte_fp8_transpose_dbias(const NVTETensor input, NVTETensor transposed_outp
NVTETensor
dbias
,
NVTETensor
workspace
,
cudaStream_t
stream
)
{
NVTE_API_CALL
(
nvte_fp8_transpose_dbias
);
using
namespace
transformer_engine
;
fp8_transpose_dbias
(
*
reinterpret_cast
<
const
Tensor
*>
(
input
),
reinterpret_cast
<
Tensor
*>
(
transposed_output
),
reinterpret_cast
<
Tensor
*>
(
dbias
),
reinterpret_cast
<
Tensor
*>
(
workspace
),
stream
);
fp8_transpose_dbias
(
*
convertNVTETensorCheck
(
input
),
convertNVTETensor
(
transposed_output
),
convertNVTETensor
(
dbias
),
convertNVTETensor
(
workspace
),
stream
);
}
transformer_engine/common/util/cast.cu
View file @
2b05e121
...
...
@@ -10,13 +10,16 @@
#endif
#include <cuda_runtime.h>
#include <transformer_engine/cast.h>
#include <transformer_engine/multi_stream.h>
#include <cfloat>
#include <limits>
#include <mutex>
#include <string>
#include "../common.h"
#include "../transpose/cast_transpose.h"
#include "../util/multi_stream.h"
#include "../util/vectorized_pointwise.h"
#include "../utils.cuh"
#include "cast_kernels.cuh"
...
...
@@ -156,6 +159,45 @@ void nvte_quantize_dbias_dsrelu(const NVTETensor input, const NVTETensor activat
void
nvte_dequantize
(
const
NVTETensor
input
,
NVTETensor
output
,
cudaStream_t
stream
)
{
NVTE_API_CALL
(
nvte_dequantize
);
using
namespace
transformer_engine
;
detail
::
dequantize_helper
(
*
reinterpret_cast
<
const
Tensor
*>
(
input
),
reinterpret_cast
<
Tensor
*>
(
output
),
stream
);
detail
::
dequantize_helper
(
*
convertNVTETensorCheck
(
input
),
convertNVTETensorCheck
(
output
),
stream
);
}
void
nvte_multi_tensor_quantize
(
const
NVTETensor
*
inputs
,
NVTETensor
*
outputs
,
const
NVTEQuantizationConfig
quant_configs
,
const
size_t
num_tensors
,
cudaStream_t
stream
)
{
NVTE_API_CALL
(
nvte_multi_tensor_quantize
);
using
namespace
transformer_engine
;
constexpr
bool
IS_DBIAS
=
false
;
constexpr
bool
IS_DACT
=
false
;
constexpr
bool
IS_ACT
=
false
;
constexpr
NVTETensor
dbias
=
nullptr
;
constexpr
NVTETensor
workspace
=
nullptr
;
constexpr
const
NVTETensor
grad
=
nullptr
;
const
size_t
num_streams
=
nvte_get_num_compute_streams
();
int
num_stream_used
=
std
::
min
(
num_streams
,
num_tensors
);
// wait for current stream to finish
NVTE_CHECK_CUDA
(
cudaEventRecord
(
detail
::
get_compute_stream_event
(
0
),
stream
));
for
(
int
s
=
0
;
s
<
num_stream_used
;
s
++
)
{
NVTE_CHECK_CUDA
(
cudaStreamWaitEvent
(
detail
::
get_compute_stream
(
s
),
detail
::
get_compute_stream_event
(
0
)));
}
for
(
int
i
=
0
;
i
<
num_tensors
;
i
++
)
{
detail
::
quantize_helper
<
IS_DBIAS
,
IS_DACT
,
IS_ACT
,
Empty
,
nullptr
>
(
inputs
[
i
],
grad
,
outputs
[
i
],
dbias
,
workspace
,
nullptr
,
detail
::
get_compute_stream
(
i
%
num_streams
));
}
// record events on compute streams
for
(
int
s
=
0
;
s
<
num_stream_used
;
s
++
)
{
NVTE_CHECK_CUDA
(
cudaEventRecord
(
detail
::
get_compute_stream_event
(
s
),
detail
::
get_compute_stream
(
s
)));
}
// wait for all compute streams to finish
for
(
int
s
=
0
;
s
<
num_stream_used
;
s
++
)
{
NVTE_CHECK_CUDA
(
cudaStreamWaitEvent
(
stream
,
detail
::
get_compute_stream_event
(
s
)));
}
}
transformer_engine/common/util/cast_gated_kernels.cuh
View file @
2b05e121
...
...
@@ -763,19 +763,20 @@ void cast_fp8_gated(const Tensor &grad, const Tensor &gated_input, Tensor *outpu
if
constexpr
(
IS_DGATED
)
{
create_2D_tensor_map
(
tensor_map_grad
,
grad
.
data
,
rows
,
cols
,
SHMEM_DIM_Y
,
SHMEM_DIM_X
,
cols
,
0
,
sizeof
(
IT
ype
));
cols
,
0
,
typeToNumBits
(
gated_input
.
dt
ype
()
));
}
const
uint32_t
tensor_stride_elems
=
output_cols
;
create_2D_tensor_map
(
tensor_map_input_act
,
gated_input
.
data
,
rows
,
cols
,
SHMEM_DIM_Y
,
SHMEM_DIM_X
,
cols
*
2
,
0
,
sizeof
(
IT
ype
));
SHMEM_DIM_X
,
cols
*
2
,
0
,
typeToNumBits
(
gated_input
.
dt
ype
()
));
create_2D_tensor_map
(
tensor_map_input_gate
,
gated_input
.
data
,
rows
,
cols
,
SHMEM_DIM_Y
,
SHMEM_DIM_X
,
cols
*
2
,
cols
,
sizeof
(
IT
ype
));
SHMEM_DIM_X
,
cols
*
2
,
cols
,
typeToNumBits
(
gated_input
.
dt
ype
()
));
create_2D_tensor_map
(
tensor_map_output_act
,
output
->
data
,
rows
,
cols
,
SHMEM_DIM_Y
,
SHMEM_DIM_X
,
tensor_stride_elems
,
0
,
sizeof
(
OT
ype
));
SHMEM_DIM_X
,
tensor_stride_elems
,
0
,
typeToNumBits
(
output
->
dt
ype
()
));
create_2D_tensor_map
(
tensor_map_output_gate
,
output
->
data
,
rows
,
cols
,
SHMEM_DIM_Y
,
SHMEM_DIM_X
,
tensor_stride_elems
,
cols
,
sizeof
(
OType
));
SHMEM_DIM_X
,
tensor_stride_elems
,
cols
,
typeToNumBits
(
output
->
dtype
()));
const
size_t
buff_elems_total
=
BUFFERS_NUM
*
SHMEM_DIM_Y
*
SHMEM_DIM_X
;
const
size_t
buff_size_aligned_in
=
...
...
@@ -862,31 +863,33 @@ void cast_mxfp8_gated(const Tensor &grad, const Tensor &gated_input, Tensor *out
if
constexpr
(
IS_DGATED
)
{
create_2D_tensor_map
(
tensor_map_grad
,
grad
.
data
,
rows
,
cols
,
SHMEM_DIM_Y
,
SHMEM_DIM_X
,
cols
,
0
,
sizeof
(
IT
ype
));
SHMEM_DIM_X
,
cols
,
0
,
typeToNumBits
(
gated_input
.
dt
ype
()
));
}
const
uint32_t
tensor_stride_elems
=
output_cols
;
create_2D_tensor_map
(
tensor_map_input_act
,
gated_input
.
data
,
rows
,
cols
,
SHMEM_DIM_Y
,
SHMEM_DIM_X
,
cols
*
2
,
0
,
sizeof
(
IType
));
SHMEM_DIM_Y
,
SHMEM_DIM_X
,
cols
*
2
,
0
,
typeToNumBits
(
gated_input
.
dtype
()));
create_2D_tensor_map
(
tensor_map_input_gate
,
gated_input
.
data
,
rows
,
cols
,
SHMEM_DIM_Y
,
SHMEM_DIM_X
,
cols
*
2
,
cols
,
sizeof
(
IType
));
SHMEM_DIM_Y
,
SHMEM_DIM_X
,
cols
*
2
,
cols
,
typeToNumBits
(
gated_input
.
dtype
()));
if
(
USE_ROWWISE_SCALING
)
{
create_2D_tensor_map
(
tensor_map_output_act_rowwise
,
output
->
data
,
rows
,
cols
,
SHMEM_DIM_Y
,
SHMEM_DIM_X
,
tensor_stride_elems
,
0
,
sizeof
(
OT
ype
));
typeToNumBits
(
output
->
dt
ype
()
));
create_2D_tensor_map
(
tensor_map_output_gate_rowwise
,
output
->
data
,
rows
,
cols
,
SHMEM_DIM_Y
,
SHMEM_DIM_X
,
tensor_stride_elems
,
cols
,
sizeof
(
OT
ype
));
typeToNumBits
(
output
->
dt
ype
()
));
}
if
(
USE_COLWISE_SCALING
)
{
create_2D_tensor_map
(
tensor_map_output_act_colwise
,
output
->
columnwise_data
,
rows
,
cols
,
SHMEM_DIM_Y
,
SHMEM_DIM_X
,
tensor_stride_elems
,
0
,
sizeof
(
OT
ype
));
0
,
typeToNumBits
(
output
->
dt
ype
()
));
create_2D_tensor_map
(
tensor_map_output_gate_colwise
,
output
->
columnwise_data
,
rows
,
cols
,
SHMEM_DIM_Y
,
SHMEM_DIM_X
,
tensor_stride_elems
,
cols
,
sizeof
(
OT
ype
));
cols
,
typeToNumBits
(
output
->
dt
ype
()
));
}
const
size_t
buff_elems_total
=
BUFFERS_NUM
*
SHMEM_DIM_Y
*
SHMEM_DIM_X
;
...
...
@@ -1071,10 +1074,9 @@ void quantize_gated_helper(const NVTETensor grad, const NVTETensor gated_input,
cudaStream_t
stream
)
{
using
namespace
gated_kernels
;
Tensor
grad_empty_tensor
;
const
Tensor
&
grad_tensor
=
IS_DGATED
?
*
(
reinterpret_cast
<
const
Tensor
*>
(
grad
))
:
grad_empty_tensor
;
const
Tensor
gated_input_tensor
=
*
reinterpret_cast
<
const
Tensor
*>
(
gated_input
);
Tensor
*
output_tensor
=
reinterpret_cast
<
Tensor
*>
(
output
);
const
Tensor
&
grad_tensor
=
IS_DGATED
?
*
(
convertNVTETensorCheck
(
grad
))
:
grad_empty_tensor
;
const
Tensor
gated_input_tensor
=
*
convertNVTETensorCheck
(
gated_input
);
Tensor
*
output_tensor
=
convertNVTETensorCheck
(
output
);
if
(
is_supported_by_CC_100
())
{
quantize_gated
<
IS_DGATED
,
ParamOP
,
ActOP
,
DActOP
>
(
grad_tensor
,
gated_input_tensor
,
...
...
transformer_engine/common/util/cast_kernels.cuh
View file @
2b05e121
...
...
@@ -904,15 +904,15 @@ void cast_fp8_2D(const Tensor &input, const Tensor *act_input, Tensor *output, T
alignas
(
64
)
CUtensorMap
tensor_map_output
{};
create_2D_tensor_map
(
tensor_map_input
,
input
.
data
,
rows
,
cols
,
FP8_SHMEM_DIM_Y
,
FP8_SHMEM_DIM_X
,
cols
,
0
,
sizeof
(
IT
ype
));
FP8_SHMEM_DIM_X
,
cols
,
0
,
typeToNumBits
(
input
.
data
.
dt
ype
));
if
constexpr
(
IS_DACT
)
{
create_2D_tensor_map
(
tensor_map_act_input
,
act_input
->
data
,
rows
,
cols
,
FP8_SHMEM_DIM_Y
,
FP8_SHMEM_DIM_X
,
cols
,
0
,
sizeof
(
IT
ype
));
FP8_SHMEM_DIM_X
,
cols
,
0
,
typeToNumBits
(
input
.
data
.
dt
ype
));
}
create_2D_tensor_map
(
tensor_map_output
,
output
->
data
,
rows
,
cols
,
FP8_SHMEM_DIM_Y
,
FP8_SHMEM_DIM_X
,
cols
,
0
,
sizeof
(
OT
ype
));
FP8_SHMEM_DIM_X
,
cols
,
0
,
typeToNumBits
(
output
->
data
.
dt
ype
));
cast_fp8_2D_kernel
<
IS_DBIAS
,
IS_DACT
,
ParamOP
,
OP
,
IType
,
OType
>
<<<
grid
,
block
,
0
,
stream
>>>
(
tensor_map_input
,
tensor_map_act_input
,
tensor_map_output
,
...
...
@@ -1004,24 +1004,24 @@ void mxfp8_quantize(const Tensor &input, const Tensor *act_input,
alignas
(
64
)
CUtensorMap
tensor_map_output_colwise
{};
create_2D_tensor_map
(
tensor_map_input
,
input
.
data
,
rows
,
cols
,
MXFP8_SHMEM_DIM_Y
,
MXFP8_SHMEM_DIM_X
,
cols
,
0
,
sizeof
(
IT
ype
));
MXFP8_SHMEM_DIM_X
,
cols
,
0
,
typeToNumBits
(
input
.
dt
ype
()
));
if
constexpr
(
IS_DACT
)
{
create_2D_tensor_map
(
tensor_map_act_input
,
act_input
->
data
,
rows
,
cols
,
MXFP8_SHMEM_DIM_Y
,
MXFP8_SHMEM_DIM_X
,
cols
,
0
,
sizeof
(
IT
ype
));
typeToNumBits
(
input
.
dt
ype
()
));
}
if
(
use_rowwise_scaling
)
{
create_2D_tensor_map
(
tensor_map_output_rowwise
,
output
->
data
,
rows
,
cols
,
MXFP8_SHMEM_DIM_Y
,
MXFP8_SHMEM_DIM_X
,
cols
,
0
,
sizeof
(
OT
ype
));
typeToNumBits
(
output
->
dt
ype
()
));
}
if
(
use_colwise_scaling
)
{
create_2D_tensor_map
(
tensor_map_output_colwise
,
output
->
columnwise_data
,
rows
,
cols
,
MXFP8_SHMEM_DIM_Y
,
MXFP8_SHMEM_DIM_X
,
cols
,
0
,
sizeof
(
OT
ype
));
typeToNumBits
(
output
->
dt
ype
()
));
}
cast_mxfp8_2D_kernel
<
IS_DBIAS
,
IS_DACT
,
IS_ACT
,
ParamOP
,
OP
,
IType
,
OType
,
...
...
@@ -1133,7 +1133,7 @@ static bool is_full_tile_1D_tensor(const Tensor *const t) {
bool
dimensions_supported_by_TMA
(
const
Tensor
*
const
t
)
{
const
size_t
cols
=
t
->
flat_last_dim
();
constexpr
int
TMA_bytes
=
16
;
const
int
alignment_requirement
=
TMA_bytes
/
typeTo
Size
(
t
->
dtype
());
const
int
alignment_requirement
=
(
TMA_bytes
*
8
)
/
typeTo
NumBits
(
t
->
dtype
());
return
cols
%
alignment_requirement
==
0
;
}
...
...
@@ -1254,23 +1254,23 @@ void quantize_helper(const NVTETensor input, const NVTETensor grad, NVTETensor o
const
Tensor
*
activation_input_tensor
;
if
constexpr
(
IS_DBIAS
||
IS_DACT
)
{
// backward - input is incoming gradient
input_tensor
=
reinterpret_cast
<
const
Tensor
*>
(
grad
);
activation_input_tensor
=
reinterpret_cast
<
const
Tensor
*>
(
input
);
input_tensor
=
convertNVTE
Tensor
Check
(
grad
);
activation_input_tensor
=
convertNVTE
Tensor
(
input
);
}
else
{
// forward = input is activation input
input_tensor
=
reinterpret_cast
<
const
Tensor
*>
(
input
);
input_tensor
=
convertNVTE
Tensor
Check
(
input
);
activation_input_tensor
=
nullptr
;
}
auto
output_tensor
=
reinterpret_cast
<
Tensor
*>
(
output
);
auto
dbias_tensor
=
reinterpret_cast
<
Tensor
*>
(
dbias
);
auto
workspace_tensor
=
reinterpret_cast
<
Tensor
*>
(
workspace
);
auto
output_tensor
=
convertNVTE
Tensor
Check
(
output
);
auto
dbias_tensor
=
convertNVTE
Tensor
(
dbias
);
auto
workspace_tensor
=
convertNVTE
Tensor
(
workspace
);
const
QuantizationConfig
*
quant_config_cpp
=
reinterpret_cast
<
const
QuantizationConfig
*>
(
quant_config
);
// extract noop tensor from quant_config_cpp if it's not null
const
NVTETensor
noop
=
quant_config_cpp
?
quant_config_cpp
->
noop_tensor
:
nullptr
;
const
auto
noop_tensor
=
noop
!=
nullptr
?
*
(
reinterpret_cast
<
const
Tensor
*>
(
noop
))
:
Tensor
();
const
auto
noop_tensor
=
noop
!=
nullptr
?
*
(
convertNVTE
Tensor
Check
(
noop
))
:
Tensor
();
switch
(
output_tensor
->
scaling_mode
)
{
case
NVTE_DELAYED_TENSOR_SCALING
:
{
...
...
@@ -1315,12 +1315,25 @@ void quantize_helper(const NVTETensor input, const NVTETensor grad, NVTETensor o
"IS_DBIAS, IS_DACT, and IS_ACT not implemented for NVTE_BLOCK_SCALING_1D"
);
bool
force_pow_2_scales
=
quant_config_cpp
?
quant_config_cpp
->
force_pow_2_scales
:
false
;
float
epsilon
=
quant_config_cpp
?
quant_config_cpp
->
amax_epsilon
:
0.0
f
;
FP8BlockwiseRowwiseOption
rowwise_option
=
output_tensor
->
has_data
()
?
FP8BlockwiseRowwiseOption
::
ROWWISE
:
FP8BlockwiseRowwiseOption
::
NONE
;
FP8BlockwiseColumnwiseOption
columnwise_option
=
output_tensor
->
has_columnwise_data
()
?
FP8BlockwiseColumnwiseOption
::
COLUMNWISE_TRANSPOSE
:
FP8BlockwiseColumnwiseOption
::
NONE
;
FP8BlockwiseRowwiseOption
rowwise_option
=
FP8BlockwiseRowwiseOption
::
NONE
;
FP8BlockwiseColumnwiseOption
columnwise_option
=
FP8BlockwiseColumnwiseOption
::
NONE
;
if
(
output_tensor
->
has_data
())
{
bool
rowwise_compact
=
quant_config_cpp
?
quant_config_cpp
->
float8_block_scale_tensor_format
==
Float8BlockScaleTensorFormat
::
COMPACT
:
false
;
rowwise_option
=
rowwise_compact
?
FP8BlockwiseRowwiseOption
::
ROWWISE_COMPACT
:
FP8BlockwiseRowwiseOption
::
ROWWISE_GEMM_READY
;
}
if
(
output_tensor
->
has_columnwise_data
())
{
bool
columnwise_compact
=
quant_config_cpp
?
quant_config_cpp
->
float8_block_scale_tensor_format
==
Float8BlockScaleTensorFormat
::
COMPACT
:
false
;
columnwise_option
=
columnwise_compact
?
FP8BlockwiseColumnwiseOption
::
COLUMNWISE_COMPACT
:
FP8BlockwiseColumnwiseOption
::
COLUMNWISE_GEMM_READY
;
}
quantize_transpose_vector_blockwise
(
input_tensor
->
data
,
output_tensor
->
scale_inv
,
output_tensor
->
columnwise_scale_inv
,
output_tensor
->
data
,
output_tensor
->
columnwise_data
,
epsilon
,
rowwise_option
,
...
...
transformer_engine/common/util/cuda_driver.cpp
View file @
2b05e121
...
...
@@ -13,16 +13,42 @@ namespace transformer_engine {
namespace
cuda_driver
{
void
*
get_symbol
(
const
char
*
symbol
)
{
void
*
entry_point
;
#ifndef USE_ROCM
typedef
cudaError_t
(
*
VersionedGetEntryPoint
)(
const
char
*
,
void
**
,
unsigned
int
,
unsigned
long
long
,
// NOLINT(*)
cudaDriverEntryPointQueryResult
*
);
typedef
cudaError_t
(
*
GetEntryPoint
)(
const
char
*
,
void
**
,
unsigned
long
long
,
// NOLINT(*)
cudaDriverEntryPointQueryResult
*
);
#endif
void
*
get_symbol
(
const
char
*
symbol
,
int
cuda_version
)
{
#ifndef USE_ROCM
constexpr
char
driver_entrypoint
[]
=
"cudaGetDriverEntryPoint"
;
constexpr
char
driver_entrypoint_versioned
[]
=
"cudaGetDriverEntryPointByVersion"
;
// We link to the libcudart.so already, so can search for it in the current context
static
GetEntryPoint
driver_entrypoint_fun
=
reinterpret_cast
<
GetEntryPoint
>
(
dlsym
(
RTLD_DEFAULT
,
driver_entrypoint
));
static
VersionedGetEntryPoint
driver_entrypoint_versioned_fun
=
reinterpret_cast
<
VersionedGetEntryPoint
>
(
dlsym
(
RTLD_DEFAULT
,
driver_entrypoint_versioned
));
cudaDriverEntryPointQueryResult
driver_result
;
#endif
void
*
entry_point
=
nullptr
;
#ifdef USE_ROCM
hipDriverProcAddressQueryResult
driver_result
;
NVTE_CHECK_CUDA
(
hipGetProcAddress
(
symbol
,
&
entry_point
,
HIP_VERSION_MAJOR
*
100
+
HIP_VERSION_MINOR
,
0
,
&
driver_result
));
NVTE_CHECK
(
driver_result
==
HIP_GET_PROC_ADDRESS_SUCCESS
,
"Could not find CUDA driver entry point for "
,
symbol
);
#else
cudaDriverEntryPointQueryResult
driver_result
;
NVTE_CHECK_CUDA
(
cudaGetDriverEntryPoint
(
symbol
,
&
entry_point
,
cudaEnableDefault
,
&
driver_result
));
if
(
driver_entrypoint_versioned_fun
!=
nullptr
)
{
// Found versioned entrypoint function
NVTE_CHECK_CUDA
(
driver_entrypoint_versioned_fun
(
symbol
,
&
entry_point
,
cuda_version
,
cudaEnableDefault
,
&
driver_result
));
}
else
{
NVTE_CHECK
(
driver_entrypoint_fun
!=
nullptr
,
"Error finding the CUDA Runtime-Driver interop."
);
// Versioned entrypoint function not found
NVTE_CHECK_CUDA
(
driver_entrypoint_fun
(
symbol
,
&
entry_point
,
cudaEnableDefault
,
&
driver_result
));
}
NVTE_CHECK
(
driver_result
==
cudaDriverEntryPointSuccess
,
"Could not find CUDA driver entry point for "
,
symbol
);
#endif
...
...
transformer_engine/common/util/cuda_driver.h
View file @
2b05e121
...
...
@@ -19,7 +19,7 @@ namespace transformer_engine {
namespace
cuda_driver
{
/*! \brief Get pointer corresponding to symbol in CUDA driver library */
void
*
get_symbol
(
const
char
*
symbol
);
void
*
get_symbol
(
const
char
*
symbol
,
int
cuda_version
=
12010
);
/*! \brief Call function in CUDA driver library
*
...
...
transformer_engine/common/util/dequantize_kernels.cuh
View file @
2b05e121
...
...
@@ -326,9 +326,9 @@ static void mxfp8_dequantize(const Tensor &input, Tensor *output, cudaStream_t s
alignas
(
64
)
CUtensorMap
tensor_map_output
{};
create_2D_tensor_map
(
tensor_map_input
,
input_data
,
rows
,
cols
,
SHMEM_DIM_Y
,
SHMEM_DIM_X
,
cols
,
0
,
sizeof
(
IT
ype
));
SHMEM_DIM_X
,
cols
,
0
,
typeToNumBits
(
input
.
dt
ype
()
));
create_2D_tensor_map
(
tensor_map_output
,
output
->
data
,
rows
,
cols
,
SHMEM_DIM_Y
,
SHMEM_DIM_X
,
cols
,
0
,
sizeof
(
OT
ype
));
SHMEM_DIM_X
,
cols
,
0
,
typeToNumBits
(
output
->
dt
ype
()
));
dequantize_mxfp8_kernel
<
IType
,
OType
,
SCALE_DIM_Y
,
SCALE_DIM_X
>
<<<
grid
,
block
,
0
,
stream
>>>
(
tensor_map_input
,
tensor_map_output
,
scales_ptr
,
...
...
transformer_engine/common/util/multi_stream.cpp
0 → 100644
View file @
2b05e121
/*************************************************************************
* Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
*
* See LICENSE for license information.
************************************************************************/
#ifndef TRANSFORMER_ENGINE_UTIL_MULTI_STREAM_H_
#define TRANSFORMER_ENGINE_UTIL_MULTI_STREAM_H_
#include "multi_stream.h"
#include <transformer_engine/multi_stream.h>
#include <mutex>
#include <vector>
#include "cuda_runtime.h"
#include "logging.h"
namespace
transformer_engine
::
detail
{
cudaStream_t
get_compute_stream
(
int
idx
)
{
const
size_t
num_streams
=
nvte_get_num_compute_streams
();
NVTE_CHECK
(
0
<=
idx
&&
idx
<
num_streams
,
"Invalid compute stream (requested idx "
,
idx
,
", but there are "
,
num_streams
,
" streams)"
);
static
std
::
vector
<
cudaStream_t
>
streams
(
num_streams
);
static
std
::
once_flag
stream_init_flag
;
auto
init
=
[
&
]()
{
for
(
size_t
i
=
0
;
i
<
num_streams
;
i
++
)
{
NVTE_CHECK_CUDA
(
cudaStreamCreateWithPriority
(
&
streams
[
i
],
cudaStreamNonBlocking
,
-
1
));
}
};
std
::
call_once
(
stream_init_flag
,
init
);
return
streams
[
idx
];
}
cudaEvent_t
get_compute_stream_event
(
int
idx
)
{
const
size_t
num_streams
=
nvte_get_num_compute_streams
();
NVTE_CHECK
(
0
<=
idx
&&
idx
<
num_streams
,
"Invalid compute stream (requested idx "
,
idx
,
", but there are "
,
num_streams
,
" streams)"
);
static
std
::
vector
<
cudaEvent_t
>
events
(
num_streams
);
static
std
::
once_flag
event_init_flag
;
auto
init
=
[
&
]()
{
for
(
size_t
i
=
0
;
i
<
num_streams
;
i
++
)
{
NVTE_CHECK_CUDA
(
cudaEventCreate
(
&
events
[
i
]));
}
};
std
::
call_once
(
event_init_flag
,
init
);
return
events
[
idx
];
}
int
get_num_compute_streams
()
{
static
constexpr
int
num_compute_streams
=
4
;
return
num_compute_streams
;
}
}
// namespace transformer_engine::detail
int
nvte_get_num_compute_streams
()
{
return
transformer_engine
::
detail
::
get_num_compute_streams
();
}
#endif // TRANSFORMER_ENGINE_UTIL_MULTI_STREAM_H_
transformer_engine/common/util/multi_stream.h
0 → 100644
View file @
2b05e121
/*************************************************************************
* Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
*
* See LICENSE for license information.
************************************************************************/
#ifndef TRANSFORMER_ENGINE_UTIL_MULTI_STREAM_H_
#define TRANSFORMER_ENGINE_UTIL_MULTI_STREAM_H_
namespace
transformer_engine
::
detail
{
int
get_num_compute_streams
();
cudaStream_t
get_compute_stream
(
int
idx
);
cudaEvent_t
get_compute_stream_event
(
int
idx
);
}
// namespace transformer_engine::detail
#endif // TRANSFORMER_ENGINE_UTIL_MULTI_STREAM_H_
transformer_engine/common/util/padding.cu
View file @
2b05e121
...
...
@@ -155,8 +155,8 @@ void multi_padding(const std::vector<Tensor*> input_list, std::vector<Tensor*> o
// Input matrices are divided into tiles
// Note: Each tile is a warp_size x warp_size grid of nvec x nvec subtiles
const
int
tile_dim_m
=
THREADS_PER_WARP
*
desired_load_store_size
/
typeTo
Size
(
type
);
const
int
tile_dim_n
=
THREADS_PER_WARP
*
desired_load_store_size
/
typeTo
Size
(
type
);
const
int
tile_dim_m
=
THREADS_PER_WARP
*
desired_load_store_size
*
8
/
typeTo
NumBits
(
type
);
const
int
tile_dim_n
=
THREADS_PER_WARP
*
desired_load_store_size
*
8
/
typeTo
NumBits
(
type
);
// Add tensors to kernel argument struct
MultiPaddingArgs
kernel_args
;
...
...
@@ -211,8 +211,8 @@ void nvte_multi_padding(size_t num_tensors, const NVTETensor* input_list, NVTETe
std
::
vector
<
Tensor
*>
input_list_
,
output_list_
;
std
::
vector
<
int
>
padded_num_rows_list_
;
for
(
size_t
i
=
0
;
i
<
num_tensors
;
++
i
)
{
input_list_
.
push_back
(
reinterpret_cast
<
Tensor
*>
(
const_cast
<
NVTETensor
&>
(
input_list
[
i
]))
)
;
output_list_
.
push_back
(
reinterpret_cast
<
Tensor
*>
(
output_list
[
i
]));
input_list_
.
push_back
(
convert
NVTETensor
Check
(
input_list
[
i
]));
output_list_
.
push_back
(
convertNVTE
Tensor
Check
(
output_list
[
i
]));
padded_num_rows_list_
.
push_back
(
padded_num_rows_list
[
i
]);
}
multi_padding
(
input_list_
,
output_list_
,
padded_num_rows_list_
,
stream
);
...
...
transformer_engine/common/util/pybind_helper.h
View file @
2b05e121
...
...
@@ -80,6 +80,10 @@
.value("NVTE_F16_arbitrary_seqlen", NVTE_Fused_Attn_Backend::NVTE_F16_arbitrary_seqlen) \
.value("NVTE_FP8", NVTE_Fused_Attn_Backend::NVTE_FP8) \
.value("NVTE_No_Backend", NVTE_Fused_Attn_Backend::NVTE_No_Backend); \
pybind11::enum_<transformer_engine::Float8BlockScaleTensorFormat>( \
m, "Float8BlockScaleTensorFormat", pybind11::module_local()) \
.value("GEMM_READY", transformer_engine::Float8BlockScaleTensorFormat::GEMM_READY) \
.value("COMPACT", transformer_engine::Float8BlockScaleTensorFormat::COMPACT); \
pybind11::enum_<transformer_engine::CommOverlapType>(m, "CommOverlapType", \
pybind11::module_local()) \
.value("RS", transformer_engine::CommOverlapType::RS) \
...
...
transformer_engine/common/utils.cuh
View file @
2b05e121
...
...
@@ -21,6 +21,10 @@ using namespace __hip_internal;
#include <cuda_bf16.h>
#include <cuda_fp8.h>
#if CUDA_VERSION >= 12080
#include <cuda_fp4.h>
#endif
#ifdef __HIP_PLATFORM_AMD__
typedef
uint16_t
hip_bfloat16x2
__attribute__
((
ext_vector_type
(
2
)));
#else
...
...
transformer_engine/debug/features/api.py
View file @
2b05e121
...
...
@@ -12,7 +12,7 @@ from nvdlfw_inspect.registry import Registry
import
torch
from
transformer_engine.debug.features.utils.stats_buffer
import
STATS_BUFFERS
from
transformer_engine.pytorch.tensor
import
all_tensor_types
from
transformer_engine.pytorch.tensor
import
get_
all_tensor_types
from
transformer_engine.debug.pytorch.debug_state
import
TEDebugState
from
transformer_engine.pytorch.tensor
import
Quantizer
,
QuantizedTensor
...
...
@@ -424,7 +424,7 @@ class TransformerEngineAPI(BaseNamespaceAPI):
if
api_name
in
[
"inspect_tensor"
,
"inspect_tensor_postquantize"
]:
assert
ret
is
None
if
api_name
==
"modify_tensor"
:
assert
type
(
ret
)
in
all_tensor_types
assert
type
(
ret
)
in
get_
all_tensor_types
()
if
(
type
(
ret
)
==
torch
.
Tensor
# pylint: disable=unidiomatic-typecheck
and
"dtype"
in
kwargs
...
...
@@ -438,4 +438,4 @@ class TransformerEngineAPI(BaseNamespaceAPI):
def
end_debug
(
self
):
"""This function is called by the nvidia-dlframework-inspect after every debug_api.end_debug()"""
TEDebugState
.
reset
()
TEDebugState
.
_
reset
()
transformer_engine/debug/features/fake_quant.py
View file @
2b05e121
...
...
@@ -49,7 +49,7 @@ def fake_quantize(tensor: torch.Tensor, fp8_format: tex.DType, out=None):
fp8_dtype
=
tex
.
DType
.
kFloat8E5M2
amax
=
tensor
.
abs
().
max
().
float
()
one
=
torch
.
ones
(
1
,
device
=
tensor
.
device
)
scale
=
_default_sf_compute
(
amax
,
one
,
fp8_max
)
scale
=
_default_sf_compute
(
amax
,
one
,
fp8_max
,
0
)
quantizer
=
Float8Quantizer
(
scale
,
amax
,
fp8_dtype
)
else
:
...
...
transformer_engine/debug/features/log_fp8_tensor_stats.py
View file @
2b05e121
...
...
@@ -120,7 +120,6 @@ class LogFp8TensorStats(BaseLogTensorStats):
if
not
rowwise
:
return
# tensor was already seen rowwise in the other gemm
tensor
=
tensor
.
_data
options
=
(
config
.
get
(
"start_step"
,
None
),
config
.
get
(
"end_step"
,
None
),
...
...
transformer_engine/debug/features/per_tensor_scaling.py
View file @
2b05e121
...
...
@@ -15,6 +15,7 @@ import transformer_engine_torch as tex
from
transformer_engine.pytorch.tensor
import
Quantizer
from
transformer_engine.pytorch.tensor.float8_tensor
import
(
Float8Tensor
,
Float8Quantizer
,
Float8CurrentScalingQuantizer
,
)
from
transformer_engine.debug.features.api
import
TEConfigAPIMapper
...
...
@@ -39,7 +40,7 @@ def per_tensor_cast(
},
"[NVTORCH INSPECT ERROR] Only 2 FP8 types: E4M3 and E5M2 are supported in TE."
tensor
=
tensor
.
contiguous
()
quantizer
=
Float8CurrentScalingQuantizer
(
fp8_dtype
)
quantizer
=
Float8CurrentScalingQuantizer
(
fp8_dtype
,
device
=
tensor
.
device
)
if
out
is
not
None
:
quantizer
.
update_quantized
(
tensor
,
out
)
...
...
@@ -81,7 +82,6 @@ class PerTensorScaling(TEConfigAPIMapper):
transformer_engine:
PerTensorScaling:
enabled: True
margin: 1
gemms: [dgrad]
tensors: [weight, activation]
"""
...
...
@@ -118,7 +118,7 @@ class PerTensorScaling(TEConfigAPIMapper):
if
key
not
in
[
"gemm"
,
"tensor"
]:
raise
ValueError
(
f
'[NVTORCH INSPECT ERROR] Unexpected key in config: "
{
key
}
".'
)
assert
isinstance
(
default_quantizer
,
Float8
CurrentScaling
Quantizer
),
(
assert
isinstance
(
default_quantizer
,
Float8Quantizer
),
(
f
"[NVTORCH INSPECT ERROR] Feature=
{
self
.
__class__
.
__name__
}
, API=process_tensor: "
"Per-tensor current scaling can be used only within `DelayedScaling` recipe autocast."
f
"
{
layer_name
}
"
...
...
transformer_engine/debug/features/utils/stats_computation.py
View file @
2b05e121
...
...
@@ -96,7 +96,10 @@ STATS = {
"max"
:
(
torch
.
max
,
lambda
buffers
:
max
(
_get
(
buffers
,
"max"
))),
"sum"
:
(
torch
.
sum
,
lambda
buffers
:
sum
(
_get
(
buffers
,
"sum"
))),
"mean"
:
(
torch
.
mean
,
lambda
buffers
:
sum
(
_get
(
buffers
,
"sum"
))
/
sum
(
_get
(
buffers
,
"numel"
))),
"numel"
:
(
lambda
x
:
x
.
numel
(),
lambda
buffers
:
sum
(
_get
(
buffers
,
"numel"
))),
"numel"
:
(
lambda
x
:
x
.
numel
()
if
hasattr
(
x
,
"numel"
)
else
x
.
get_data_tensors
()[
0
].
numel
(),
lambda
buffers
:
sum
(
_get
(
buffers
,
"numel"
)),
),
"l1_norm"
:
(
lambda
x
:
torch
.
norm
(
x
,
p
=
1
),
lambda
buffers
:
sum
(
_get
(
buffers
,
"l1_norm"
))),
"l2_norm_square"
:
(
lambda
x
:
torch
.
sum
(
x
**
2
),
...
...
@@ -137,7 +140,7 @@ STATS = {
-
min
(
_get
(
buffers
,
"dynamic_range_bottom"
)),
),
"underflows%"
:
(
lambda
x
:
(
x
==
0
).
sum
()
/
x
.
numel
()
*
100
,
lambda
x
:
(
x
.
get_data_tensors
()[
0
]
==
0
).
sum
()
/
x
.
get_data_tensors
()[
0
]
.
numel
()
*
100
,
lambda
buffers
:
100
*
sum
(
_get
(
buffers
,
"underflows_num"
))
/
sum
(
_get
(
buffers
,
"numel"
)),
),
}
transformer_engine/debug/pytorch/debug_quantization.py
View file @
2b05e121
...
...
@@ -14,10 +14,11 @@ import torch
import
transformer_engine_torch
as
tex
from
transformer_engine.common.recipe
import
Recipe
from
transformer_engine.pytorch.tensor.quantized_tensor
import
(
QuantizedTensor
,
Quantizer
,
QuantizedTensorBase
,
prepare_for_saving
,
restore_from_saved
,
)
...
...
@@ -299,6 +300,7 @@ class DebugQuantizer(Quantizer):
iteration
=
self
.
iteration
,
dtype
=
dtype
,
)
if
dtype
is
not
None
:
if
columnwise_gemm_tensor
.
dtype
!=
dtype
:
raise
ValueError
(
"Dtype does not match the output of the modify_tensor call"
)
if
self
.
rowwise_tensor_plan
==
API_CALL_MODIFY
:
...
...
@@ -311,6 +313,7 @@ class DebugQuantizer(Quantizer):
iteration
=
self
.
iteration
,
dtype
=
dtype
,
)
if
dtype
is
not
None
:
if
rowwise_gemm_tensor
.
dtype
!=
dtype
:
raise
ValueError
(
"Dtype does not match the output of the modify_tensor call"
)
...
...
@@ -332,6 +335,7 @@ class DebugQuantizer(Quantizer):
quantizer
=
self
,
layer_name
=
self
.
layer_name
,
tensor_name
=
self
.
tensor_name
,
original_tensor
=
tensor
,
)
def
process_gemm_output
(
self
,
tensor
:
torch
.
Tensor
):
...
...
@@ -455,8 +459,12 @@ class DebugQuantizer(Quantizer):
return
True
return
False
def
_get_compatible_recipe
(
self
)
->
Union
[
type
[
Recipe
],
None
]:
"""Probably not needed for debug quantizer"""
return
None
class
DebugQuantizedTensor
:
class
DebugQuantizedTensor
(
QuantizedTensorBase
)
:
"""
Class containing quantized tensors after debug. Depending on configuration
it can contain one or two different objects. These objects can be accessed by the method
...
...
@@ -470,6 +478,7 @@ class DebugQuantizedTensor:
quantizer
,
layer_name
=
None
,
tensor_name
=
None
,
original_tensor
=
None
,
):
self
.
rowwise_gemm_tensor
=
rowwise_gemm_tensor
...
...
@@ -477,6 +486,7 @@ class DebugQuantizedTensor:
self
.
quantizer
=
quantizer
self
.
_layer_name
=
layer_name
self
.
_tensor_name
=
tensor_name
self
.
_original_tensor
=
original_tensor
def
prepare_for_saving
(
self
):
""" " Prepare for saving method override"""
...
...
@@ -524,5 +534,5 @@ class DebugQuantizedTensor:
"""Size of the tensor."""
return
self
.
rowwise_gemm_tensor
.
size
()
def
update_usage
(
self
,
rowwise_usage
:
bool
,
columnwise_usage
:
bool
):
def
update_usage
(
self
,
rowwise_usage
:
bool
=
None
,
columnwise_usage
:
bool
=
None
):
"""Update usage of the tensor."""
Prev
1
…
4
5
6
7
8
9
10
11
12
13
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment