Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
TransformerEngine
Commits
0d874a4e
Commit
0d874a4e
authored
Mar 03, 2026
by
wenjh
Browse files
Merge branch 'nv_main' of v2.12
parents
a68e5f87
dfdd3820
Changes
640
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
1073 additions
and
193 deletions
+1073
-193
transformer_engine/common/recipe/__init__.py
transformer_engine/common/recipe/__init__.py
+21
-19
transformer_engine/common/recipe/current_scaling.cu
transformer_engine/common/recipe/current_scaling.cu
+1
-1
transformer_engine/common/recipe/delayed_scaling.cu
transformer_engine/common/recipe/delayed_scaling.cu
+1
-1
transformer_engine/common/recipe/fp8_block_scaling.cu
transformer_engine/common/recipe/fp8_block_scaling.cu
+1
-1
transformer_engine/common/recipe/mxfp8_scaling.cu
transformer_engine/common/recipe/mxfp8_scaling.cu
+253
-0
transformer_engine/common/recipe/nvfp4.cu
transformer_engine/common/recipe/nvfp4.cu
+1
-1
transformer_engine/common/recipe/recipe_common.cuh
transformer_engine/common/recipe/recipe_common.cuh
+1
-1
transformer_engine/common/swizzle/swizzle.cu
transformer_engine/common/swizzle/swizzle.cu
+154
-80
transformer_engine/common/swizzle/swizzle_block_scaling.cu
transformer_engine/common/swizzle/swizzle_block_scaling.cu
+17
-8
transformer_engine/common/transformer_engine.cpp
transformer_engine/common/transformer_engine.cpp
+604
-61
transformer_engine/common/transpose/cast_transpose.cu
transformer_engine/common/transpose/cast_transpose.cu
+1
-1
transformer_engine/common/transpose/cast_transpose.h
transformer_engine/common/transpose/cast_transpose.h
+3
-4
transformer_engine/common/transpose/cast_transpose_fusion.cu
transformer_engine/common/transpose/cast_transpose_fusion.cu
+1
-3
transformer_engine/common/transpose/multi_cast_transpose.cu
transformer_engine/common/transpose/multi_cast_transpose.cu
+1
-1
transformer_engine/common/transpose/quantize_transpose_square_blockwise.cu
...e/common/transpose/quantize_transpose_square_blockwise.cu
+8
-6
transformer_engine/common/transpose/quantize_transpose_vector_blockwise.cu
...e/common/transpose/quantize_transpose_vector_blockwise.cu
+1
-1
transformer_engine/common/transpose/quantize_transpose_vector_blockwise_fp4.cu
...mmon/transpose/quantize_transpose_vector_blockwise_fp4.cu
+1
-1
transformer_engine/common/transpose/rtc/cast_transpose.cu
transformer_engine/common/transpose/rtc/cast_transpose.cu
+1
-1
transformer_engine/common/transpose/rtc/cast_transpose_fusion.cu
...rmer_engine/common/transpose/rtc/cast_transpose_fusion.cu
+1
-1
transformer_engine/common/transpose/rtc/swap_first_dims.cu
transformer_engine/common/transpose/rtc/swap_first_dims.cu
+1
-1
No files found.
Too many changes to show.
To preserve performance only
640 of 640+
files are displayed.
Plain diff
Email patch
transformer_engine/common/recipe/__init__.py
View file @
0d874a4e
# Copyright (c) 2022-202
5
, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# Copyright (c) 2022-202
6
, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
...
...
@@ -50,7 +50,7 @@ class MMParams:
Parameters
----------
use_split_accumulator : bool, default =
`
True
`
use_split_accumulator : bool, default = True
Use FP8 fast accumulation on Hopper or Ada. For more details,
see CUBLASLT_MATMUL_DESC_FAST_ACCUM option for cublasLtMatmul.
"""
...
...
@@ -159,7 +159,7 @@ class DelayedScaling(Recipe):
recipe: DelayedScaling) -> Tensor
where `Tensor` is a framework tensor type.
reduce_amax: bool, default =
`
True
`
reduce_amax: bool, default = True
By default, if `torch.distributed` is initialized, the `amax` value for FP8
tensors is reduced across the `amax_reduction_group` (specified in the `autocast`
call). This keeps the amaxes and scaling factors synced across the given
...
...
@@ -167,13 +167,13 @@ class DelayedScaling(Recipe):
GPU maintains local amaxes and scaling factors. To ensure results are
numerically identical across checkpointing boundaries in this case, all
ranks must checkpoint in order to store the local tensors.
fp8_dpa: bool, default =
`
False
`
fp8_dpa: bool, default = False
Whether to enable FP8 dot product attention (DPA). When the model is placed in an
`autocast(enabled=True)` region and `fp8_dpa` is set to `True`, DPA casts the
inputs from higher precision to FP8, performs attention in FP8, and casts tensors
back to higher precision as outputs. FP8 DPA currently is only supported in the
`FusedAttention` backend.
fp8_mha: bool, default =
`
False
`
fp8_mha: bool, default = False
Whether to enable FP8 multi-head attention (MHA). When `True`, it removes the casting
operations mentioned above at the DPA boundaries. Currently only standard MHA modules
i.e. `LayerNormLinear/Linear + DPA + Linear`, are supported for this feature. When
...
...
@@ -422,11 +422,11 @@ class NVFP4BlockScaling(Recipe):
----------
fp4_format : {Format.E2M1}, default = Format.E2M1
FP4 data type.
disable_rht : bool, default =
`
False
`
disable_rht : bool, default = False
If set to `True`, random Hadamard transforms are not applied to any tensor.
disable_stochastic_rounding : bool, default =
`
False
`
disable_stochastic_rounding : bool, default = False
If set to `True`, stochastic rounding is disabled during quantization for all tensors.
disable_2d_quantization : bool, default =
`
False
`
disable_2d_quantization : bool, default = False
If set to `True`, 1D block scaling with block size 16 is used for all tensors.
"""
...
...
@@ -492,17 +492,19 @@ class CustomRecipe(Recipe):
Parameters
----------
qfactory : Callable
Factory callable that returns a quantizer instance for a
given semantic tensor role.
The callable is typically invoked as:
qfactory(
role: str,
)
Where `role` is one of the following strings for e.g. te.Linear
(stable public contract):
- forward: "linear_input", "linear_weight", "linear_output"
- backward: "linear_grad_output", "linear_grad_input"
Factory callable that returns a quantizer instance for a
given semantic tensor role.
The callable is typically invoked as::
qfactory(
role: str,
)
Where `role` is one of the following strings for e.g. te.Linear
(stable public contract):
- forward: "linear_input", "linear_weight", "linear_output"
- backward: "linear_grad_output", "linear_grad_input"
"""
qfactory
:
Callable
[...,
Any
]
...
...
transformer_engine/common/recipe/current_scaling.cu
View file @
0d874a4e
/*************************************************************************
* Copyright (c) 2022-202
5
, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* Copyright (c) 2022-202
6
, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
*
* See LICENSE for license information.
************************************************************************/
...
...
transformer_engine/common/recipe/delayed_scaling.cu
View file @
0d874a4e
/*************************************************************************
* Copyright (c) 2022-202
5
, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* Copyright (c) 2022-202
6
, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
*
* See LICENSE for license information.
************************************************************************/
...
...
transformer_engine/common/recipe/fp8_block_scaling.cu
View file @
0d874a4e
/*************************************************************************
* Copyright (c) 2022-202
5
, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* Copyright (c) 2022-202
6
, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
*
* See LICENSE for license information.
************************************************************************/
...
...
transformer_engine/common/recipe/mxfp8_scaling.cu
0 → 100644
View file @
0d874a4e
/*************************************************************************
* Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
*
* See LICENSE for license information.
************************************************************************/
#include <transformer_engine/recipe.h>
#include "../common.h"
#include "../util/ptx.cuh"
#include "../utils.cuh"
namespace
transformer_engine
{
namespace
mxfp8_scaling_recipe
{
constexpr
int
rowwise_row_padding
=
128
;
// Row padding of rowwise_scale and rowwise_amax
constexpr
int
rowwise_col_padding
=
4
;
// Column padding of rowwise_scale and rowwise_amax
constexpr
int
colwise_row_padding
=
4
;
// Row padding of colwise_scale and colwise_amax
constexpr
int
colwise_col_padding
=
128
;
// Column padding of colwise_scale and colwise_amax
constexpr
int
kRowsPerTile
=
32
;
// Rows each block processes
constexpr
int
kColsPerTile
=
128
;
// Columns each block processes
constexpr
int
kThreadsPerBlock
=
128
;
template
<
typename
IType
>
__global__
void
__launch_bounds__
(
kThreadsPerBlock
)
mxfp8_scaling_compute_partial_amax_kernel
(
const
IType
*
input
,
IType
*
amax_rowwise
,
IType
*
amax_colwise
,
int
amax_rowwise_stride
,
int
amax_colwise_stride
,
int
rows
,
int
cols
,
size_t
start_offset
,
size_t
len
)
{
__shared__
float
smem_amax_rowwise
[
kRowsPerTile
][
kColsPerTile
/
32
];
size_t
end_offset
=
start_offset
+
len
;
const
IType
*
input_minus_offset
=
input
-
start_offset
;
int
warp_idx
=
threadIdx
.
x
/
32
;
int
lane_idx
=
threadIdx
.
x
%
32
;
int
c
=
blockIdx
.
x
*
kColsPerTile
+
threadIdx
.
x
;
int
r
=
blockIdx
.
y
*
kRowsPerTile
;
float
col_amax
=
0.0
f
;
#pragma unroll
for
(
int
i
=
0
;
i
<
kRowsPerTile
;
i
++
)
{
size_t
idx
=
r
*
cols
+
c
;
float
row_amax
=
0.0
f
;
if
(
r
<
rows
&&
c
<
cols
&&
idx
>=
start_offset
&&
idx
<
end_offset
)
{
float
abs_input
=
fabs
(
static_cast
<
float
>
(
input_minus_offset
[
idx
]));
row_amax
=
fmaxf
(
row_amax
,
abs_input
);
col_amax
=
fmaxf
(
col_amax
,
abs_input
);
}
#pragma unroll
for
(
int
delta
=
16
;
delta
>
0
;
delta
/=
2
)
{
float
other_row_amax
=
__shfl_down_sync
(
0xFFFFFFFF
,
row_amax
,
delta
);
row_amax
=
fmaxf
(
row_amax
,
other_row_amax
);
}
if
(
lane_idx
==
0
)
{
smem_amax_rowwise
[
i
][
warp_idx
]
=
row_amax
;
}
r
++
;
}
amax_colwise
[
blockIdx
.
y
*
amax_colwise_stride
+
c
]
=
static_cast
<
IType
>
(
col_amax
);
__syncthreads
();
int
r_
=
threadIdx
.
x
/
(
kColsPerTile
/
32
);
// rows in shared memory
int
c_
=
threadIdx
.
x
%
(
kColsPerTile
/
32
);
// cols in shared memory
r
=
blockIdx
.
y
*
kRowsPerTile
+
r_
;
c
=
blockIdx
.
x
*
kColsPerTile
/
32
+
c_
;
amax_rowwise
[
r
*
amax_rowwise_stride
+
c
]
=
static_cast
<
IType
>
(
smem_amax_rowwise
[
r_
][
c_
]);
}
template
<
typename
IType
,
typename
OType
>
__global__
void
__launch_bounds__
(
kThreadsPerBlock
)
mxfp8_scaling_partial_cast_kernel
(
const
IType
*
input
,
OType
*
output_rowwise
,
OType
*
output_colwise
,
const
e8m0_t
*
scale_inv_rowwise
,
const
e8m0_t
*
scale_inv_colwise
,
int
scale_inv_rowwise_stride
,
int
scale_inv_colwise_stride
,
int
rows
,
int
cols
,
size_t
start_offset
,
size_t
len
)
{
__shared__
float
smem_scales_rowwise
[
kRowsPerTile
][
kColsPerTile
/
32
];
__shared__
float
smem_scales_colwise
[
kColsPerTile
];
// Load scales_rowwise
{
int
r_
=
threadIdx
.
x
/
(
kColsPerTile
/
32
);
// rows in shared memory
int
c_
=
threadIdx
.
x
%
(
kColsPerTile
/
32
);
// cols in shared memory
int
r
=
blockIdx
.
y
*
kRowsPerTile
+
r_
;
int
c
=
blockIdx
.
x
*
kColsPerTile
/
32
+
c_
;
size_t
idx
=
r
*
scale_inv_rowwise_stride
+
c
;
smem_scales_rowwise
[
r_
][
c_
]
=
ptx
::
exp2f_rcp
(
scale_inv_rowwise
[
idx
]);
}
// Load scales_colwise
{
int
c_
=
threadIdx
.
x
;
int
r
=
blockIdx
.
y
*
kRowsPerTile
/
32
;
int
c
=
blockIdx
.
x
*
kColsPerTile
+
c_
;
size_t
idx
=
r
*
scale_inv_colwise_stride
+
c
;
smem_scales_colwise
[
c_
]
=
ptx
::
exp2f_rcp
(
scale_inv_colwise
[
idx
]);
}
__syncthreads
();
size_t
end_offset
=
start_offset
+
len
;
const
IType
*
input_minus_offset
=
input
-
start_offset
;
OType
*
output_rowwise_minus_offset
=
output_rowwise
-
start_offset
;
OType
*
output_colwise_minus_offset
=
output_colwise
-
start_offset
;
int
warp_idx
=
threadIdx
.
x
/
32
;
// int lane_idx = threadIdx.x % 32;
int
c
=
blockIdx
.
x
*
kColsPerTile
+
threadIdx
.
x
;
int
r
=
blockIdx
.
y
*
kRowsPerTile
;
#pragma unroll
for
(
int
i
=
0
;
i
<
kRowsPerTile
;
i
++
)
{
size_t
idx
=
r
*
cols
+
c
;
if
(
r
<
rows
&&
c
<
cols
&&
idx
>=
start_offset
&&
idx
<
end_offset
)
{
float
inp
=
static_cast
<
float
>
(
input_minus_offset
[
idx
]);
OType
out_rowwise
=
static_cast
<
OType
>
(
inp
*
smem_scales_rowwise
[
i
][
warp_idx
]);
OType
out_colwise
=
static_cast
<
OType
>
(
inp
*
smem_scales_colwise
[
threadIdx
.
x
]);
output_rowwise_minus_offset
[
idx
]
=
out_rowwise
;
output_colwise_minus_offset
[
idx
]
=
out_colwise
;
}
r
++
;
}
}
void
mxfp8_scaling_compute_partial_amax
(
const
Tensor
input
,
Tensor
amax_rowwise
,
Tensor
amax_colwise
,
int
rows
,
int
cols
,
size_t
start_offset
,
cudaStream_t
stream
)
{
NVTE_CHECK
(
rows
%
32
==
0
,
"rows must be divisible by 32"
);
NVTE_CHECK
(
cols
%
32
==
0
,
"cols must be divisible by 32"
);
NVTE_CHECK
(
input
.
data
.
shape
.
size
()
==
1
,
"input must be a 1D tensor"
);
NVTE_CHECK
(
start_offset
+
input
.
data
.
shape
[
0
]
<=
static_cast
<
size_t
>
(
rows
)
*
cols
,
"Invalid start_offset"
);
NVTE_CHECK
(
amax_rowwise
.
data
.
shape
.
size
()
==
2
,
"amax_rowwise must be a 2D tensor"
);
NVTE_CHECK
(
amax_rowwise
.
data
.
shape
[
0
]
%
rowwise_row_padding
==
0
,
"Wrong padding of amax_rowwise's rows"
);
NVTE_CHECK
(
amax_rowwise
.
data
.
shape
[
0
]
>=
rows
,
"Invalid rows"
);
NVTE_CHECK
(
amax_rowwise
.
data
.
shape
[
1
]
%
rowwise_col_padding
==
0
,
"Wrong padding of amax_rowwise's cols"
);
NVTE_CHECK
(
amax_rowwise
.
data
.
shape
[
1
]
>=
cols
/
32
,
"Invalid cols"
);
NVTE_CHECK
(
amax_rowwise
.
dtype
()
==
input
.
dtype
(),
"Wrong dtype of amax_rowwise"
);
NVTE_CHECK
(
amax_colwise
.
data
.
shape
.
size
()
==
2
,
"amax_colwise must be a 2D tensor"
);
NVTE_CHECK
(
amax_colwise
.
data
.
shape
[
0
]
%
colwise_row_padding
==
0
,
"Wrong padding of amax_colwise's rows"
);
NVTE_CHECK
(
amax_colwise
.
data
.
shape
[
0
]
>=
rows
/
32
,
"Invalid rows"
);
NVTE_CHECK
(
amax_colwise
.
data
.
shape
[
1
]
%
colwise_col_padding
==
0
,
"Wrong padding of amax_colwise's cols"
);
NVTE_CHECK
(
amax_colwise
.
data
.
shape
[
1
]
>=
cols
,
"Invalid cols"
);
NVTE_CHECK
(
amax_colwise
.
dtype
()
==
input
.
dtype
(),
"Wrong dtype of amax_colwise"
);
int
blocks_x
=
(
cols
+
kColsPerTile
-
1
)
/
kColsPerTile
;
int
blocks_y
=
(
rows
+
kRowsPerTile
-
1
)
/
kRowsPerTile
;
dim3
grid
(
blocks_x
,
blocks_y
);
TRANSFORMER_ENGINE_TYPE_SWITCH_NON_FP8ONLY
(
input
.
dtype
(),
IType
,
mxfp8_scaling_compute_partial_amax_kernel
<
IType
><<<
grid
,
kColsPerTile
,
0
,
stream
>>>
(
reinterpret_cast
<
const
IType
*>
(
input
.
data
.
dptr
),
reinterpret_cast
<
IType
*>
(
amax_rowwise
.
data
.
dptr
),
reinterpret_cast
<
IType
*>
(
amax_colwise
.
data
.
dptr
),
amax_rowwise
.
data
.
shape
[
1
],
amax_colwise
.
data
.
shape
[
1
],
rows
,
cols
,
start_offset
,
input
.
data
.
shape
[
0
]);)
}
void
mxfp8_scaling_partial_cast
(
const
Tensor
input
,
Tensor
output_rowwise
,
Tensor
output_colwise
,
const
Tensor
scale_inv_rowwise
,
const
Tensor
scale_inv_colwise
,
int
rows
,
int
cols
,
size_t
start_offset
,
cudaStream_t
stream
)
{
NVTE_CHECK
(
rows
%
32
==
0
,
"rows must be divisible by 32"
);
NVTE_CHECK
(
cols
%
32
==
0
,
"cols must be divisible by 32"
);
NVTE_CHECK
(
input
.
data
.
shape
.
size
()
==
1
,
"input must be a 1D tensor"
);
NVTE_CHECK
(
start_offset
+
input
.
data
.
shape
[
0
]
<=
static_cast
<
size_t
>
(
rows
)
*
cols
,
"Invalid start_offset"
);
NVTE_CHECK
(
output_rowwise
.
data
.
shape
.
size
()
==
1
,
"output_rowwise must be a 1D tensor"
);
NVTE_CHECK
(
output_colwise
.
data
.
shape
.
size
()
==
1
,
"output_colwise must be a 1D tensor"
);
NVTE_CHECK
(
output_rowwise
.
data
.
shape
[
0
]
==
input
.
data
.
shape
[
0
],
"Size of input and output_rowwise mismatch"
);
NVTE_CHECK
(
output_colwise
.
data
.
shape
[
0
]
==
input
.
data
.
shape
[
0
],
"Size of input and output_colwise mismatch"
);
NVTE_CHECK
(
output_rowwise
.
dtype
()
==
DType
::
kFloat8E4M3
||
output_rowwise
.
dtype
()
==
DType
::
kByte
,
"output_rowwise should be e4m3 or uint8"
);
NVTE_CHECK
(
output_colwise
.
dtype
()
==
DType
::
kFloat8E4M3
||
output_colwise
.
dtype
()
==
DType
::
kByte
,
"output_colwise should be e4m3 or uint8"
);
NVTE_CHECK
(
scale_inv_rowwise
.
data
.
shape
.
size
()
==
2
,
"scale_inv_rowwise must be a 2D tensor"
);
NVTE_CHECK
(
scale_inv_rowwise
.
data
.
shape
[
0
]
%
rowwise_row_padding
==
0
,
"Wrong padding of scale_inv_rowwise's rows"
);
NVTE_CHECK
(
scale_inv_rowwise
.
data
.
shape
[
0
]
>=
rows
,
"Invalid rows"
);
NVTE_CHECK
(
scale_inv_rowwise
.
data
.
shape
[
1
]
%
rowwise_col_padding
==
0
,
"Wrong padding of scale_inv_rowwise's cols"
);
NVTE_CHECK
(
scale_inv_rowwise
.
data
.
shape
[
1
]
>=
cols
/
32
,
"Invalid cols"
);
NVTE_CHECK
(
scale_inv_rowwise
.
dtype
()
==
DType
::
kByte
,
"Wrong dtype of scale_inv_rowwise"
);
NVTE_CHECK
(
scale_inv_colwise
.
data
.
shape
.
size
()
==
2
,
"scale_inv_colwise must be a 2D tensor"
);
NVTE_CHECK
(
scale_inv_colwise
.
data
.
shape
[
0
]
%
colwise_row_padding
==
0
,
"Wrong padding of scale_inv_colwise's rows"
);
NVTE_CHECK
(
scale_inv_colwise
.
data
.
shape
[
0
]
>=
rows
/
32
,
"Invalid rows"
);
NVTE_CHECK
(
scale_inv_colwise
.
data
.
shape
[
1
]
%
colwise_col_padding
==
0
,
"Wrong padding of scale_inv_colwise's cols"
);
NVTE_CHECK
(
scale_inv_colwise
.
data
.
shape
[
1
]
>=
cols
,
"Invalid cols"
);
NVTE_CHECK
(
scale_inv_colwise
.
dtype
()
==
DType
::
kByte
,
"Wrong dtype of scale_inv_colwise"
);
int
blocks_x
=
(
cols
+
kColsPerTile
-
1
)
/
kColsPerTile
;
int
blocks_y
=
(
rows
+
kRowsPerTile
-
1
)
/
kRowsPerTile
;
dim3
grid
(
blocks_x
,
blocks_y
);
TRANSFORMER_ENGINE_TYPE_SWITCH_NON_FP8ONLY
(
input
.
dtype
(),
IType
,
mxfp8_scaling_partial_cast_kernel
<
IType
,
fp8e4m3
><<<
grid
,
kColsPerTile
,
0
,
stream
>>>
(
reinterpret_cast
<
const
IType
*>
(
input
.
data
.
dptr
),
reinterpret_cast
<
fp8e4m3
*>
(
output_rowwise
.
data
.
dptr
),
reinterpret_cast
<
fp8e4m3
*>
(
output_colwise
.
data
.
dptr
),
reinterpret_cast
<
const
e8m0_t
*>
(
scale_inv_rowwise
.
data
.
dptr
),
reinterpret_cast
<
const
e8m0_t
*>
(
scale_inv_colwise
.
data
.
dptr
),
scale_inv_rowwise
.
data
.
shape
[
1
],
scale_inv_colwise
.
data
.
shape
[
1
],
rows
,
cols
,
start_offset
,
input
.
data
.
shape
[
0
]);)
}
}
// namespace mxfp8_scaling_recipe
}
// namespace transformer_engine
void
nvte_mxfp8_scaling_compute_partial_amax
(
const
NVTETensor
input
,
NVTETensor
amax_rowwise
,
NVTETensor
amax_colwise
,
int
rows
,
int
cols
,
size_t
start_offset
,
cudaStream_t
stream
)
{
NVTE_API_CALL
(
nvte_mxfp8_scaling_compute_partial_amax
);
using
namespace
transformer_engine
;
mxfp8_scaling_recipe
::
mxfp8_scaling_compute_partial_amax
(
*
convertNVTETensorCheck
(
input
),
*
convertNVTETensorCheck
(
amax_rowwise
),
*
convertNVTETensorCheck
(
amax_colwise
),
rows
,
cols
,
start_offset
,
stream
);
}
void
nvte_mxfp8_scaling_partial_cast
(
const
NVTETensor
input
,
NVTETensor
output_rowwise
,
NVTETensor
output_colwise
,
const
NVTETensor
scale_inv_rowwise
,
const
NVTETensor
scale_inv_colwise
,
int
rows
,
int
cols
,
size_t
start_offset
,
cudaStream_t
stream
)
{
NVTE_API_CALL
(
nvte_mxfp8_scaling_partial_cast
);
using
namespace
transformer_engine
;
mxfp8_scaling_recipe
::
mxfp8_scaling_partial_cast
(
*
convertNVTETensorCheck
(
input
),
*
convertNVTETensorCheck
(
output_rowwise
),
*
convertNVTETensorCheck
(
output_colwise
),
*
convertNVTETensorCheck
(
scale_inv_rowwise
),
*
convertNVTETensorCheck
(
scale_inv_colwise
),
rows
,
cols
,
start_offset
,
stream
);
}
transformer_engine/common/recipe/nvfp4.cu
View file @
0d874a4e
/*************************************************************************
* Copyright (c) 2022-202
5
, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* Copyright (c) 2022-202
6
, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
*
* See LICENSE for license information.
************************************************************************/
...
...
transformer_engine/common/recipe/recipe_common.cuh
View file @
0d874a4e
/*************************************************************************
* Copyright (c) 2022-202
5
, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* Copyright (c) 2022-202
6
, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
*
* See LICENSE for license information.
************************************************************************/
...
...
transformer_engine/common/swizzle/swizzle.cu
View file @
0d874a4e
/*************************************************************************
* Copyright (c) 2022-202
5
, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* Copyright (c) 2022-202
6
, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
*
* See LICENSE for license information.
************************************************************************/
...
...
@@ -342,68 +342,122 @@ __global__ void multi_tensor_swizzle_col_scaling_kernel(MultiSwizzleArgs kernel_
}
// namespace
void
swizzle_scaling_factors
(
const
Tensor
*
input
,
Tensor
*
output
,
cudaStream_t
stream
)
{
NVTE_CHECK
(
input
->
scaling_mode
==
NVTE_MXFP8_1D_SCALING
||
input
->
scaling_mode
==
NVTE_NVFP4_1D_SCALING
,
"Input tensor has invalid scaling mode ("
,
to_string
(
input
->
scaling_mode
),
")."
);
NVTE_CHECK
(
is_fp8_dtype
(
input
->
dtype
())
||
is_fp4_dtype
(
input
->
dtype
()),
"Input tensor has invalid dtype ("
,
to_string
(
input
->
dtype
()),
")."
);
// Do nothing if tensor is empty
if
(
input
->
data
.
numel
()
==
0
)
{
return
;
}
// Check scaling mode
const
auto
&
scaling_mode
=
input
->
scaling_mode
;
NVTE_CHECK
(
scaling_mode
==
NVTE_MXFP8_1D_SCALING
||
scaling_mode
==
NVTE_NVFP4_1D_SCALING
,
"Input tensor has invalid scaling mode ("
,
to_string
(
input
->
scaling_mode
),
")."
);
// Check tensors
CheckInputTensor
(
*
input
,
"scaling_factor_input"
);
CheckInputTensor
(
*
output
,
"scaling_factor_output"
);
NVTE_CHECK
(
!
input
->
with_gemm_swizzled_scales
,
"Expected input tensor with scales in compact format."
);
NVTE_CHECK
(
output
->
with_gemm_swizzled_scales
,
"Expected output tensor with scales in GEMM swizzled format."
);
switch
(
scaling_mode
)
{
case
NVTE_MXFP8_1D_SCALING
:
NVTE_CHECK
(
is_fp8_dtype
(
input
->
dtype
()),
"Input tensor has invalid dtype (expected FP8, got "
,
to_string
(
input
->
dtype
()),
")."
);
break
;
case
NVTE_NVFP4_1D_SCALING
:
NVTE_CHECK
(
is_fp4_dtype
(
input
->
dtype
()),
"Input tensor has invalid dtype (expected FP4, got "
,
to_string
(
input
->
dtype
()),
")."
);
break
;
default:
NVTE_ERROR
(
"Invalid scaling mode"
);
}
auto
&
scaling_mode
=
input
->
scaling_mode
;
NVTE_CHECK
(
scaling_mode
==
NVTE_MXFP8_1D_SCALING
||
scaling_mode
==
NVTE_NVFP4_1D_SCALING
,
"Unsupported scaling mode for swizzling."
);
bool
nvfp4
=
scaling_mode
==
NVTE_NVFP4_1D_SCALING
;
// Check if scaling factors are non-trivial
const
bool
has_rowwise_scale_inv
=
input
->
scale_inv
.
has_data
();
const
bool
has_columnwise_scale_inv
=
input
->
columnwise_scale_inv
.
has_data
();
NVTE_CHECK
(
!
has_rowwise_scale_inv
||
!
has_columnwise_scale_inv
,
"Input tensor has both row-wise and column-wise scaling factors"
);
if
(
!
has_rowwise_scale_inv
&&
!
has_columnwise_scale_inv
)
{
return
;
}
// 1D block scaling, row-wise or colum-wise
int
m
,
k
;
if
(
input
->
has_data
())
{
m
=
input
->
scale_inv
.
shape
[
0
];
k
=
input
->
scale_inv
.
shape
[
1
];
}
else
{
if
(
nvfp4
)
{
m
=
input
->
columnwise_scale_inv
.
shape
[
0
];
k
=
input
->
columnwise_scale_inv
.
shape
[
1
];
}
else
{
m
=
input
->
columnwise_scale_inv
.
shape
[
1
];
k
=
input
->
columnwise_scale_inv
.
shape
[
0
];
// Deduce tensor dims
int
m
{
0
},
k
{
0
};
switch
(
scaling_mode
)
{
case
NVTE_MXFP8_1D_SCALING
:
{
if
(
has_rowwise_scale_inv
)
{
NVTE_CHECK
(
input
->
scale_inv
.
shape
.
size
()
==
2
,
"Expected 2D scaling factors, got shape="
,
input
->
scale_inv
.
shape
,
"."
);
m
=
input
->
scale_inv
.
shape
[
0
];
k
=
input
->
scale_inv
.
shape
[
1
];
}
else
if
(
has_columnwise_scale_inv
)
{
NVTE_CHECK
(
input
->
columnwise_scale_inv
.
shape
.
size
()
==
2
,
"Expected 2D scaling factors, got shape="
,
input
->
columnwise_scale_inv
.
shape
,
"."
);
m
=
input
->
columnwise_scale_inv
.
shape
[
1
];
k
=
input
->
columnwise_scale_inv
.
shape
[
0
];
}
break
;
}
case
NVTE_NVFP4_1D_SCALING
:
{
if
(
has_rowwise_scale_inv
)
{
NVTE_CHECK
(
input
->
scale_inv
.
shape
.
size
()
==
2
,
"Expected 2D scaling factors, got shape="
,
input
->
scale_inv
.
shape
,
"."
);
m
=
input
->
scale_inv
.
shape
[
0
];
k
=
input
->
scale_inv
.
shape
[
1
];
}
else
if
(
has_columnwise_scale_inv
)
{
NVTE_CHECK
(
input
->
columnwise_scale_inv
.
shape
.
size
()
==
2
,
"Expected 2D scaling factors, got shape="
,
input
->
columnwise_scale_inv
.
shape
,
"."
);
m
=
input
->
columnwise_scale_inv
.
shape
[
0
];
k
=
input
->
columnwise_scale_inv
.
shape
[
1
];
}
break
;
}
default:
NVTE_ERROR
(
"Invalid scaling mode"
);
}
// Check dims
constexpr
int
SF_TILE_DIM_M
=
128
;
constexpr
int
SF_TILE_DIM_K
=
4
;
NVTE_CHECK
(
m
%
SF_TILE_DIM_M
==
0
,
"Input should be padded in M/N dimension!"
);
NVTE_CHECK
(
k
%
SF_TILE_DIM_K
==
0
,
"Input should be padded in K dimension!"
);
NVTE_CHECK
(
k
>
0
,
"Input scale inverse should be 2D!"
);
if
(
output
->
has_data
())
{
NVTE_CHECK
(
m
*
k
==
std
::
accumulate
(
output
->
scale_inv
.
shape
.
begin
(),
output
->
scale_inv
.
shape
.
end
(),
1
,
std
::
multiplies
<
int
>
()),
"Input.scale_inv size is not equal to Output.scale_inv size!"
);
// Check that output tensor matches input tensor
if
(
has_rowwise_scale_inv
)
{
NVTE_CHECK
(
output
->
scale_inv
.
has_data
(),
"Output tensor does not have row-wise scaling factors."
);
NVTE_CHECK
(
m
*
k
==
output
->
scale_inv
.
numel
(),
"Expected output tensor to have "
,
m
*
k
,
" row-wise scaling factors, but got shape="
,
output
->
scale_inv
.
shape
,
"."
);
}
if
(
output
->
has_columnwise_
data
()
)
{
NVTE_CHECK
(
m
*
k
==
std
::
accumulate
(
output
->
columnwise_scale_inv
.
s
ha
pe
.
begin
(),
output
->
columnwise
_
scal
e_inv
.
shape
.
end
(),
1
,
std
::
multiplies
<
int
>
()),
"In
put
.
columnwise_scale_inv
size is not equal to "
"O
utput
.
columnwise_scale_inv
size!
"
);
if
(
has_columnwise_
scale_inv
)
{
NVTE_CHECK
(
output
->
columnwise_scale_inv
.
ha
s_data
(),
"Output tensor does not have
column
-
wise
scal
ing factors."
);
NVTE_CHECK
(
m
*
k
==
out
put
->
columnwise_scale_inv
.
numel
(),
"Expected output tensor to have "
,
m
*
k
,
" column-wise scaling factors, but got shape="
,
o
utput
->
columnwise_scale_inv
.
shape
,
".
"
);
}
int
num_tiles_m
=
m
/
SF_TILE_DIM_M
;
int
num_tiles_k
=
k
/
SF_TILE_DIM_K
;
// Choose swizzle implementation
bool
rowwise_swizzle
{
false
},
columnwise_swizzle
{
false
};
switch
(
scaling_mode
)
{
case
NVTE_MXFP8_1D_SCALING
:
{
rowwise_swizzle
=
has_rowwise_scale_inv
;
columnwise_swizzle
=
has_columnwise_scale_inv
;
break
;
}
case
NVTE_NVFP4_1D_SCALING
:
{
// NVFP4 column-wise data is transposed, so row-wise and
// column-wise scales have same swizzling format
rowwise_swizzle
=
true
;
columnwise_swizzle
=
false
;
break
;
}
default:
NVTE_ERROR
(
"Invalid scaling mode"
);
}
// For NVFP4, the scale inverse for tranposed data needs rowwise swizzle.
const
bool
rowwise_swizzle
=
input
->
has_data
()
||
nvfp4
;
const
bool
columnwise_swizzle
=
input
->
has_columnwise_data
()
&&
!
nvfp4
;
const
dim3
block_size
(
TB_DIM
,
TB_DIM
);
const
int
num_tiles_m
=
m
/
SF_TILE_DIM_M
;
const
int
num_tiles_k
=
k
/
SF_TILE_DIM_K
;
dim3
block_size
(
TB_DIM
,
TB_DIM
);
// Perform row-wise swizzle
if
(
rowwise_swizzle
)
{
int
vec_load_size
=
(
num_tiles_k
-
1
)
%
4
+
1
;
/* there is no int3 and misaligned if using int4/int2 */
...
...
@@ -412,20 +466,32 @@ void swizzle_scaling_factors(const Tensor* input, Tensor* output, cudaStream_t s
dim3
num_blocks
(
DIVUP
(
num_tiles_k
,
n_tiles_in_tb
),
num_tiles_m
);
int
slm_size
=
n_tiles_in_tb
*
SF_TILE_DIM_M
*
SF_TILE_DIM_K
*
sizeof
(
int8_t
);
int
original_M
,
original_K
;
void
*
input_scale_inv_ptr
,
*
output_scale_inv_ptr
;
if
(
!
nvfp4
||
input
->
has_data
())
{
int
block_scale_size
=
nvfp4
?
NVFP4_BLOCK_SIZE
:
MXFP8_BLOCK_SIZE
;
original_M
=
input
->
flat_first_dim
();
original_K
=
input
->
flat_last_dim
()
/
block_scale_size
;
input_scale_inv_ptr
=
input
->
scale_inv
.
dptr
;
output_scale_inv_ptr
=
output
->
scale_inv
.
dptr
;
}
else
{
original_M
=
input
->
flat_last_dim
();
original_K
=
input
->
flat_first_dim
()
/
NVFP4_BLOCK_SIZE
;
input_scale_inv_ptr
=
input
->
columnwise_scale_inv
.
dptr
;
output_scale_inv_ptr
=
output
->
columnwise_scale_inv
.
dptr
;
int
original_M
{
0
},
original_K
{
0
};
void
*
input_scale_inv_ptr
{
nullptr
},
*
output_scale_inv_ptr
{
nullptr
};
switch
(
scaling_mode
)
{
case
NVTE_MXFP8_1D_SCALING
:
{
original_M
=
input
->
flat_first_dim
();
original_K
=
input
->
flat_last_dim
()
/
MXFP8_BLOCK_SIZE
;
input_scale_inv_ptr
=
input
->
scale_inv
.
dptr
;
output_scale_inv_ptr
=
output
->
scale_inv
.
dptr
;
break
;
}
case
NVTE_NVFP4_1D_SCALING
:
{
if
(
has_rowwise_scale_inv
)
{
original_M
=
input
->
flat_first_dim
();
original_K
=
input
->
flat_last_dim
()
/
NVFP4_BLOCK_SIZE
;
input_scale_inv_ptr
=
input
->
scale_inv
.
dptr
;
output_scale_inv_ptr
=
output
->
scale_inv
.
dptr
;
}
else
if
(
has_columnwise_scale_inv
)
{
original_M
=
input
->
flat_last_dim
();
original_K
=
input
->
flat_first_dim
()
/
NVFP4_BLOCK_SIZE
;
input_scale_inv_ptr
=
input
->
columnwise_scale_inv
.
dptr
;
output_scale_inv_ptr
=
output
->
columnwise_scale_inv
.
dptr
;
}
break
;
}
default:
NVTE_ERROR
(
"Invalid scaling mode"
);
}
switch
(
vec_load_size
)
{
...
...
@@ -481,7 +547,10 @@ void swizzle_scaling_factors(const Tensor* input, Tensor* output, cudaStream_t s
NVTE_ERROR
(
"Not valid vec_load_size."
);
break
;
}
NVTE_CHECK_CUDA
(
cudaGetLastError
());
}
// Perform column-wise swizzle
if
(
columnwise_swizzle
)
{
int
vec_load_size
=
(
num_tiles_m
-
1
)
%
4
+
1
;
if
(
vec_load_size
==
3
)
vec_load_size
=
1
;
/* no int3 and misaligned if using int4/int2 */
...
...
@@ -490,8 +559,6 @@ void swizzle_scaling_factors(const Tensor* input, Tensor* output, cudaStream_t s
int
slm_size
=
n_tiles_in_tb
*
SF_TILE_DIM_M
*
SF_TILE_DIM_K
*
sizeof
(
int8_t
);
const
int
original_M
=
input
->
flat_last_dim
();
const
int
original_K
=
input
->
flat_first_dim
()
/
MXFP8_BLOCK_SIZE
;
// NVFP4 shouldn't end up here because it only needs rowwise swizzle
NVTE_CHECK
(
!
nvfp4
,
"NVFP4 shouldn't end up here because it only needs rowwise swizzle"
);
switch
(
vec_load_size
)
{
#ifdef __HIP_PLATFORM_AMD__
...
...
@@ -552,8 +619,8 @@ void swizzle_scaling_factors(const Tensor* input, Tensor* output, cudaStream_t s
NVTE_ERROR
(
"Not valid vec_load_size."
);
break
;
}
NVTE_CHECK_CUDA
(
cudaGetLastError
());
}
NVTE_CHECK_CUDA
(
cudaGetLastError
());
}
template
<
int
SF_TILE_DIM_M
,
int
SF_TILE_DIM_K
>
...
...
@@ -702,18 +769,24 @@ void multi_tensor_swizzle_scaling_factors(const std::vector<Tensor*>& input,
NVTE_CHECK
(
(
is_fp8
&&
is_mxfp8_scaling
(
scaling_mode
))
||
(
is_fp4
&&
is_nvfp4_scaling
(
scaling_mode
)),
"Not implemented scaling mode "
+
to_string
(
scaling_mode
)
+
"."
);
NVTE_CHECK
(
!
input
[
i
]
->
with_gemm_swizzled_scales
,
"Expected input tensors with scales in compact format."
);
NVTE_CHECK
(
output
[
i
]
->
with_gemm_swizzled_scales
,
"Expected output tensors with scales in GEMM swizzled format."
);
// We don't allow empty tensors. They should be filtered out before calling this function.
if
(
input
[
i
]
->
data
.
numel
()
==
0
)
{
NVTE_ERROR
(
"Tensor input["
+
std
::
to_string
(
i
)
+
"] is empty."
);
}
NVTE_CHECK
(
input
[
i
]
->
numel
()
!=
0
,
"Tensor input["
,
i
,
"] is empty."
);
CheckInputTensor
(
*
input
[
i
],
"scaling_factor_input["
+
std
::
to_string
(
i
)
+
"]"
);
CheckInputTensor
(
*
output
[
i
],
"scaling_factor_output["
+
std
::
to_string
(
i
)
+
"]"
);
all_has_data
&=
input
[
i
]
->
has_data
();
all_has_columnwise_data
&=
input
[
i
]
->
has_columnwise_data
();
all_nvfp4
&=
is_nvfp4_scaling
(
scaling_mode
);
all_has_data
=
all_has_data
&&
input
[
i
]
->
scale_inv
.
has_data
();
all_has_columnwise_data
=
(
all_has_columnwise_data
&&
input
[
i
]
->
columnwise_scale_inv
.
has_data
());
all_nvfp4
=
all_nvfp4
&&
is_nvfp4_scaling
(
scaling_mode
);
}
NVTE_CHECK
(
all_has_data
||
all_has_columnwise_data
,
"All tensors should have data or columnwise data."
);
NVTE_CHECK
(
!
all_has_data
||
!
all_has_columnwise_data
,
"All tensors have both data and columnwise data."
);
const
bool
rowwise_swizzle
=
all_has_data
||
all_nvfp4
;
const
bool
columnwise_swizzle
=
all_has_columnwise_data
&&
!
all_nvfp4
;
...
...
@@ -752,18 +825,19 @@ void multi_tensor_swizzle_scaling_factors(const std::vector<Tensor*>& input,
NVTE_CHECK
(
k
%
SF_TILE_DIM_K
==
0
,
"Input should be padded in K dimension!"
);
NVTE_CHECK
(
k
>
0
,
"Input scale inverse should be 2D!"
);
if
(
output
[
i
]
->
has_data
())
{
NVTE_CHECK
(
m
*
k
==
std
::
accumulate
(
output
[
i
]
->
scale_inv
.
shape
.
begin
(),
output
[
i
]
->
scale_inv
.
shape
.
end
(),
1
,
std
::
multiplies
<
int
>
()),
"Input.scale_inv size is not equal to Output.scale_inv size!"
);
if
(
all_has_data
)
{
NVTE_CHECK
(
output
[
i
]
->
scale_inv
.
has_data
(),
"Output tensor "
,
i
,
" does not have row-wise scaling factors."
);
NVTE_CHECK
(
m
*
k
==
output
[
i
]
->
scale_inv
.
numel
(),
"Expected output tensor "
,
i
,
" to have "
,
m
*
k
,
" row-wise scaling factors, but got shape="
,
output
[
i
]
->
scale_inv
.
shape
,
"."
);
}
if
(
output
[
i
]
->
has_columnwise_data
()
)
{
NVTE_CHECK
(
m
*
k
==
std
::
accumulate
(
output
[
i
]
->
columnwise_scale_inv
.
shape
.
begin
()
,
output
[
i
]
->
columnwise
_
scal
e_inv
.
shape
.
end
(),
1
,
std
::
multiplies
<
int
>
())
,
"
Input.
columnwise
_
scal
e_inv size is not equal to "
"O
utput
.
columnwise_scale_inv
size!
"
);
if
(
all_
has_columnwise_data
)
{
NVTE_CHECK
(
output
[
i
]
->
columnwise_scale_inv
.
has_data
(),
"Output tensor "
,
i
,
" does not have
column
-
wise
scal
ing factors."
);
NVTE_CHECK
(
m
*
k
==
output
[
i
]
->
columnwise_scale_inv
.
numel
(),
"Expected output tensor "
,
i
,
"
to have "
,
m
*
k
,
"
column
-
wise
scal
ing factors, but got shape="
,
o
utput
[
i
]
->
columnwise_scale_inv
.
shape
,
".
"
);
}
int
num_tiles_k
=
k
/
SF_TILE_DIM_K
;
...
...
transformer_engine/common/swizzle/swizzle_block_scaling.cu
View file @
0d874a4e
/*************************************************************************
* Copyright (c) 2022-202
5
, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* Copyright (c) 2022-202
6
, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
*
* See LICENSE for license information.
************************************************************************/
...
...
@@ -99,7 +99,8 @@ void __global__ __launch_bounds__(WARPS_X_PER_TB* WARPS_Y_PER_TB* WARP_SIZE)
// calculate this warp's input base pointer
constexpr
uint32_t
in_x_stride
=
WARP_SIZE
*
sizeof
(
uint4
);
const
void
*
const
warp_src
=
in
+
in_tile_y
*
in_y_stride
+
in_tile_x
*
in_x_stride
;
const
void
*
const
warp_src
=
(
reinterpret_cast
<
const
uint8_t
*>
(
in
)
+
in_tile_y
*
in_y_stride
+
in_tile_x
*
in_x_stride
);
// load scaling factors for this lane's initial four 1x128 tiles
uint4
sf
;
...
...
@@ -114,7 +115,8 @@ void __global__ __launch_bounds__(WARPS_X_PER_TB* WARPS_Y_PER_TB* WARP_SIZE)
}
// pack the exponent bits of the scaling factors
uint32_t
packed_exponents
=
(
sf
.
x
>>
23
)
|
(
sf
.
y
>>
15
)
|
(
sf
.
z
>>
7
)
|
(
sf
.
w
<<
1
);
uint32_t
packed_exponents
=
((
sf
.
x
>>
23
)
&
0xFF
)
|
(((
sf
.
y
>>
23
)
&
0xFF
)
<<
8
)
|
(((
sf
.
z
>>
23
)
&
0xFF
)
<<
16
)
|
(((
sf
.
w
>>
23
)
&
0xFF
)
<<
24
);
// partially swizzle the scaling factors
constexpr
uint32_t
ACTIVE_MASK
=
0xFFFFFFFF
;
// no divergent branches
...
...
@@ -129,7 +131,8 @@ void __global__ __launch_bounds__(WARPS_X_PER_TB* WARPS_Y_PER_TB* WARP_SIZE)
// store them cooperatively for 512 1x32 tiles in a 128x128 tile
constexpr
uint32_t
out_x_stride
=
512
;
void
*
const
warp_dst
=
out
+
out_tile_y
*
out_y_stride
+
out_tile_x
*
out_x_stride
;
void
*
const
warp_dst
=
(
reinterpret_cast
<
uint8_t
*>
(
out
)
+
out_tile_y
*
out_y_stride
+
out_tile_x
*
out_x_stride
);
reinterpret_cast
<
uint4
*>
(
warp_dst
)[
lane
]
=
sf
;
}
...
...
@@ -193,21 +196,24 @@ void __global__ __launch_bounds__(WARPS_X_PER_TB* WARPS_Y_PER_TB* WARP_SIZE)
// calculate this warp's input base pointer
constexpr
uint32_t
in_x_stride
=
sizeof
(
float
);
const
void
*
const
warp_src
=
in
+
in_tile_y
*
in_y_stride
+
in_tile_x
*
in_x_stride
;
const
void
*
const
warp_src
=
(
reinterpret_cast
<
const
uint8_t
*>
(
in
)
+
in_tile_y
*
in_y_stride
+
in_tile_x
*
in_x_stride
);
// load scaling factor for this warp's 128x128 tile
uint32_t
sf
=
*
reinterpret_cast
<
const
uint32_t
*>
(
warp_src
);
// broadcast it to four scaling factors for 1x32 tiles
sf
=
(
sf
<<
1
)
|
(
sf
>>
7
);
sf
=
sf
|
(
sf
>>
16
);
// extract and broadcast the exponent byte to four bytes for E8M0 format
uint32_t
exp_byte
=
(
sf
>>
23
)
&
0xFF
;
sf
=
exp_byte
|
(
exp_byte
<<
8
)
|
(
exp_byte
<<
16
)
|
(
exp_byte
<<
24
);
// broadcast it to sixteen scaling factors for 1x32 tiles
const
uint4
sf4
{
sf
,
sf
,
sf
,
sf
};
// store it cooperatively for 512 1x32 tiles in a 128x128 tile
constexpr
uint32_t
out_x_stride
=
512
;
void
*
const
warp_dst
=
out
+
out_tile_y
*
out_y_stride
+
out_tile_x
*
out_x_stride
;
void
*
const
warp_dst
=
(
reinterpret_cast
<
uint8_t
*>
(
out
)
+
out_tile_y
*
out_y_stride
+
out_tile_x
*
out_x_stride
);
reinterpret_cast
<
uint4
*>
(
warp_dst
)[
lane
]
=
sf4
;
}
...
...
@@ -260,6 +266,9 @@ void swizzle_block_scaling_to_mxfp8_scaling_factors(const Tensor* input, Tensor*
NVTE_CHECK
(
output
->
scale_inv
.
dtype
==
DType
::
kFloat8E8M0
,
"Output must have E8M0 scaling factors"
);
NVTE_CHECK
(
output
->
with_gemm_swizzled_scales
,
"Expected output tensor with scales in GEMM swizzled format."
);
NVTE_CHECK
(
input
->
data
.
dptr
!=
nullptr
,
"Input must have rowwise data"
);
NVTE_CHECK
(
output
->
data
.
dptr
==
input
->
data
.
dptr
,
"Output must share data with input"
);
NVTE_CHECK
(
input
->
scale_inv
.
dptr
!=
nullptr
,
"Input must have rowwise scaling factors"
);
...
...
transformer_engine/common/transformer_engine.cpp
View file @
0d874a4e
/*************************************************************************
* Copyright (c) 2022-202
5
, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* Copyright (c) 2022-202
6
, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
*
* See LICENSE for license information.
************************************************************************/
#include <transformer_engine/transformer_engine.h>
#include <algorithm>
#include <atomic>
#include <climits>
#include <cstring>
#include <iostream>
#include <mutex>
#include <optional>
#include <string>
#include <utility>
#include <vector>
#include "common.h"
#include "common/util/cuda_runtime.h"
...
...
@@ -81,7 +85,7 @@ std::string to_string(const NVTEScalingMode &mode) {
}
void
CheckNoopTensor
(
const
Tensor
&
t
,
const
std
::
string
&
name
)
{
if
(
t
.
data
.
dptr
!=
nullptr
)
{
if
(
t
.
data
.
has_data
()
)
{
NVTE_CHECK
(
t
.
numel
()
==
1
,
"Expected 1 element for "
,
name
,
" noop, but found "
,
t
.
numel
(),
"."
);
NVTE_CHECK
(
t
.
data
.
dtype
==
DType
::
kFloat32
,
"Found wrong dtype for "
,
name
,
...
...
@@ -92,15 +96,30 @@ void CheckNoopTensor(const Tensor &t, const std::string &name) {
void
CheckScaleTensorShape
(
const
Tensor
&
t
,
const
std
::
string
&
name
)
{
NVTE_CHECK
(
t
.
scaling_mode
!=
NVTE_INVALID_SCALING
,
"Invalid scaling mode!"
);
if
(
is_tensor_scaling
(
t
.
scaling_mode
))
{
// per-tensor scaling
if
(
t
.
has_data
())
{
NVTE_CHECK
(
t
.
scale_inv
.
numel
()
==
1
,
"Tensor
\"
"
,
name
,
"
\"
has invalid scale_inv shape (expected (1), got "
,
t
.
scale_inv
.
shape
,
")"
);
}
if
(
t
.
has_columnwise_data
())
{
NVTE_CHECK
(
t
.
columnwise_scale_inv
.
numel
()
==
1
,
"Tensor
\"
"
,
name
,
"
\"
has invalid columnwise_scale_inv shape (expected (1), got "
,
t
.
columnwise_scale_inv
.
shape
,
")"
);
if
(
is_fp8_dtype
(
t
.
dtype
()))
{
// FP8 tensor with tensor scaling
if
(
t
.
has_data
())
{
NVTE_CHECK
(
t
.
scale_inv
.
numel
()
==
1
,
"Tensor
\"
"
,
name
,
"
\"
has invalid scale_inv shape (expected 1 entry, got "
,
t
.
scale_inv
.
shape
,
")"
);
}
if
(
t
.
has_columnwise_data
())
{
NVTE_CHECK
(
t
.
columnwise_scale_inv
.
numel
()
==
1
,
"Tensor
\"
"
,
name
,
"
\"
has invalid columnwise_scale_inv shape (expected 1 entry, got "
,
t
.
columnwise_scale_inv
.
shape
,
")"
);
}
}
else
{
// High-precision tensor
if
(
t
.
has_data
())
{
NVTE_CHECK
(
t
.
scale_inv
.
numel
()
==
0
,
"Tensor
\"
"
,
name
,
"
\"
has invalid scale_inv shape (expected 0 entries, got "
,
t
.
scale_inv
.
shape
,
")"
);
}
if
(
t
.
has_columnwise_data
())
{
NVTE_CHECK
(
t
.
columnwise_scale_inv
.
numel
()
==
0
,
"Tensor
\"
"
,
name
,
"
\"
has invalid columnwise_scale_inv shape (expected 0 entries, got "
,
t
.
columnwise_scale_inv
.
shape
,
")"
);
}
}
}
else
{
if
(
t
.
scaling_mode
==
NVTE_MXFP8_1D_SCALING
)
{
...
...
@@ -163,7 +182,7 @@ void CheckInputTensor(const Tensor &t, const std::string &name) {
if
(
is_fp8_dtype
(
type
)
||
is_int8_dtype
(
type
))
{
// FP8 input needs to have scale_inv
if
(
t
.
has_data
())
{
NVTE_CHECK
(
t
.
scale_inv
.
dptr
!=
nullptr
,
"FP8 scaling factor input "
,
name
,
NVTE_CHECK
(
t
.
scale_inv
.
has_data
()
,
"FP8 scaling factor input "
,
name
,
"_scale_inverse must be allocated"
);
NVTE_CHECK
(
t
.
scale_inv
.
dtype
==
DType
::
kFloat32
||
t
.
scale_inv
.
dtype
==
DType
::
kFloat8E8M0
,
"FP8 scaling factor input "
,
name
,
...
...
@@ -172,7 +191,7 @@ void CheckInputTensor(const Tensor &t, const std::string &name) {
to_string
(
t
.
scale_inv
.
dtype
),
")"
);
}
if
(
t
.
has_columnwise_data
())
{
NVTE_CHECK
(
t
.
columnwise_scale_inv
.
dptr
!=
nullptr
,
"FP8 scaling factor input "
,
name
,
NVTE_CHECK
(
t
.
columnwise_scale_inv
.
has_data
()
,
"FP8 scaling factor input "
,
name
,
"_columnwise_scale_inverse must be allocated"
);
NVTE_CHECK
(
t
.
columnwise_scale_inv
.
dtype
==
DType
::
kFloat32
||
t
.
columnwise_scale_inv
.
dtype
==
DType
::
kFloat8E8M0
,
...
...
@@ -185,7 +204,7 @@ void CheckInputTensor(const Tensor &t, const std::string &name) {
// TODO(ksivaman): Fix this to check for amaxes and other details.
// For now only needed for swizzle.
if
(
t
.
has_data
())
{
NVTE_CHECK
(
t
.
scale_inv
.
dptr
!=
nullptr
,
"FP4 scaling factor input "
,
name
,
NVTE_CHECK
(
t
.
scale_inv
.
has_data
()
,
"FP4 scaling factor input "
,
name
,
"_scale_inverse must be allocated"
);
NVTE_CHECK
(
t
.
scale_inv
.
dtype
==
DType
::
kFloat8E4M3
,
"FP4 scaling factor input "
,
name
,
"_scale_inverse has invalid dtype "
...
...
@@ -193,7 +212,7 @@ void CheckInputTensor(const Tensor &t, const std::string &name) {
to_string
(
t
.
scale_inv
.
dtype
),
")"
);
}
if
(
t
.
has_columnwise_data
())
{
NVTE_CHECK
(
t
.
columnwise_scale_inv
.
dptr
!=
nullptr
,
"FP4 scaling factor input "
,
name
,
NVTE_CHECK
(
t
.
columnwise_scale_inv
.
has_data
()
,
"FP4 scaling factor input "
,
name
,
"_columnwise_scale_inverse must be allocated"
);
NVTE_CHECK
(
t
.
columnwise_scale_inv
.
dtype
==
DType
::
kFloat8E4M3
,
"FP8 scaling factor input "
,
name
,
...
...
@@ -202,11 +221,10 @@ void CheckInputTensor(const Tensor &t, const std::string &name) {
to_string
(
t
.
columnwise_scale_inv
.
dtype
),
")"
);
}
}
else
{
NVTE_CHECK
(
t
.
scale
.
dptr
==
nullptr
,
"Scale is not supported for non-FP8 input "
,
name
);
NVTE_CHECK
(
t
.
amax
.
dptr
==
nullptr
,
"Amax is not supported for non-FP8 input "
,
name
);
NVTE_CHECK
(
t
.
scale_inv
.
dptr
==
nullptr
,
"Scale_inv is not supported for non-FP8 input "
,
name
);
NVTE_CHECK
(
t
.
columnwise_scale_inv
.
dptr
==
nullptr
,
"Scale_inv is not supported for non-FP8 input "
,
name
);
NVTE_CHECK
(
!
t
.
scale
.
has_data
(),
"Scale is not supported for non-FP8 input "
,
name
);
NVTE_CHECK
(
!
t
.
scale_inv
.
has_data
(),
"Scale_inv is not supported for non-FP8 input "
,
name
);
NVTE_CHECK
(
!
t
.
columnwise_scale_inv
.
has_data
(),
"Scale_inv is not supported for non-FP8 input "
,
name
);
}
NVTE_CHECK
(
t
.
has_data
()
||
t
.
has_columnwise_data
(),
"Input "
,
name
,
" is not allocated!"
);
...
...
@@ -217,14 +235,14 @@ void CheckOutputTensor(const Tensor &t, const std::string &name, bool allow_empt
const
DType
type
=
t
.
dtype
();
if
(
is_fp8_dtype
(
type
)
||
is_int8_dtype
(
type
))
{
// FP8 output needs to have scale, scale_inv and (if delayed scaling) amax
if
(
t
.
scaling_mode
==
NVTE_DELAYED_TENSOR_SCALING
&&
t
.
amax
.
dptr
!=
nullptr
)
{
if
(
t
.
scaling_mode
==
NVTE_DELAYED_TENSOR_SCALING
&&
t
.
amax
.
has_data
()
)
{
NVTE_CHECK
(
t
.
amax
.
dtype
==
DType
::
kFloat32
,
"Invalid amax dtype (expected "
,
to_string
(
DType
::
kFloat32
),
", got "
,
to_string
(
t
.
amax
.
dtype
),
")"
);
NVTE_CHECK
(
product
(
t
.
amax
.
shape
)
==
1
,
"Invalid shape of amax in output "
,
name
,
NVTE_CHECK
(
t
.
amax
.
numel
(
)
==
1
,
"Invalid shape of amax in output "
,
name
,
" (expected 1 entry, got shape="
,
t
.
amax
.
shape
,
")"
);
}
if
(
t
.
has_data
())
{
NVTE_CHECK
(
t
.
scale_inv
.
dptr
!=
nullptr
,
"FP8 scaling factor output "
,
name
,
NVTE_CHECK
(
t
.
scale_inv
.
has_data
()
,
"FP8 scaling factor output "
,
name
,
"_scale_inverse must be allocated"
);
NVTE_CHECK
(
t
.
scale_inv
.
dtype
==
DType
::
kFloat32
||
t
.
scale_inv
.
dtype
==
DType
::
kFloat8E8M0
,
"FP8 scaling factor output "
,
name
,
...
...
@@ -233,7 +251,7 @@ void CheckOutputTensor(const Tensor &t, const std::string &name, bool allow_empt
to_string
(
t
.
scale_inv
.
dtype
),
")"
);
}
if
(
t
.
has_columnwise_data
())
{
NVTE_CHECK
(
t
.
columnwise_scale_inv
.
dptr
!=
nullptr
,
"FP8 scaling factor output "
,
name
,
NVTE_CHECK
(
t
.
columnwise_scale_inv
.
has_data
()
,
"FP8 scaling factor output "
,
name
,
"_columnwise_scale_inverse must be allocated"
);
NVTE_CHECK
(
t
.
columnwise_scale_inv
.
dtype
==
DType
::
kFloat32
||
t
.
columnwise_scale_inv
.
dtype
==
DType
::
kFloat8E8M0
,
...
...
@@ -245,7 +263,7 @@ void CheckOutputTensor(const Tensor &t, const std::string &name, bool allow_empt
}
else
if
(
is_fp4_dtype
(
type
))
{
// FP4 output needs to have the scale_inv
if
(
t
.
has_data
())
{
NVTE_CHECK
(
t
.
scale_inv
.
dptr
!=
nullptr
,
"FP4 scaling factor output "
,
name
,
NVTE_CHECK
(
t
.
scale_inv
.
has_data
()
,
"FP4 scaling factor output "
,
name
,
"_scale_inverse must be allocated"
);
NVTE_CHECK
(
t
.
scale_inv
.
dtype
==
DType
::
kFloat8E4M3
,
"FP4 scaling factor output "
,
name
,
"_scale_inverse has invalid dtype "
...
...
@@ -253,7 +271,7 @@ void CheckOutputTensor(const Tensor &t, const std::string &name, bool allow_empt
to_string
(
t
.
scale_inv
.
dtype
),
")"
);
}
if
(
t
.
has_columnwise_data
())
{
NVTE_CHECK
(
t
.
columnwise_scale_inv
.
dptr
!=
nullptr
,
"FP4 scaling factor output "
,
name
,
NVTE_CHECK
(
t
.
columnwise_scale_inv
.
has_data
()
,
"FP4 scaling factor output "
,
name
,
"_columnwise_scale_inverse must be allocated"
);
NVTE_CHECK
(
t
.
columnwise_scale_inv
.
dtype
==
DType
::
kFloat8E4M3
,
"FP4 scaling factor output "
,
name
,
...
...
@@ -262,12 +280,10 @@ void CheckOutputTensor(const Tensor &t, const std::string &name, bool allow_empt
to_string
(
t
.
columnwise_scale_inv
.
dtype
),
")"
);
}
}
else
{
NVTE_CHECK
(
t
.
scale
.
dptr
==
nullptr
,
"Scale is not supported for non-FP8 output "
,
name
);
// Unfused quant with level 2 nvfp4 scaling will produce high precision tensors with amax.
// NVTE_CHECK(t.amax.dptr == nullptr, "Amax is not supported for non-FP8 output ", name);
NVTE_CHECK
(
t
.
scale_inv
.
dptr
==
nullptr
,
"Scale_inv is not supported for non-FP8 output "
,
name
);
NVTE_CHECK
(
t
.
columnwise_scale_inv
.
dptr
==
nullptr
,
"Scale_inv is not supported for non-FP8 input "
,
name
);
NVTE_CHECK
(
!
t
.
scale
.
has_data
(),
"Scale is not supported for non-FP8 output "
,
name
);
NVTE_CHECK
(
!
t
.
scale_inv
.
has_data
(),
"Scale_inv is not supported for non-FP8 output "
,
name
);
NVTE_CHECK
(
!
t
.
columnwise_scale_inv
.
has_data
(),
"Scale_inv is not supported for non-FP8 input "
,
name
);
}
if
(
!
allow_empty
)
{
...
...
@@ -277,6 +293,128 @@ void CheckOutputTensor(const Tensor &t, const std::string &name, bool allow_empt
CheckScaleTensorShape
(
t
,
name
);
}
void
CheckGroupedTensorShapeArrays
(
const
GroupedTensor
&
t
,
const
std
::
string
&
name
)
{
NVTE_CHECK
(
t
.
num_tensors
>
0
,
"Grouped tensor "
,
name
,
" has no tensors!"
);
// Helper lambda to validate shape arrays
// All three arrays are OPTIONAL:
// - first_dims: empty if all tensors have same first dimension
// - last_dims: empty if all tensors have same last dimension
// - tensor_offsets: empty if all tensors have same shape (offsets are predictable)
auto
check_shape_array
=
[
&
](
const
SimpleTensor
&
arr
,
const
char
*
arr_name
)
{
if
(
arr
.
has_data
())
{
NVTE_CHECK
(
arr
.
shape
.
size
()
==
1
,
"Grouped tensor "
,
name
,
" "
,
arr_name
,
" must be 1D"
);
NVTE_CHECK
(
arr
.
dtype
==
DType
::
kInt64
,
"Grouped tensor "
,
name
,
" "
,
arr_name
,
" must have dtype Int64"
);
NVTE_CHECK
(
arr
.
shape
[
0
]
==
t
.
num_tensors
,
"Grouped tensor "
,
name
,
" "
,
arr_name
,
" size ("
,
arr
.
shape
[
0
],
") must equal num_tensors ("
,
t
.
num_tensors
,
")"
);
}
};
// Validate shape arrays (all optional)
check_shape_array
(
t
.
first_dims
,
"first_dims"
);
check_shape_array
(
t
.
last_dims
,
"last_dims"
);
check_shape_array
(
t
.
tensor_offsets
,
"tensor_offsets"
);
// tensor_offsets is required if any dimension varies
// (i.e., required unless all_same_shape())
if
(
!
t
.
all_same_shape
())
{
NVTE_CHECK
(
t
.
tensor_offsets
.
dptr
!=
nullptr
,
"Grouped tensor "
,
name
,
" must have tensor_offsets when any dimension varies (first_dims or last_dims is set)"
);
}
// Validate logical_shape
NVTE_CHECK
(
t
.
logical_shape
.
ndim
==
2
,
"Grouped tensor "
,
name
,
" logical_shape must be 2D"
);
NVTE_CHECK
(
t
.
logical_shape
.
data
[
0
]
>
0
&&
t
.
logical_shape
.
data
[
1
]
>
0
,
"Grouped tensor "
,
name
,
" logical_shape must have positive dimensions"
);
// Validate all data fields are 1D (flattened)
if
(
t
.
has_data
())
{
NVTE_CHECK
(
t
.
data
.
shape
.
size
()
==
1
,
"Grouped tensor "
,
name
,
" data must be 1D"
);
}
if
(
t
.
has_columnwise_data
())
{
NVTE_CHECK
(
t
.
columnwise_data
.
shape
.
size
()
==
1
,
"Grouped tensor "
,
name
,
" columnwise_data must be 1D"
);
}
// Validate data size matches logical_shape
size_t
expected_numel
=
t
.
logical_shape
.
data
[
0
]
*
t
.
logical_shape
.
data
[
1
];
if
(
t
.
has_data
())
{
NVTE_CHECK
(
t
.
data
.
numel
()
==
expected_numel
,
"Grouped tensor "
,
name
,
" data size ("
,
t
.
data
.
numel
(),
") must match logical_shape size ("
,
expected_numel
,
")"
);
}
if
(
t
.
has_columnwise_data
())
{
NVTE_CHECK
(
t
.
columnwise_data
.
numel
()
==
expected_numel
,
"Grouped tensor "
,
name
,
" columnwise_data size ("
,
t
.
columnwise_data
.
numel
(),
") must match logical_shape size ("
,
expected_numel
,
")"
);
}
}
// Helper function to check scale_inv for both input and output
static
void
CheckGroupedScaleInv
(
const
GroupedTensor
&
t
,
const
std
::
string
&
name
,
bool
is_output
)
{
const
char
*
tensor_type
=
is_output
?
"output"
:
"input"
;
// Helper to check scale_inv for both rowwise and columnwise layouts
auto
check_scales
=
[
&
](
DType
expected_dtype
)
{
if
(
t
.
has_data
())
{
NVTE_CHECK
(
t
.
scale_inv
.
has_data
(),
tensor_type
,
" "
,
name
,
" rowwise scale_inv must be allocated"
);
NVTE_CHECK
(
t
.
scale_inv
.
dtype
==
expected_dtype
,
tensor_type
,
" "
,
name
,
" rowwise scale_inv has invalid dtype (expected "
,
to_string
(
expected_dtype
),
", got "
,
to_string
(
t
.
scale_inv
.
dtype
),
")"
);
}
if
(
t
.
has_columnwise_data
())
{
NVTE_CHECK
(
t
.
columnwise_scale_inv
.
has_data
(),
tensor_type
,
" "
,
name
,
" columnwise scale_inv must be allocated"
);
NVTE_CHECK
(
t
.
columnwise_scale_inv
.
dtype
==
expected_dtype
,
tensor_type
,
" "
,
name
,
" columnwise scale_inv has invalid dtype (expected "
,
to_string
(
expected_dtype
),
", got "
,
to_string
(
t
.
columnwise_scale_inv
.
dtype
),
")"
);
}
};
// Determine expected dtype based on data type and scaling mode
if
(
is_fp8_dtype
(
t
.
dtype
())
&&
is_tensor_scaling
(
t
.
scaling_mode
))
{
check_scales
(
DType
::
kFloat32
);
}
else
if
(
is_mxfp8_scaling
(
t
.
scaling_mode
))
{
check_scales
(
DType
::
kFloat8E8M0
);
}
else
if
(
is_nvfp4_scaling
(
t
.
scaling_mode
))
{
check_scales
(
DType
::
kFloat8E4M3
);
}
else
{
// Non-quantized types should not have scale/scale_inv
NVTE_CHECK
(
!
t
.
scale_inv
.
has_data
(),
"Scale_inv not supported for non-quantized "
,
tensor_type
,
" "
,
name
);
NVTE_CHECK
(
!
t
.
columnwise_scale_inv
.
has_data
(),
"Scale_inv not supported for non-quantized "
,
tensor_type
,
" "
,
name
);
}
}
void
CheckInputGroupedTensor
(
const
GroupedTensor
&
t
,
const
std
::
string
&
name
)
{
NVTE_CHECK
(
t
.
has_data
()
||
t
.
has_columnwise_data
(),
"Input grouped tensor "
,
name
,
" not allocated"
);
CheckGroupedScaleInv
(
t
,
name
,
false
);
CheckGroupedTensorShapeArrays
(
t
,
name
);
}
void
CheckOutputGroupedTensor
(
const
GroupedTensor
&
t
,
const
std
::
string
&
name
,
bool
allow_empty
)
{
if
(
!
allow_empty
)
{
NVTE_CHECK
(
t
.
has_data
()
||
t
.
has_columnwise_data
(),
"Output grouped tensor "
,
name
,
" not allocated"
);
}
// Only perform dtype-specific validation if data is allocated
if
(
t
.
has_data
()
||
t
.
has_columnwise_data
())
{
// Amax validation for delayed scaling
if
(
is_fp8_dtype
(
t
.
dtype
())
&&
t
.
scaling_mode
==
NVTE_DELAYED_TENSOR_SCALING
)
{
NVTE_CHECK
(
t
.
amax
.
has_data
(),
"Output "
,
name
,
" amax must be allocated"
);
NVTE_CHECK
(
t
.
amax
.
dtype
==
DType
::
kFloat32
,
"Output "
,
name
,
" amax must be Float32"
);
}
CheckGroupedScaleInv
(
t
,
name
,
true
);
}
CheckGroupedTensorShapeArrays
(
t
,
name
);
}
class
TensorAllocator
{
public:
static
TensorAllocator
&
instance
()
{
...
...
@@ -391,6 +529,89 @@ Tensor *convertNVTETensorCheck(const NVTETensor t) {
return
ptr
;
}
// GroupedTensor allocator - similar pattern to TensorAllocator
class
GroupedTensorAllocator
{
public:
static
GroupedTensorAllocator
&
instance
()
{
static
GroupedTensorAllocator
allocator
;
return
allocator
;
}
~
GroupedTensorAllocator
()
{}
NVTEGroupedTensor
Allocate
(
NVTEScalingMode
mode
,
size_t
num_tensors
,
NVTEShape
logical_shape
)
{
std
::
lock_guard
<
std
::
mutex
>
lock
(
mutex
);
if
(
!
free_list
.
empty
())
{
uintptr_t
index
=
free_list
.
back
();
NVTEGroupedTensor
ret
=
reinterpret_cast
<
NVTEGroupedTensor
>
(
index
);
free_list
.
pop_back
();
// 1-based indexing - fully reinitialize the tensor to avoid stale data
memory
[
index
-
1
].
scaling_mode
=
mode
;
memory
[
index
-
1
].
num_tensors
=
num_tensors
;
memory
[
index
-
1
].
logical_shape
=
logical_shape
;
memory
[
index
-
1
].
nvte_tensor
=
ret
;
return
ret
;
}
if
(
memory
.
size
()
<
memory
.
capacity
())
{
memory
.
emplace_back
(
mode
,
num_tensors
);
GroupedTensor
&
t
=
memory
.
back
();
size
=
memory
.
size
();
// 1-based indexing
uintptr_t
index
=
memory
.
size
();
t
.
logical_shape
=
logical_shape
;
t
.
nvte_tensor
=
reinterpret_cast
<
NVTEGroupedTensor
>
(
index
);
return
reinterpret_cast
<
NVTEGroupedTensor
>
(
index
);
}
NVTE_ERROR
(
"Cannot allocate a new NVTEGroupedTensor. Maximum number of grouped tensors reached: "
,
MAX_GROUPED_TENSOR_NUM
,
". There is probably a memory leak in your application."
);
}
void
Free
(
NVTEGroupedTensor
t
)
{
std
::
lock_guard
<
std
::
mutex
>
lock
(
mutex
);
uintptr_t
index
=
reinterpret_cast
<
uintptr_t
>
(
t
);
if
(
index
==
0
)
return
;
NVTE_CHECK
(
index
<=
memory
.
size
(),
"Invalid grouped tensor."
);
free_list
.
push_back
(
index
);
// Clean up
memory
[
index
-
1
].
clear
();
}
GroupedTensor
*
convertNVTEGroupedTensor
(
NVTEGroupedTensor
t
)
{
uintptr_t
index
=
reinterpret_cast
<
uintptr_t
>
(
t
);
// 1-based indexing to enable 0-initialization of NVTEGroupedTensor
// to be invalid tensor
static_assert
(
nullptr
==
0
);
if
(
index
!=
0
&&
index
<=
size
)
{
return
&
(
memory
[
index
-
1
]);
}
return
nullptr
;
}
private:
GroupedTensorAllocator
()
{
std
::
lock_guard
<
std
::
mutex
>
lock
(
mutex
);
memory
.
reserve
(
MAX_GROUPED_TENSOR_NUM
);
}
std
::
mutex
mutex
;
std
::
atomic
<
size_t
>
size
;
// Allocate at most 20 MB for grouped tensors
const
size_t
MAX_GROUPED_TENSOR_NUM
=
20
*
1024
*
1024
/
sizeof
(
GroupedTensor
);
std
::
vector
<
uintptr_t
>
free_list
;
std
::
vector
<
GroupedTensor
>
memory
;
};
GroupedTensor
*
convertNVTEGroupedTensor
(
const
NVTEGroupedTensor
t
)
{
return
GroupedTensorAllocator
::
instance
().
convertNVTEGroupedTensor
(
t
);
}
GroupedTensor
*
convertNVTEGroupedTensorCheck
(
const
NVTEGroupedTensor
t
)
{
GroupedTensor
*
ptr
=
GroupedTensorAllocator
::
instance
().
convertNVTEGroupedTensor
(
t
);
NVTE_CHECK
(
ptr
!=
nullptr
,
"Invalid grouped tensor."
);
return
ptr
;
}
}
// namespace transformer_engine
NVTETensor
nvte_create_tensor
(
NVTEScalingMode
scaling_mode
)
{
...
...
@@ -427,7 +648,11 @@ NVTEShape nvte_make_shape(const size_t *data, size_t ndim) {
NVTE_CHECK
(
ndim
<=
sizeof
(
ret
.
data
)
/
sizeof
(
ret
.
data
[
0
]),
"Too many dims for NVTEShape (requested: "
,
ndim
,
", max: "
,
sizeof
(
ret
.
data
)
/
sizeof
(
ret
.
data
[
0
]),
")"
);
std
::
copy
(
data
,
data
+
ndim
,
ret
.
data
);
if
(
data
==
nullptr
)
{
std
::
fill
(
ret
.
data
,
ret
.
data
+
ndim
,
0
);
}
else
{
std
::
copy
(
data
,
data
+
ndim
,
ret
.
data
);
}
ret
.
ndim
=
ndim
;
return
ret
;
}
...
...
@@ -540,7 +765,7 @@ void *nvte_tensor_columnwise_scale_inv(const NVTETensor tensor) {
NVTEShape
nvte_tensor_scale_inv_shape
(
const
NVTETensor
tensor
)
{
auto
*
t
=
transformer_engine
::
convertNVTETensor
(
tensor
);
if
(
t
==
nullptr
)
{
return
nvte_make_shape
(
nullptr
,
0
);
return
nvte_make_shape
(
nullptr
,
1
);
}
return
nvte_make_shape
(
t
->
scale_inv
.
shape
.
data
(),
t
->
scale_inv
.
shape
.
size
());
}
...
...
@@ -573,13 +798,14 @@ void nvte_set_tensor_param(NVTETensor *tensor, NVTETensorParam param_name,
t
->
columnwise_amax
=
*
param
;
break
;
default:
NVTE_ERROR
(
"Unknown tensor parameter!"
);
NVTE_ERROR
(
"Unsupported tensor parameter ("
,
static_cast
<
int
>
(
param_name
),
"). Consider using nvte_set_tensor_param_v2 instead."
);
}
}
NVTEBasicTensor
nvte_get_tensor_param
(
const
NVTETensor
tensor
,
NVTETensorParam
param_name
)
{
if
(
tensor
==
nullptr
)
{
return
{
nullptr
,
kNVTEFloat32
,
nvte_make_shape
(
nullptr
,
0
)};
return
{
nullptr
,
kNVTEFloat32
,
nvte_make_shape
(
nullptr
,
1
)};
}
const
auto
&
t
=
*
transformer_engine
::
convertNVTETensorCheck
(
tensor
);
switch
(
param_name
)
{
...
...
@@ -598,7 +824,148 @@ NVTEBasicTensor nvte_get_tensor_param(const NVTETensor tensor, NVTETensorParam p
case
kNVTEColumnwiseAmax
:
return
t
.
columnwise_amax
;
default:
NVTE_ERROR
(
"Unknown tensor parameter!"
);
NVTE_ERROR
(
"Unsupported tensor parameter ("
,
static_cast
<
int
>
(
param_name
),
"). Consider using nvte_set_tensor_param_v2 instead."
);
}
}
void
nvte_set_tensor_param_v2
(
NVTETensor
tensor
,
NVTETensorParam
param
,
const
void
*
buf
,
size_t
size_in_bytes
)
{
// Check attribute and buffer
NVTE_CHECK
(
param
<
kNVTENumTensorParams
,
"Invalid NVTETensorParam (got "
,
static_cast
<
int
>
(
param
),
")"
);
NVTE_CHECK
(
tensor
!=
nullptr
,
"Tensor pointer can't be NULL."
);
auto
&
t
=
*
transformer_engine
::
convertNVTETensorCheck
(
tensor
);
const
auto
&
attr_size
=
transformer_engine
::
Tensor
::
attr_sizes
[
param
];
NVTE_CHECK
(
size_in_bytes
>=
attr_size
,
"Buffer is too small for tensor parameter "
"(parameter "
,
static_cast
<
int
>
(
param
),
" needs "
,
attr_size
,
" bytes, but buffer has "
,
size_in_bytes
,
" bytes)"
);
NVTE_CHECK
(
buf
!=
nullptr
,
"Invalid buffer (got NULL)"
);
// Read from buffer
switch
(
param
)
{
case
kNVTERowwiseData
:
{
const
NVTEBasicTensor
*
basic_tensor
=
reinterpret_cast
<
const
NVTEBasicTensor
*>
(
buf
);
t
.
data
=
*
basic_tensor
;
break
;
}
case
kNVTEColumnwiseData
:
{
const
NVTEBasicTensor
*
basic_tensor
=
reinterpret_cast
<
const
NVTEBasicTensor
*>
(
buf
);
t
.
columnwise_data
=
*
basic_tensor
;
break
;
}
case
kNVTEScale
:
{
const
NVTEBasicTensor
*
basic_tensor
=
reinterpret_cast
<
const
NVTEBasicTensor
*>
(
buf
);
t
.
scale
=
*
basic_tensor
;
break
;
}
case
kNVTEAmax
:
{
const
NVTEBasicTensor
*
basic_tensor
=
reinterpret_cast
<
const
NVTEBasicTensor
*>
(
buf
);
t
.
amax
=
*
basic_tensor
;
break
;
}
case
kNVTERowwiseScaleInv
:
{
const
NVTEBasicTensor
*
basic_tensor
=
reinterpret_cast
<
const
NVTEBasicTensor
*>
(
buf
);
t
.
scale_inv
=
*
basic_tensor
;
break
;
}
case
kNVTEColumnwiseScaleInv
:
{
const
NVTEBasicTensor
*
basic_tensor
=
reinterpret_cast
<
const
NVTEBasicTensor
*>
(
buf
);
t
.
columnwise_scale_inv
=
*
basic_tensor
;
break
;
}
case
kNVTEColumnwiseAmax
:
{
const
NVTEBasicTensor
*
basic_tensor
=
reinterpret_cast
<
const
NVTEBasicTensor
*>
(
buf
);
t
.
columnwise_amax
=
*
basic_tensor
;
break
;
}
case
kNVTEWithGEMMSwizzledScales
:
t
.
with_gemm_swizzled_scales
=
static_cast
<
bool
>
(
*
reinterpret_cast
<
const
uint8_t
*>
(
buf
));
break
;
default:
NVTE_ERROR
(
"Unsupported tensor parameter ("
,
static_cast
<
int
>
(
param
),
")"
);
}
}
void
nvte_get_tensor_param_v2
(
const
NVTETensor
tensor
,
NVTETensorParam
param
,
void
*
buf
,
size_t
size_in_bytes
,
size_t
*
size_written
)
{
using
namespace
transformer_engine
;
// Check param
NVTE_CHECK
(
param
<
kNVTENumTensorParams
,
"Invalid NVTETensorParam (got "
,
static_cast
<
int
>
(
param
),
")"
);
// Write attribute size if provided
const
auto
&
attr_size
=
Tensor
::
attr_sizes
[
param
];
if
(
size_written
!=
nullptr
)
{
*
size_written
=
attr_size
;
}
// Return immediately if buffer is not provided
if
(
buf
==
nullptr
)
{
return
;
}
// Check buffer size
NVTE_CHECK
(
size_in_bytes
>=
attr_size
,
"Buffer is too small for tensor parameter "
"(parameter "
,
static_cast
<
int
>
(
param
),
" needs "
,
attr_size
,
" bytes, but buffer has "
,
size_in_bytes
,
" bytes)"
);
// Get C++ tensor
const
Tensor
*
t
=
convertNVTETensor
(
tensor
);
std
::
optional
<
Tensor
>
dummy
;
if
(
t
==
nullptr
)
{
// Make dummy tensor if provided tensor is invalid
dummy
.
emplace
();
t
=
&
(
*
dummy
);
}
// Write to buffer
switch
(
param
)
{
case
kNVTERowwiseData
:
{
NVTEBasicTensor
*
basic_tensor
=
reinterpret_cast
<
NVTEBasicTensor
*>
(
buf
);
*
basic_tensor
=
static_cast
<
NVTEBasicTensor
>
(
t
->
data
);
break
;
}
case
kNVTEColumnwiseData
:
{
NVTEBasicTensor
*
basic_tensor
=
reinterpret_cast
<
NVTEBasicTensor
*>
(
buf
);
*
basic_tensor
=
static_cast
<
NVTEBasicTensor
>
(
t
->
columnwise_data
);
break
;
}
case
kNVTEScale
:
{
NVTEBasicTensor
*
basic_tensor
=
reinterpret_cast
<
NVTEBasicTensor
*>
(
buf
);
*
basic_tensor
=
static_cast
<
NVTEBasicTensor
>
(
t
->
scale
);
break
;
}
case
kNVTEAmax
:
{
NVTEBasicTensor
*
basic_tensor
=
reinterpret_cast
<
NVTEBasicTensor
*>
(
buf
);
*
basic_tensor
=
static_cast
<
NVTEBasicTensor
>
(
t
->
amax
);
break
;
}
case
kNVTERowwiseScaleInv
:
{
NVTEBasicTensor
*
basic_tensor
=
reinterpret_cast
<
NVTEBasicTensor
*>
(
buf
);
*
basic_tensor
=
static_cast
<
NVTEBasicTensor
>
(
t
->
scale_inv
);
break
;
}
case
kNVTEColumnwiseScaleInv
:
{
NVTEBasicTensor
*
basic_tensor
=
reinterpret_cast
<
NVTEBasicTensor
*>
(
buf
);
*
basic_tensor
=
static_cast
<
NVTEBasicTensor
>
(
t
->
columnwise_scale_inv
);
break
;
}
case
kNVTEColumnwiseAmax
:
{
NVTEBasicTensor
*
basic_tensor
=
reinterpret_cast
<
NVTEBasicTensor
*>
(
buf
);
*
basic_tensor
=
static_cast
<
NVTEBasicTensor
>
(
t
->
columnwise_amax
);
break
;
}
case
kNVTEWithGEMMSwizzledScales
:
*
reinterpret_cast
<
uint8_t
*>
(
buf
)
=
static_cast
<
uint8_t
>
(
t
->
with_gemm_swizzled_scales
);
break
;
default:
NVTE_ERROR
(
"Unsupported tensor parameter ("
,
static_cast
<
int
>
(
param
),
")"
);
}
}
...
...
@@ -624,14 +991,21 @@ void nvte_tensor_pack_destroy(NVTETensorPack *pack) {
void
nvte_zero_tensor
(
const
NVTETensor
tensor
,
cudaStream_t
stream
)
{
if
(
tensor
==
nullptr
)
return
;
const
auto
&
t
=
*
transformer_engine
::
convertNVTETensorCheck
(
tensor
);
// Zero out tensor data if allocated
if
(
t
.
data
.
dptr
!=
nullptr
)
{
const
size_t
size_in_bytes
=
nvte_tensor_size_bytes
(
tensor
);
NVTE_CHECK_CUDA
(
cudaMemsetAsync
(
t
.
data
.
dptr
,
0
,
size_in_bytes
,
stream
));
const
auto
size
=
t
.
data
.
buffer_size_bytes
();
if
(
size
>
0
)
{
NVTE_CHECK_CUDA
(
cudaMemsetAsync
(
t
.
data
.
dptr
,
0
,
size
,
stream
));
}
}
// Set amax to 0 if allocated
// Zero out amax if allocated
if
(
t
.
amax
.
dptr
!=
nullptr
)
{
NVTE_CHECK_CUDA
(
cudaMemsetAsync
(
t
.
amax
.
dptr
,
0
,
sizeof
(
float
),
stream
));
const
auto
size
=
t
.
amax
.
buffer_size_bytes
();
if
(
size
>
0
)
{
NVTE_CHECK_CUDA
(
cudaMemsetAsync
(
t
.
amax
.
dptr
,
0
,
size
,
stream
));
}
}
}
...
...
@@ -642,12 +1016,15 @@ NVTEQuantizationConfig nvte_create_quantization_config() {
void
nvte_get_quantization_config_attribute
(
NVTEQuantizationConfig
config
,
NVTEQuantizationConfigAttribute
attr
,
void
*
buf
,
size_t
size_in_bytes
,
size_t
*
size_written
)
{
using
namespace
transformer_engine
;
// Write attribute size
NVTE_CHECK
(
attr
<
kNVTEQuantizationConfigNumAttributes
,
"Invalid NVTEQuantizationConfigAttribute (got "
,
static_cast
<
int
>
(
attr
),
")"
);
NVTE_CHECK
(
size_written
!=
nullptr
,
"Invalid size_written (got NULL)"
);
const
auto
&
attr_size
=
transformer_engine
::
QuantizationConfig
::
attr_sizes
[
attr
];
*
size_written
=
attr_size
;
const
auto
&
attr_size
=
QuantizationConfig
::
attr_sizes
[
attr
];
if
(
size_written
!=
nullptr
)
{
*
size_written
=
attr_size
;
}
// Return immediately if buffer is not provided
if
(
buf
==
nullptr
)
{
...
...
@@ -661,12 +1038,18 @@ void nvte_get_quantization_config_attribute(NVTEQuantizationConfig config,
static_cast
<
int
>
(
attr
),
" needs "
,
attr_size
,
" bytes, but buffer has "
,
size_in_bytes
,
" bytes)"
);
// bool size is implementation-dependent, so we explicitly specify
// uint8_t in the user-facing API.
auto
bool_to_uint8
=
[](
bool
in
,
void
*
out
)
{
*
reinterpret_cast
<
uint8_t
*>
(
out
)
=
static_cast
<
uint8_t
>
(
in
);
};
// Write to buffer
NVTE_CHECK
(
config
!=
nullptr
,
"Invalid NVTEQuantizationConfig (got NULL)"
);
const
auto
&
config_
=
*
reinterpret_cast
<
const
transformer_engine
::
QuantizationConfig
*>
(
config
);
const
auto
&
config_
=
*
reinterpret_cast
<
const
QuantizationConfig
*>
(
config
);
switch
(
attr
)
{
case
kNVTEQuantizationConfigForcePow2Scales
:
std
::
memcpy
(
buf
,
&
config_
.
force_pow_2_scales
,
attr_size
);
bool_to_uint8
(
config_
.
force_pow_2_scales
,
buf
);
break
;
case
kNVTEQuantizationConfigAmaxEpsilon
:
std
::
memcpy
(
buf
,
&
config_
.
amax_epsilon
,
attr_size
);
...
...
@@ -674,8 +1057,23 @@ void nvte_get_quantization_config_attribute(NVTEQuantizationConfig config,
case
kNVTEQuantizationConfigNoopTensor
:
std
::
memcpy
(
buf
,
&
config_
.
noop_tensor
,
attr_size
);
break
;
case
kNVTEQuantizationConfigFloat8BlockScaleTensorFormat
:
std
::
memcpy
(
buf
,
&
config_
.
float8_block_scale_tensor_format
,
attr_size
);
case
kNVTEQuantizationConfigFloat8BlockScaleTensorFormat
:
{
// Deprecated
const
auto
invalid
=
Float8BlockScaleTensorFormat
::
INVALID
;
std
::
memcpy
(
buf
,
&
invalid
,
attr_size
);
break
;
}
case
kNVTEQuantizationConfigRNGState
:
std
::
memcpy
(
buf
,
&
config_
.
rng_state
,
attr_size
);
break
;
case
kNVTEQuantizationConfigNVFP42DQuantization
:
bool_to_uint8
(
config_
.
nvfp4_2d_quantization
,
buf
);
break
;
case
kNVTEQuantizationConfigStochasticRounding
:
bool_to_uint8
(
config_
.
stochastic_rounding
,
buf
);
break
;
case
kNVTEQuantizationConfigUseFastMath
:
bool_to_uint8
(
config_
.
use_fast_math
,
buf
);
break
;
default:
NVTE_ERROR
(
"Unsupported NVTEQuantizationConfigAttribute (got "
,
static_cast
<
int
>
(
attr
),
")"
);
...
...
@@ -685,10 +1083,12 @@ void nvte_get_quantization_config_attribute(NVTEQuantizationConfig config,
void
nvte_set_quantization_config_attribute
(
NVTEQuantizationConfig
config
,
NVTEQuantizationConfigAttribute
attr
,
const
void
*
buf
,
size_t
size_in_bytes
)
{
using
namespace
transformer_engine
;
// Check attribute and buffer
NVTE_CHECK
(
attr
<
kNVTEQuantizationConfigNumAttributes
,
"Invalid NVTEQuantizationConfigAttribute (got "
,
static_cast
<
int
>
(
attr
),
")"
);
const
auto
&
attr_size
=
transformer_engine
::
QuantizationConfig
::
attr_sizes
[
attr
];
const
auto
&
attr_size
=
QuantizationConfig
::
attr_sizes
[
attr
];
NVTE_CHECK
(
size_in_bytes
>=
attr_size
,
"Buffer is too small for quantization config attribute "
"(attribute "
,
...
...
@@ -696,12 +1096,18 @@ void nvte_set_quantization_config_attribute(NVTEQuantizationConfig config,
" bytes)"
);
NVTE_CHECK
(
buf
!=
nullptr
,
"Invalid buffer (got NULL)"
);
// bool size is implementation-dependent, so we explicitly specify
// uint8_t in the user-facing API.
auto
uint8_to_bool
=
[](
const
void
*
in
,
bool
&
out
)
{
out
=
static_cast
<
bool
>
(
*
reinterpret_cast
<
const
uint8_t
*>
(
in
));
};
// Read from buffer
NVTE_CHECK
(
config
!=
nullptr
,
"Invalid NVTEQuantizationConfig (got NULL)"
);
auto
&
config_
=
*
reinterpret_cast
<
transformer_engine
::
QuantizationConfig
*>
(
config
);
auto
&
config_
=
*
reinterpret_cast
<
QuantizationConfig
*>
(
config
);
switch
(
attr
)
{
case
kNVTEQuantizationConfigForcePow2Scales
:
std
::
memcpy
(
&
config_
.
force_pow_2_scales
,
buf
,
attr_size
);
uint8_to_bool
(
buf
,
config_
.
force_pow_2_scales
);
break
;
case
kNVTEQuantizationConfigAmaxEpsilon
:
std
::
memcpy
(
&
config_
.
amax_epsilon
,
buf
,
attr_size
);
...
...
@@ -710,16 +1116,19 @@ void nvte_set_quantization_config_attribute(NVTEQuantizationConfig config,
std
::
memcpy
(
&
config_
.
noop_tensor
,
buf
,
attr_size
);
break
;
case
kNVTEQuantizationConfigFloat8BlockScaleTensorFormat
:
std
::
memcpy
(
&
config_
.
float8_block_scale_tensor_format
,
buf
,
attr_size
);
// Deprecated
break
;
case
kNVTEQuantizationConfigRNGState
:
std
::
memcpy
(
&
config_
.
rng_state
,
buf
,
attr_size
);
break
;
case
kNVTEQuantizationConfigNVFP42DQuantization
:
std
::
memcpy
(
&
config_
.
nvfp4_2d_quantization
,
buf
,
attr_size
);
uint8_to_bool
(
buf
,
config_
.
nvfp4_2d_quantization
);
break
;
case
kNVTEQuantizationConfigStochasticRounding
:
std
::
memcpy
(
&
config_
.
stochastic_rounding
,
buf
,
attr_size
);
uint8_to_bool
(
buf
,
config_
.
stochastic_rounding
);
break
;
case
kNVTEQuantizationConfigUseFastMath
:
uint8_to_bool
(
buf
,
config_
.
use_fast_math
);
break
;
default:
NVTE_ERROR
(
"Unsupported NVTEQuantizationConfigAttribute (got "
,
static_cast
<
int
>
(
attr
),
")"
);
...
...
@@ -736,12 +1145,146 @@ int nvte_is_non_tn_fp8_gemm_supported() {
#if USE_ROCM
return
true
;
#else
int
deviceComputeCapability
=
transformer_engine
::
cuda
::
sm_arch
(
transformer_engine
::
cuda
::
current_device
());
// Note: this is temporary restriction and should be lifted in the future.
// (remove the note once it's done.)
return
(
deviceComputeCapability
>=
100
&&
deviceComputeCapability
<
120
)
||
deviceComputeCapability
>=
130
;
int
num_devices
=
transformer_engine
::
cuda
::
num_devices
();
static
std
::
vector
<
int
>
cache
(
num_devices
,
-
1
);
static
std
::
vector
<
std
::
once_flag
>
flags
(
num_devices
);
int
device_id
=
transformer_engine
::
cuda
::
current_device
();
std
::
call_once
(
flags
[
device_id
],
[
&
]()
{
int
deviceComputeCapability
=
transformer_engine
::
cuda
::
sm_arch
(
device_id
);
// Note: this is temporary restriction and should be lifted in the future.
// (remove the note once it's done.)
cache
[
device_id
]
=
(
deviceComputeCapability
>=
100
&&
deviceComputeCapability
<
120
)
||
deviceComputeCapability
>=
130
;
});
return
cache
[
device_id
];
#endif
}
// Grouped Tensor C API implementations
NVTEGroupedTensor
nvte_create_grouped_tensor
(
NVTEScalingMode
scaling_mode
,
size_t
num_tensors
,
NVTEShape
logical_shape
)
{
NVTE_CHECK
(
num_tensors
>
0
,
"Number of tensors must be greater than 0"
);
NVTE_CHECK
(
logical_shape
.
ndim
==
2
,
"Logical shape must be 2D"
);
NVTE_CHECK
(
logical_shape
.
data
[
0
]
>
0
&&
logical_shape
.
data
[
1
]
>
0
,
"Logical shape must have positive dimensions"
);
NVTEGroupedTensor
ret
=
transformer_engine
::
GroupedTensorAllocator
::
instance
().
Allocate
(
scaling_mode
,
num_tensors
,
logical_shape
);
return
ret
;
}
void
nvte_destroy_grouped_tensor
(
NVTEGroupedTensor
tensor
)
{
transformer_engine
::
GroupedTensorAllocator
::
instance
().
Free
(
tensor
);
}
void
nvte_set_grouped_tensor_param
(
NVTEGroupedTensor
*
tensor
,
NVTEGroupedTensorParam
param_name
,
const
NVTEBasicTensor
*
param
)
{
NVTE_CHECK
(
tensor
!=
nullptr
,
"Grouped tensor pointer can't be NULL."
);
auto
*
t
=
transformer_engine
::
convertNVTEGroupedTensor
(
*
tensor
);
NVTE_CHECK
(
t
!=
nullptr
,
"Grouped tensor is not allocated."
);
NVTE_CHECK
(
param
!=
nullptr
,
"Grouped tensor param can't be NULL."
);
switch
(
param_name
)
{
case
kNVTEGroupedRowwiseData
:
t
->
data
=
*
param
;
break
;
case
kNVTEGroupedColumnwiseData
:
t
->
columnwise_data
=
*
param
;
break
;
case
kNVTEGroupedScale
:
t
->
scale
=
*
param
;
break
;
case
kNVTEGroupedAmax
:
t
->
amax
=
*
param
;
break
;
case
kNVTEGroupedRowwiseScaleInv
:
t
->
scale_inv
=
*
param
;
break
;
case
kNVTEGroupedColumnwiseScaleInv
:
t
->
columnwise_scale_inv
=
*
param
;
break
;
case
kNVTEGroupedColumnwiseAmax
:
t
->
columnwise_amax
=
*
param
;
break
;
case
kNVTEGroupedFirstDims
:
t
->
first_dims
=
*
param
;
// Validate it's Int64
NVTE_CHECK
(
t
->
first_dims
.
dtype
==
transformer_engine
::
DType
::
kInt64
,
"first_dims must have dtype Int64"
);
break
;
case
kNVTEGroupedLastDims
:
t
->
last_dims
=
*
param
;
// Validate it's Int64
NVTE_CHECK
(
t
->
last_dims
.
dtype
==
transformer_engine
::
DType
::
kInt64
,
"last_dims must have dtype Int64"
);
break
;
case
kNVTEGroupedTensorOffsets
:
t
->
tensor_offsets
=
*
param
;
// Validate it's Int64
NVTE_CHECK
(
t
->
tensor_offsets
.
dtype
==
transformer_engine
::
DType
::
kInt64
,
"tensor_offsets must have dtype Int64"
);
break
;
default:
NVTE_ERROR
(
"Unknown grouped tensor parameter!"
);
}
}
NVTEBasicTensor
nvte_get_grouped_tensor_param
(
const
NVTEGroupedTensor
tensor
,
NVTEGroupedTensorParam
param_name
)
{
if
(
tensor
==
nullptr
)
{
return
{
nullptr
,
kNVTEFloat32
,
nvte_make_shape
(
nullptr
,
1
)};
}
const
auto
&
t
=
*
transformer_engine
::
convertNVTEGroupedTensorCheck
(
tensor
);
switch
(
param_name
)
{
case
kNVTEGroupedRowwiseData
:
return
t
.
data
;
case
kNVTEGroupedColumnwiseData
:
return
t
.
columnwise_data
;
case
kNVTEGroupedScale
:
return
t
.
scale
;
case
kNVTEGroupedAmax
:
return
t
.
amax
;
case
kNVTEGroupedRowwiseScaleInv
:
return
t
.
scale_inv
;
case
kNVTEGroupedColumnwiseScaleInv
:
return
t
.
columnwise_scale_inv
;
case
kNVTEGroupedColumnwiseAmax
:
return
t
.
columnwise_amax
;
case
kNVTEGroupedFirstDims
:
return
t
.
first_dims
;
case
kNVTEGroupedLastDims
:
return
t
.
last_dims
;
case
kNVTEGroupedTensorOffsets
:
return
t
.
tensor_offsets
;
default:
NVTE_ERROR
(
"Unknown grouped tensor parameter!"
);
}
}
size_t
nvte_grouped_tensor_num_tensors
(
const
NVTEGroupedTensor
tensor
)
{
auto
*
t
=
transformer_engine
::
convertNVTEGroupedTensor
(
tensor
);
if
(
t
==
nullptr
)
return
0
;
return
t
->
num_tensors
;
}
NVTEDType
nvte_grouped_tensor_type
(
const
NVTEGroupedTensor
tensor
)
{
auto
*
t
=
transformer_engine
::
convertNVTEGroupedTensor
(
tensor
);
if
(
t
==
nullptr
)
return
kNVTEFloat32
;
return
static_cast
<
NVTEDType
>
(
t
->
dtype
());
}
NVTEScalingMode
nvte_grouped_tensor_scaling_mode
(
const
NVTEGroupedTensor
tensor
)
{
if
(
tensor
==
nullptr
)
{
return
NVTE_DELAYED_TENSOR_SCALING
;
}
const
auto
&
t
=
*
transformer_engine
::
convertNVTEGroupedTensorCheck
(
tensor
);
return
t
.
scaling_mode
;
}
NVTEShape
nvte_get_grouped_tensor_logical_shape
(
const
NVTEGroupedTensor
tensor
)
{
if
(
tensor
==
nullptr
)
{
return
nvte_make_shape
(
nullptr
,
1
);
}
const
auto
&
t
=
*
transformer_engine
::
convertNVTEGroupedTensorCheck
(
tensor
);
return
t
.
logical_shape
;
}
transformer_engine/common/transpose/cast_transpose.cu
View file @
0d874a4e
/*************************************************************************
* Copyright (c) 2022-202
5
, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* Copyright (c) 2022-202
6
, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
*
* See LICENSE for license information.
************************************************************************/
...
...
transformer_engine/common/transpose/cast_transpose.h
View file @
0d874a4e
/*************************************************************************
* Copyright (c) 2022-202
5
, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* Copyright (c) 2022-202
6
, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
*
* See LICENSE for license information.
************************************************************************/
...
...
@@ -36,7 +36,7 @@ enum class FP8BlockwiseRowwiseOption {
NONE
,
// Rowwise data, scales in GEMM format
ROWWISE_GEMM_READY
,
//
Rowwise data, scales in compact format, needs extra processing (padding, transposing) before GEMM
//
Deprecated
ROWWISE_COMPACT
};
...
...
@@ -50,8 +50,7 @@ enum class FP8BlockwiseColumnwiseOption {
// On Hopper sm90, GEMM_READY means that columnwise quantization also fuses transpose op
// On higher sm versions with TN,NT,NN fp8 gemm, GEMM_READY doesn't fuse transpose
COLUMNWISE_GEMM_READY
,
// Columnwise data in original shape
// Scales in compact format, needs extra processing (padding, transposing) before GEMM
// Deprecated
COLUMNWISE_COMPACT
};
...
...
transformer_engine/common/transpose/cast_transpose_fusion.cu
View file @
0d874a4e
/*************************************************************************
* Copyright (c) 2022-202
5
, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* Copyright (c) 2022-202
6
, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
*
* See LICENSE for license information.
************************************************************************/
...
...
@@ -203,8 +203,6 @@ void populate_cast_transpose_dbias_workspace_config(const Tensor &cast_output, /
workspace
->
data
.
dtype
);
const
size_t
required_size
=
get_buffer_size_bytes
(
num_rows_partial_dbias
,
row_length
,
DType
::
kFloat32
);
NVTE_CHECK
(
!
workspace
->
data
.
shape
.
empty
(),
"Invalid workspace dims (expected ("
,
num_rows_partial_dbias
,
","
,
row_length
,
"), found ())"
);
NVTE_CHECK
(
workspace_size
>=
required_size
,
"Invalid workspace (expected dims=("
,
num_rows_partial_dbias
,
","
,
row_length
,
"), dtype="
,
to_string
(
DType
::
kFloat32
),
"; found dims="
,
workspace
->
data
.
shape
,
...
...
transformer_engine/common/transpose/multi_cast_transpose.cu
View file @
0d874a4e
/*************************************************************************
* Copyright (c) 2022-202
5
, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* Copyright (c) 2022-202
6
, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
*
* See LICENSE for license information.
************************************************************************/
...
...
transformer_engine/common/transpose/quantize_transpose_square_blockwise.cu
View file @
0d874a4e
/*************************************************************************
* Copyright (c) 2022-202
5
, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* Copyright (c) 2022-202
6
, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
*
* See LICENSE for license information.
************************************************************************/
...
...
@@ -908,7 +908,7 @@ void quantize_transpose_square_blockwise(const SimpleTensor& input, SimpleTensor
}
NVTE_CHECK
(
input
.
shape
==
output
.
shape
,
"Input and output must have the same shape."
);
const
size_t
row_length
=
input
.
shape
.
size
()
>
0
?
input
.
shape
.
at
(
input
.
shape
.
size
()
-
1
)
:
1
u
;
const
size_t
row_length
=
input
.
shape
.
size
()
>
0
?
input
.
shape
.
back
(
)
:
1
;
size_t
num_rows
=
1
;
for
(
size_t
i
=
0
;
(
i
<
input
.
shape
.
size
()
-
1
)
&&
(
input
.
shape
.
size
()
>
0
);
++
i
)
{
num_rows
*=
input
.
shape
.
at
(
i
);
...
...
@@ -927,12 +927,14 @@ void quantize_transpose_square_blockwise(const SimpleTensor& input, SimpleTensor
const
float
*
noop_ptr
=
reinterpret_cast
<
const
float
*>
(
noop_tensor
.
dptr
);
if
(
return_transpose
)
{
NVTE_CHECK
(
output_t
.
shape
.
size
()
==
input
.
shape
.
size
(),
"output_t
must have same number of dimensions as input
."
);
NVTE_CHECK
(
output_t
.
shape
.
size
()
==
input
.
shape
.
size
(),
"input (shape="
,
input
.
shape
,
"
) and
output_t
(shape="
,
output_t
.
shape
,
") have incompatible dims
."
);
if
(
output_t
.
shape
.
size
()
>
0
)
{
NVTE_CHECK
(
output_t
.
shape
[
0
]
==
row_length
,
"Wrong dimension 0 of output_t."
);
NVTE_CHECK
(
output_t
.
shape
.
front
()
==
input
.
shape
.
back
(),
"input (shape="
,
input
.
shape
,
") and output_t (shape="
,
output_t
.
shape
,
") have incompatible dims."
);
for
(
size_t
i
=
1
;
i
<
output_t
.
shape
.
size
();
++
i
)
{
NVTE_CHECK
(
output_t
.
shape
.
at
(
i
)
==
input
.
shape
.
at
(
i
-
1
),
"Wrong dimension in output_t"
);
NVTE_CHECK
(
output_t
.
shape
[
i
]
==
input
.
shape
[
i
-
1
],
"input (shape="
,
input
.
shape
,
") and output_t (shape="
,
output_t
.
shape
,
") have incompatible dims."
);
}
}
NVTE_CHECK
(
output
.
dtype
==
output_t
.
dtype
,
"output and output_t need to have the same type."
);
...
...
transformer_engine/common/transpose/quantize_transpose_vector_blockwise.cu
View file @
0d874a4e
/*************************************************************************
* Copyright (c) 2022-202
5
, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* Copyright (c) 2022-202
6
, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
*
* See LICENSE for license information.
************************************************************************/
...
...
transformer_engine/common/transpose/quantize_transpose_vector_blockwise_fp4.cu
View file @
0d874a4e
/*************************************************************************
* Copyright (c) 2022-202
5
, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* Copyright (c) 2022-202
6
, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
*
* See LICENSE for license information.
************************************************************************/
...
...
transformer_engine/common/transpose/rtc/cast_transpose.cu
View file @
0d874a4e
/*************************************************************************
* Copyright (c) 2022-202
5
, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* Copyright (c) 2022-202
6
, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
*
* See LICENSE for license information.
************************************************************************/
...
...
transformer_engine/common/transpose/rtc/cast_transpose_fusion.cu
View file @
0d874a4e
/*************************************************************************
* Copyright (c) 2022-202
5
, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* Copyright (c) 2022-202
6
, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
*
* See LICENSE for license information.
************************************************************************/
...
...
transformer_engine/common/transpose/rtc/swap_first_dims.cu
View file @
0d874a4e
/*************************************************************************
* Copyright (c) 2022-202
5
, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* Copyright (c) 2022-202
6
, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
*
* See LICENSE for license information.
************************************************************************/
...
...
Prev
1
…
16
17
18
19
20
21
22
23
24
…
32
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment