Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
TransformerEngine
Commits
0d874a4e
Commit
0d874a4e
authored
Mar 03, 2026
by
wenjh
Browse files
Merge branch 'nv_main' of v2.12
parents
a68e5f87
dfdd3820
Changes
646
Show whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
622 additions
and
417 deletions
+622
-417
transformer_engine/pytorch/csrc/extensions/multi_tensor/compute_scale.cpp
...ne/pytorch/csrc/extensions/multi_tensor/compute_scale.cpp
+11
-1
transformer_engine/pytorch/csrc/extensions/multi_tensor/l2norm.cpp
...er_engine/pytorch/csrc/extensions/multi_tensor/l2norm.cpp
+1
-1
transformer_engine/pytorch/csrc/extensions/multi_tensor/scale.cpp
...mer_engine/pytorch/csrc/extensions/multi_tensor/scale.cpp
+1
-1
transformer_engine/pytorch/csrc/extensions/multi_tensor/sgd.cpp
...ormer_engine/pytorch/csrc/extensions/multi_tensor/sgd.cpp
+1
-1
transformer_engine/pytorch/csrc/extensions/normalization.cpp
transformer_engine/pytorch/csrc/extensions/normalization.cpp
+39
-15
transformer_engine/pytorch/csrc/extensions/nvshmem_comm.cpp
transformer_engine/pytorch/csrc/extensions/nvshmem_comm.cpp
+1
-1
transformer_engine/pytorch/csrc/extensions/padding.cpp
transformer_engine/pytorch/csrc/extensions/padding.cpp
+1
-1
transformer_engine/pytorch/csrc/extensions/permutation.cpp
transformer_engine/pytorch/csrc/extensions/permutation.cpp
+1
-1
transformer_engine/pytorch/csrc/extensions/pybind.cpp
transformer_engine/pytorch/csrc/extensions/pybind.cpp
+17
-2
transformer_engine/pytorch/csrc/extensions/recipe.cpp
transformer_engine/pytorch/csrc/extensions/recipe.cpp
+3
-3
transformer_engine/pytorch/csrc/extensions/router.cpp
transformer_engine/pytorch/csrc/extensions/router.cpp
+1
-1
transformer_engine/pytorch/csrc/extensions/softmax.cpp
transformer_engine/pytorch/csrc/extensions/softmax.cpp
+1
-1
transformer_engine/pytorch/csrc/extensions/swizzle.cpp
transformer_engine/pytorch/csrc/extensions/swizzle.cpp
+394
-0
transformer_engine/pytorch/csrc/extensions/transpose.cpp
transformer_engine/pytorch/csrc/extensions/transpose.cpp
+1
-1
transformer_engine/pytorch/csrc/pybind.h
transformer_engine/pytorch/csrc/pybind.h
+1
-1
transformer_engine/pytorch/csrc/quantizer.cpp
transformer_engine/pytorch/csrc/quantizer.cpp
+101
-108
transformer_engine/pytorch/csrc/type_converters.cpp
transformer_engine/pytorch/csrc/type_converters.cpp
+19
-5
transformer_engine/pytorch/csrc/util.cpp
transformer_engine/pytorch/csrc/util.cpp
+0
-256
transformer_engine/pytorch/csrc/util.h
transformer_engine/pytorch/csrc/util.h
+27
-16
transformer_engine/pytorch/custom_recipes/__init__.py
transformer_engine/pytorch/custom_recipes/__init__.py
+1
-1
No files found.
transformer_engine/pytorch/csrc/extensions/multi_tensor/compute_scale.cpp
View file @
0d874a4e
/*************************************************************************
/*************************************************************************
* Copyright (c) 2022-202
5
, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* Copyright (c) 2022-202
6
, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
*
*
* See LICENSE for license information.
* See LICENSE for license information.
************************************************************************/
************************************************************************/
...
@@ -20,4 +20,14 @@ void multi_tensor_compute_scale_and_scale_inv_cuda(
...
@@ -20,4 +20,14 @@ void multi_tensor_compute_scale_and_scale_inv_cuda(
force_pow_2_scales
,
epsilon
,
at
::
cuda
::
getCurrentCUDAStream
());
force_pow_2_scales
,
epsilon
,
at
::
cuda
::
getCurrentCUDAStream
());
}
}
void
multi_tensor_compute_scale_inv_e8m0_cuda
(
int
chunk_size
,
const
py
::
object
&
dummy
,
std
::
vector
<
std
::
vector
<
at
::
Tensor
>>
tensor_lists
)
{
NVTE_CHECK
(
dummy
.
is_none
(),
"No-op flag is not supported."
);
auto
[
_
,
__
,
tensor_lists_ptr
,
num_lists
,
num_tensors
]
=
makeTransformerEngineTensorList
(
tensor_lists
);
nvte_multi_tensor_compute_scale_inv_e8m0_cuda
(
chunk_size
,
tensor_lists_ptr
.
data
(),
num_lists
,
num_tensors
,
at
::
cuda
::
getCurrentCUDAStream
());
}
}
// namespace transformer_engine::pytorch
}
// namespace transformer_engine::pytorch
transformer_engine/pytorch/csrc/extensions/multi_tensor/l2norm.cpp
View file @
0d874a4e
/*************************************************************************
/*************************************************************************
* Copyright (c) 2022-202
5
, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* Copyright (c) 2022-202
6
, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
*
*
* See LICENSE for license information.
* See LICENSE for license information.
************************************************************************/
************************************************************************/
...
...
transformer_engine/pytorch/csrc/extensions/multi_tensor/scale.cpp
View file @
0d874a4e
/*************************************************************************
/*************************************************************************
* Copyright (c) 2022-202
5
, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* Copyright (c) 2022-202
6
, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
*
*
* See LICENSE for license information.
* See LICENSE for license information.
************************************************************************/
************************************************************************/
...
...
transformer_engine/pytorch/csrc/extensions/multi_tensor/sgd.cpp
View file @
0d874a4e
/*************************************************************************
/*************************************************************************
* Copyright (c) 2022-202
5
, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* Copyright (c) 2022-202
6
, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
*
*
* See LICENSE for license information.
* See LICENSE for license information.
************************************************************************/
************************************************************************/
...
...
transformer_engine/pytorch/csrc/extensions/normalization.cpp
View file @
0d874a4e
/*************************************************************************
/*************************************************************************
* Copyright (c) 2022-202
5
, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* Copyright (c) 2022-202
6
, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
*
*
* See LICENSE for license information.
* See LICENSE for license information.
************************************************************************/
************************************************************************/
...
@@ -64,6 +64,11 @@ std::vector<py::object> layernorm_fwd(py::handle input, py::handle weight, Maybe
...
@@ -64,6 +64,11 @@ std::vector<py::object> layernorm_fwd(py::handle input, py::handle weight, Maybe
const
bool
zero_centered_gamma
)
{
const
bool
zero_centered_gamma
)
{
using
namespace
transformer_engine
::
pytorch
::
detail
;
using
namespace
transformer_engine
::
pytorch
::
detail
;
// Ensure that cuDNN handle is created on the correct device,
// overriding torch.cuda.set_device calls from user side.
// Assumes all tensors passed are on the same device.
at
::
cuda
::
CUDAGuard
device_guard
(
input
.
cast
<
at
::
Tensor
>
().
device
());
// Input and param tensors
// Input and param tensors
auto
none
=
py
::
none
();
auto
none
=
py
::
none
();
const
TensorWrapper
&
input_nvte
=
makeTransformerEngineTensor
(
input
,
none
);
const
TensorWrapper
&
input_nvte
=
makeTransformerEngineTensor
(
input
,
none
);
...
@@ -84,14 +89,8 @@ std::vector<py::object> layernorm_fwd(py::handle input, py::handle weight, Maybe
...
@@ -84,14 +89,8 @@ std::vector<py::object> layernorm_fwd(py::handle input, py::handle weight, Maybe
TensorWrapper
mu_nvte
=
makeTransformerEngineTensor
(
mu_py
);
TensorWrapper
mu_nvte
=
makeTransformerEngineTensor
(
mu_py
);
TensorWrapper
rsigma_nvte
=
makeTransformerEngineTensor
(
rsigma_py
);
TensorWrapper
rsigma_nvte
=
makeTransformerEngineTensor
(
rsigma_py
);
//
Output tenso
r
//
Quantize
r
auto
quantizer_cpp
=
convert_quantizer
(
quantizer
);
auto
quantizer_cpp
=
convert_quantizer
(
quantizer
);
TensorWrapper
out_nvte
;
if
(
out
.
is_none
())
{
std
::
tie
(
out_nvte
,
out
)
=
quantizer_cpp
->
create_tensor
(
shape
,
out_dtype
);
}
else
{
out_nvte
=
makeTransformerEngineTensor
(
out
,
quantizer
);
}
// Choose implementation
// Choose implementation
enum
class
Impl
{
enum
class
Impl
{
...
@@ -130,6 +129,19 @@ std::vector<py::object> layernorm_fwd(py::handle input, py::handle weight, Maybe
...
@@ -130,6 +129,19 @@ std::vector<py::object> layernorm_fwd(py::handle input, py::handle weight, Maybe
}
}
}
}
// Output tensor
TensorWrapper
out_nvte
;
if
(
out
.
is_none
())
{
if
(
impl
==
Impl
::
FULLY_FUSED
)
{
// FP8 has no special logic to optimize for GEMM, MXFP8 cuDNN
// kernel does not support GEMM swizzled scales
quantizer_cpp
->
optimize_for_gemm
=
false
;
}
std
::
tie
(
out_nvte
,
out
)
=
quantizer_cpp
->
create_tensor
(
shape
,
out_dtype
);
}
else
{
out_nvte
=
makeTransformerEngineTensor
(
out
,
quantizer
);
}
// Construct unquantized output tensor if needed
// Construct unquantized output tensor if needed
TensorWrapper
unquantized_out_nvte
;
TensorWrapper
unquantized_out_nvte
;
py
::
object
unquantized_out
;
py
::
object
unquantized_out
;
...
@@ -294,6 +306,11 @@ std::vector<py::object> rmsnorm_fwd(const py::handle &input, const py::handle &w
...
@@ -294,6 +306,11 @@ std::vector<py::object> rmsnorm_fwd(const py::handle &input, const py::handle &w
const
int
sm_margin
,
const
bool
zero_centered_gamma
)
{
const
int
sm_margin
,
const
bool
zero_centered_gamma
)
{
using
namespace
transformer_engine
::
pytorch
::
detail
;
using
namespace
transformer_engine
::
pytorch
::
detail
;
// Ensure that cuDNN handle is created on the correct device,
// overriding torch.cuda.set_device calls from user side.
// Assumes all tensors passed are on the same device.
at
::
cuda
::
CUDAGuard
device_guard
(
input
.
cast
<
at
::
Tensor
>
().
device
());
// Input and param tensors
// Input and param tensors
auto
none
=
py
::
none
();
auto
none
=
py
::
none
();
const
TensorWrapper
&
input_nvte
=
makeTransformerEngineTensor
(
input
,
none
);
const
TensorWrapper
&
input_nvte
=
makeTransformerEngineTensor
(
input
,
none
);
...
@@ -308,14 +325,8 @@ std::vector<py::object> rmsnorm_fwd(const py::handle &input, const py::handle &w
...
@@ -308,14 +325,8 @@ std::vector<py::object> rmsnorm_fwd(const py::handle &input, const py::handle &w
at
::
Tensor
rsigma_py
=
at
::
empty
({
static_cast
<
int64_t
>
(
outer_size
)},
at
::
CUDA
(
at
::
kFloat
));
at
::
Tensor
rsigma_py
=
at
::
empty
({
static_cast
<
int64_t
>
(
outer_size
)},
at
::
CUDA
(
at
::
kFloat
));
TensorWrapper
rsigma_nvte
=
makeTransformerEngineTensor
(
rsigma_py
);
TensorWrapper
rsigma_nvte
=
makeTransformerEngineTensor
(
rsigma_py
);
//
Output tenso
r
//
Quantize
r
auto
quantizer_cpp
=
convert_quantizer
(
quantizer
);
auto
quantizer_cpp
=
convert_quantizer
(
quantizer
);
TensorWrapper
out_nvte
;
if
(
out
.
is_none
())
{
std
::
tie
(
out_nvte
,
out
)
=
quantizer_cpp
->
create_tensor
(
shape
,
out_dtype
);
}
else
{
out_nvte
=
makeTransformerEngineTensor
(
out
,
quantizer
);
}
// Choose implementation
// Choose implementation
enum
class
Impl
{
enum
class
Impl
{
...
@@ -354,6 +365,19 @@ std::vector<py::object> rmsnorm_fwd(const py::handle &input, const py::handle &w
...
@@ -354,6 +365,19 @@ std::vector<py::object> rmsnorm_fwd(const py::handle &input, const py::handle &w
}
}
}
}
// Output tensor
TensorWrapper
out_nvte
;
if
(
out
.
is_none
())
{
if
(
impl
==
Impl
::
FULLY_FUSED
)
{
// FP8 has no special logic to optimize for GEMM, MXFP8 cuDNN
// kernel does not support GEMM swizzled scales
quantizer_cpp
->
optimize_for_gemm
=
false
;
}
std
::
tie
(
out_nvte
,
out
)
=
quantizer_cpp
->
create_tensor
(
shape
,
out_dtype
);
}
else
{
out_nvte
=
makeTransformerEngineTensor
(
out
,
quantizer
);
}
// Construct unquantized output tensor if needed
// Construct unquantized output tensor if needed
TensorWrapper
unquantized_out_nvte
;
TensorWrapper
unquantized_out_nvte
;
py
::
object
unquantized_out
;
py
::
object
unquantized_out
;
...
...
transformer_engine/pytorch/csrc/extensions/nvshmem_comm.cpp
View file @
0d874a4e
/*************************************************************************
/*************************************************************************
* Copyright (c) 2022-202
5
, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* Copyright (c) 2022-202
6
, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
*
*
* See LICENSE for license information.
* See LICENSE for license information.
************************************************************************/
************************************************************************/
...
...
transformer_engine/pytorch/csrc/extensions/padding.cpp
View file @
0d874a4e
/*************************************************************************
/*************************************************************************
* Copyright (c) 2022-202
5
, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* Copyright (c) 2022-202
6
, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
*
*
* See LICENSE for license information.
* See LICENSE for license information.
************************************************************************/
************************************************************************/
...
...
transformer_engine/pytorch/csrc/extensions/permutation.cpp
View file @
0d874a4e
/*************************************************************************
/*************************************************************************
* Copyright (c) 2022-202
5
, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* Copyright (c) 2022-202
6
, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
*
*
* See LICENSE for license information.
* See LICENSE for license information.
************************************************************************/
************************************************************************/
...
...
transformer_engine/pytorch/csrc/extensions/pybind.cpp
View file @
0d874a4e
/*************************************************************************
/*************************************************************************
* Copyright (c) 2022-202
5
, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* Copyright (c) 2022-202
6
, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
*
*
* See LICENSE for license information.
* See LICENSE for license information.
************************************************************************/
************************************************************************/
...
@@ -248,7 +248,7 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
...
@@ -248,7 +248,7 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
"Multi-tensor quantize"
,
py
::
arg
(
"tensor_list"
),
py
::
arg
(
"quantizer_list"
));
"Multi-tensor quantize"
,
py
::
arg
(
"tensor_list"
),
py
::
arg
(
"quantizer_list"
));
m
.
def
(
"split_quantize"
,
&
transformer_engine
::
pytorch
::
split_quantize
,
m
.
def
(
"split_quantize"
,
&
transformer_engine
::
pytorch
::
split_quantize
,
"Split and multi-tensor quantize"
,
py
::
arg
(
"tensor"
),
py
::
arg
(
"split_sections"
),
"Split and multi-tensor quantize"
,
py
::
arg
(
"tensor"
),
py
::
arg
(
"split_sections"
),
py
::
arg
(
"quantizer_list"
));
py
::
arg
(
"quantizer_list"
)
,
py
::
arg
(
"disable_bulk_allocation"
)
=
false
);
m
.
def
(
"te_general_grouped_gemm"
,
&
transformer_engine
::
pytorch
::
te_general_grouped_gemm
,
m
.
def
(
"te_general_grouped_gemm"
,
&
transformer_engine
::
pytorch
::
te_general_grouped_gemm
,
"Grouped GEMM"
);
"Grouped GEMM"
);
#ifdef USE_ROCM
#ifdef USE_ROCM
...
@@ -296,10 +296,22 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
...
@@ -296,10 +296,22 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
"Partial cast from master weights for fp8 block scaling"
,
py
::
arg
(
"inp"
),
py
::
arg
(
"out"
),
"Partial cast from master weights for fp8 block scaling"
,
py
::
arg
(
"inp"
),
py
::
arg
(
"out"
),
py
::
arg
(
"scale"
),
py
::
arg
(
"h"
),
py
::
arg
(
"w"
),
py
::
arg
(
"start_offset"
),
py
::
arg
(
"block_len"
),
py
::
arg
(
"scale"
),
py
::
arg
(
"h"
),
py
::
arg
(
"w"
),
py
::
arg
(
"start_offset"
),
py
::
arg
(
"block_len"
),
py
::
arg
(
"out_dtype"
),
py
::
call_guard
<
py
::
gil_scoped_release
>
());
py
::
arg
(
"out_dtype"
),
py
::
call_guard
<
py
::
gil_scoped_release
>
());
m
.
def
(
"mxfp8_scaling_compute_partial_amax"
,
&
transformer_engine
::
pytorch
::
mxfp8_scaling_compute_partial_amax
,
"Compute partial amax from master weights for fp8 mxfp8 scaling"
,
py
::
arg
(
"input"
),
py
::
arg
(
"amax_rowwise"
),
py
::
arg
(
"amax_colwise"
),
py
::
arg
(
"rows"
),
py
::
arg
(
"cols"
),
py
::
arg
(
"start_offset"
),
py
::
call_guard
<
py
::
gil_scoped_release
>
());
m
.
def
(
"mxfp8_scaling_partial_cast"
,
&
transformer_engine
::
pytorch
::
mxfp8_scaling_partial_cast
,
"Partial cast from master weights for fp8 mxfp8 scaling"
,
py
::
arg
(
"input"
),
py
::
arg
(
"output_rowwise"
),
py
::
arg
(
"output_colwise"
),
py
::
arg
(
"scale_inv_rowwise"
),
py
::
arg
(
"scale_inv_colwise"
),
py
::
arg
(
"rows"
),
py
::
arg
(
"cols"
),
py
::
arg
(
"start_offset"
),
py
::
call_guard
<
py
::
gil_scoped_release
>
());
m
.
def
(
"fused_multi_row_padding"
,
&
transformer_engine
::
pytorch
::
fused_multi_row_padding
,
m
.
def
(
"fused_multi_row_padding"
,
&
transformer_engine
::
pytorch
::
fused_multi_row_padding
,
"Fused Multi-tensor padding"
,
py
::
call_guard
<
py
::
gil_scoped_release
>
());
"Fused Multi-tensor padding"
,
py
::
call_guard
<
py
::
gil_scoped_release
>
());
m
.
def
(
"fused_multi_row_unpadding"
,
&
transformer_engine
::
pytorch
::
fused_multi_row_unpadding
,
m
.
def
(
"fused_multi_row_unpadding"
,
&
transformer_engine
::
pytorch
::
fused_multi_row_unpadding
,
"Fused Multi-tensor unpadding"
,
py
::
call_guard
<
py
::
gil_scoped_release
>
());
"Fused Multi-tensor unpadding"
,
py
::
call_guard
<
py
::
gil_scoped_release
>
());
m
.
def
(
"swizzle_scales_for_gemm_"
,
&
transformer_engine
::
pytorch
::
inplace_swizzle_scale_for_gemm
,
"Convert tensor block scales into GEMM swizzled format"
);
// attention kernels
// attention kernels
m
.
def
(
"fa_prepare_fwd"
,
&
transformer_engine
::
pytorch
::
fa_prepare_fwd
,
m
.
def
(
"fa_prepare_fwd"
,
&
transformer_engine
::
pytorch
::
fa_prepare_fwd
,
...
@@ -450,6 +462,9 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
...
@@ -450,6 +462,9 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m
.
def
(
"multi_tensor_compute_scale_and_scale_inv"
,
m
.
def
(
"multi_tensor_compute_scale_and_scale_inv"
,
&
transformer_engine
::
pytorch
::
multi_tensor_compute_scale_and_scale_inv_cuda
,
&
transformer_engine
::
pytorch
::
multi_tensor_compute_scale_and_scale_inv_cuda
,
"Fused compute scale and scale_inv from amax"
,
py
::
call_guard
<
py
::
gil_scoped_release
>
());
"Fused compute scale and scale_inv from amax"
,
py
::
call_guard
<
py
::
gil_scoped_release
>
());
m
.
def
(
"multi_tensor_compute_scale_inv_e8m0"
,
&
transformer_engine
::
pytorch
::
multi_tensor_compute_scale_inv_e8m0_cuda
,
"Fused compute E8M0 scale_inv from amax"
,
py
::
call_guard
<
py
::
gil_scoped_release
>
());
// Comm+GEMM Overlap
// Comm+GEMM Overlap
m
.
def
(
"bulk_overlap_ag_with_external_gemm"
,
m
.
def
(
"bulk_overlap_ag_with_external_gemm"
,
...
...
transformer_engine/pytorch/csrc/extensions/recipe.cpp
View file @
0d874a4e
/*************************************************************************
/*************************************************************************
* Copyright (c) 2022-202
5
, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* Copyright (c) 2022-202
6
, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
*
*
* See LICENSE for license information.
* See LICENSE for license information.
************************************************************************/
************************************************************************/
...
@@ -22,8 +22,8 @@ void compute_amax(const at::Tensor& tensor, at::Tensor& amax) {
...
@@ -22,8 +22,8 @@ void compute_amax(const at::Tensor& tensor, at::Tensor& amax) {
TORCH_CHECK
(
amax
.
numel
()
==
1
,
"amax must have exactly one element"
);
TORCH_CHECK
(
amax
.
numel
()
==
1
,
"amax must have exactly one element"
);
auto
*
amax_ptr
=
amax
.
data_ptr
<
float
>
();
auto
*
amax_ptr
=
amax
.
data_ptr
<
float
>
();
TensorWrapper
fake_te_output
(
TensorWrapper
fake_te_output
(
nullptr
,
te_input
.
shape
(),
/*dptr=*/
nullptr
,
te_input
.
shape
(),
DType
::
kFloat
8E4M
3
,
// It doesn't matter because we only compute amax.
DType
::
kFloat3
2
,
// It doesn't matter because we only compute amax.
amax_ptr
);
amax_ptr
);
nvte_compute_amax
(
te_input
.
data
(),
fake_te_output
.
data
(),
at
::
cuda
::
getCurrentCUDAStream
());
nvte_compute_amax
(
te_input
.
data
(),
fake_te_output
.
data
(),
at
::
cuda
::
getCurrentCUDAStream
());
...
...
transformer_engine/pytorch/csrc/extensions/router.cpp
View file @
0d874a4e
/*************************************************************************
/*************************************************************************
* Copyright (c) 2022-202
5
, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* Copyright (c) 2022-202
6
, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
*
*
* See LICENSE for license information.
* See LICENSE for license information.
************************************************************************/
************************************************************************/
...
...
transformer_engine/pytorch/csrc/extensions/softmax.cpp
View file @
0d874a4e
/*************************************************************************
/*************************************************************************
* Copyright (c) 2022-202
5
, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* Copyright (c) 2022-202
6
, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
*
*
* See LICENSE for license information.
* See LICENSE for license information.
************************************************************************/
************************************************************************/
...
...
transformer_engine/pytorch/csrc/extensions/swizzle.cpp
0 → 100644
View file @
0d874a4e
/*************************************************************************
* Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
*
* See LICENSE for license information.
************************************************************************/
#include <optional>
#include <tuple>
#include <utility>
#include <vector>
#include "common.h"
#include "common/common.h"
#include "extensions.h"
#include "pybind.h"
#include "util.h"
namespace
transformer_engine
{
namespace
pytorch
{
namespace
{
void
reset_tensor_data
(
transformer_engine
::
TensorWrapper
&
tensor
,
bool
rowwise
,
bool
columnwise
)
{
NVTEShape
shape
;
shape
.
ndim
=
1
;
shape
.
data
[
0
]
=
0
;
const
transformer_engine
::
DType
dtype
=
transformer_engine
::
DType
::
kFloat32
;
if
(
rowwise
)
{
tensor
.
set_rowwise_data
(
nullptr
,
dtype
,
shape
);
tensor
.
set_rowwise_scale_inv
(
nullptr
,
dtype
,
shape
);
}
if
(
columnwise
)
{
tensor
.
set_columnwise_data
(
nullptr
,
dtype
,
shape
);
tensor
.
set_columnwise_scale_inv
(
nullptr
,
dtype
,
shape
);
}
}
}
// namespace
std
::
tuple
<
std
::
optional
<
at
::
Tensor
>
,
std
::
optional
<
at
::
Tensor
>>
swizzle_scales_for_gemm
(
transformer_engine
::
TensorWrapper
&
tensor
,
bool
rowwise_usage
,
bool
columnwise_usage
)
{
// Return early if scale swizzling is not required
const
auto
scaling_mode
=
tensor
.
scaling_mode
();
switch
(
scaling_mode
)
{
case
NVTE_MXFP8_1D_SCALING
:
case
NVTE_NVFP4_1D_SCALING
:
// Tensor format requires scale swizzling
break
;
case
NVTE_INVALID_SCALING
:
NVTE_ERROR
(
"Invalid scaling mode for swizzling scaling factors."
);
default:
// Tensor format does not require scale swizzling for GEMM
return
{
std
::
nullopt
,
std
::
nullopt
};
}
// Return early if scales are already swizzled
if
(
tensor
.
get_with_gemm_swizzled_scales
())
{
return
{
std
::
nullopt
,
std
::
nullopt
};
}
// CUDA stream
auto
stream
=
at
::
cuda
::
getCurrentCUDAStream
();
// Swizzle row-wise scales if needed
std
::
optional
<
at
::
Tensor
>
rowwise_scales_pyt
;
if
(
rowwise_usage
)
{
// Buffer for unswizzled scales
const
auto
input_scales_nvte
=
tensor
.
get_rowwise_scale_inv
();
void
*
input_scales_dptr
=
input_scales_nvte
.
data_ptr
;
const
NVTEShape
input_scales_shape
=
input_scales_nvte
.
shape
;
const
auto
scales_dtype
=
static_cast
<
DType
>
(
input_scales_nvte
.
dtype
);
// Allocate buffer for swizzled scales
const
NVTEShape
output_scales_shape
=
input_scales_shape
;
rowwise_scales_pyt
=
allocateSpace
(
input_scales_shape
,
scales_dtype
,
false
);
void
*
output_scales_dptr
=
getDataPtr
(
*
rowwise_scales_pyt
);
// Initialize TE tensors with scales
const
auto
data_nvte
=
tensor
.
get_rowwise_data
();
const
auto
data_dtype
=
static_cast
<
DType
>
(
data_nvte
.
dtype
);
TensorWrapper
input_nvte
(
scaling_mode
);
input_nvte
.
set_rowwise_data
(
nullptr
,
data_dtype
,
data_nvte
.
shape
);
input_nvte
.
set_rowwise_scale_inv
(
input_scales_dptr
,
scales_dtype
,
input_scales_shape
);
TensorWrapper
output_nvte
(
scaling_mode
);
output_nvte
.
set_rowwise_data
(
nullptr
,
data_dtype
,
data_nvte
.
shape
);
output_nvte
.
set_rowwise_scale_inv
(
output_scales_dptr
,
scales_dtype
,
output_scales_shape
);
output_nvte
.
set_with_gemm_swizzled_scales
(
true
);
// Launch kernel
NVTE_SCOPED_GIL_RELEASE
(
{
nvte_swizzle_scaling_factors
(
input_nvte
.
data
(),
output_nvte
.
data
(),
stream
);
});
// Update tensor with swizzled scales
tensor
.
set_rowwise_scale_inv
(
output_scales_dptr
,
scales_dtype
,
output_scales_shape
);
}
// Swizzle column-wise scales if needed
std
::
optional
<
at
::
Tensor
>
columnwise_scales_pyt
;
if
(
columnwise_usage
)
{
// Buffer for unswizzled scales
const
auto
input_scales_nvte
=
tensor
.
get_columnwise_scale_inv
();
void
*
input_scales_dptr
=
input_scales_nvte
.
data_ptr
;
const
NVTEShape
input_scales_shape
=
input_scales_nvte
.
shape
;
const
auto
scales_dtype
=
static_cast
<
DType
>
(
input_scales_nvte
.
dtype
);
// Allocate buffer for swizzled scales
const
NVTEShape
output_scales_shape
=
input_scales_shape
;
columnwise_scales_pyt
=
allocateSpace
(
input_scales_shape
,
scales_dtype
,
false
);
void
*
output_scales_dptr
=
getDataPtr
(
*
columnwise_scales_pyt
);
// Initialize TE tensors with scales
const
auto
data_nvte
=
tensor
.
get_columnwise_data
();
const
auto
data_dtype
=
static_cast
<
DType
>
(
data_nvte
.
dtype
);
TensorWrapper
input_nvte
(
scaling_mode
);
input_nvte
.
set_columnwise_data
(
nullptr
,
data_dtype
,
data_nvte
.
shape
);
input_nvte
.
set_columnwise_scale_inv
(
input_scales_dptr
,
scales_dtype
,
input_scales_shape
);
TensorWrapper
output_nvte
(
scaling_mode
);
output_nvte
.
set_columnwise_data
(
nullptr
,
data_dtype
,
data_nvte
.
shape
);
output_nvte
.
set_columnwise_scale_inv
(
output_scales_dptr
,
scales_dtype
,
output_scales_shape
);
output_nvte
.
set_with_gemm_swizzled_scales
(
true
);
// Launch kernel
NVTE_SCOPED_GIL_RELEASE
(
{
nvte_swizzle_scaling_factors
(
input_nvte
.
data
(),
output_nvte
.
data
(),
stream
);
});
// Update tensor with swizzled scales
tensor
.
set_columnwise_scale_inv
(
output_scales_dptr
,
scales_dtype
,
output_scales_shape
);
}
// Update tensor
reset_tensor_data
(
tensor
,
!
rowwise_usage
,
!
columnwise_usage
);
tensor
.
set_with_gemm_swizzled_scales
(
true
);
return
{
std
::
move
(
rowwise_scales_pyt
),
std
::
move
(
columnwise_scales_pyt
)};
}
std
::
optional
<
at
::
Tensor
>
multi_tensor_swizzle_scales_for_gemm
(
std
::
vector
<
transformer_engine
::
TensorWrapper
>
&
tensors
,
bool
rowwise_usage
,
bool
columnwise_usage
)
{
// Checks and trivial cases
NVTE_CHECK
(
rowwise_usage
!=
columnwise_usage
,
"Expect exactly one of rowwise_usage="
,
rowwise_usage
,
" and columnwise_usage="
,
columnwise_usage
,
"."
);
if
(
tensors
.
empty
())
{
return
std
::
nullopt
;
}
const
auto
scaling_mode
=
tensors
.
front
().
scaling_mode
();
for
(
const
auto
&
tensor
:
tensors
)
{
NVTE_CHECK
(
tensor
.
scaling_mode
()
==
scaling_mode
,
"Tensors have different scaling modes"
);
}
// Return early if scale swizzling is not required
switch
(
scaling_mode
)
{
case
NVTE_MXFP8_1D_SCALING
:
case
NVTE_NVFP4_1D_SCALING
:
// Tensor format requires scale swizzling
break
;
case
NVTE_INVALID_SCALING
:
NVTE_ERROR
(
"Invalid scaling mode for swizzling scaling factors."
);
default:
// Tensor format does not require scale swizzling for GEMM
return
std
::
nullopt
;
}
// Filter out tensors that already have swizzled scales
std
::
vector
<
TensorWrapper
*>
tensors_needing_swizzle
;
for
(
auto
&
tensor
:
tensors
)
{
if
(
!
tensor
.
get_with_gemm_swizzled_scales
())
{
tensors_needing_swizzle
.
push_back
(
&
tensor
);
}
}
if
(
tensors_needing_swizzle
.
empty
())
{
return
std
::
nullopt
;
}
// Determine buffer size needed for swizzled scales
std
::
vector
<
size_t
>
output_scales_offsets
;
size_t
output_scales_bytes
=
0
;
for
(
auto
&
tensor
:
tensors_needing_swizzle
)
{
const
auto
scales_nvte
=
(
rowwise_usage
?
tensor
->
get_rowwise_scale_inv
()
:
tensor
->
get_columnwise_scale_inv
());
const
auto
&
shape
=
scales_nvte
.
shape
;
const
auto
dtype
=
static_cast
<
DType
>
(
scales_nvte
.
dtype
);
const
auto
dtype_bits
=
transformer_engine
::
pytorch
::
typeToNumBits
(
dtype
);
const
auto
size
=
product
(
shape
,
0
,
shape
.
ndim
);
output_scales_bytes
=
roundup
(
output_scales_bytes
,
16
);
// align to 16B
output_scales_offsets
.
push_back
(
output_scales_bytes
);
output_scales_bytes
+=
ceildiv
(
size
*
dtype_bits
,
8
);
}
// Allocate buffer for swizzled scales
auto
output_scales_pyt
=
allocateSpace
(
std
::
vector
<
size_t
>
{
output_scales_bytes
},
transformer_engine
::
DType
::
kByte
,
false
);
uint8_t
*
output_scales_dptr
=
reinterpret_cast
<
uint8_t
*>
(
getDataPtr
(
output_scales_pyt
));
// Construct TE tensors with only scales
std
::
vector
<
transformer_engine
::
TensorWrapper
>
inputs_nvte
,
outputs_nvte
;
for
(
size_t
i
=
0
;
i
<
tensors_needing_swizzle
.
size
();
++
i
)
{
auto
&
tensor
=
*
tensors_needing_swizzle
[
i
];
inputs_nvte
.
emplace_back
(
scaling_mode
);
outputs_nvte
.
emplace_back
(
scaling_mode
);
auto
&
input_nvte
=
inputs_nvte
.
back
();
auto
&
output_nvte
=
outputs_nvte
.
back
();
output_nvte
.
set_with_gemm_swizzled_scales
(
true
);
if
(
rowwise_usage
)
{
const
auto
data_nvte
=
tensor
.
get_rowwise_data
();
const
auto
scales_nvte
=
tensor
.
get_rowwise_scale_inv
();
const
auto
data_dtype
=
static_cast
<
transformer_engine
::
DType
>
(
data_nvte
.
dtype
);
const
auto
scales_dtype
=
static_cast
<
transformer_engine
::
DType
>
(
scales_nvte
.
dtype
);
input_nvte
.
set_rowwise_data
(
nullptr
,
data_dtype
,
data_nvte
.
shape
);
input_nvte
.
set_rowwise_scale_inv
(
scales_nvte
.
data_ptr
,
scales_dtype
,
scales_nvte
.
shape
);
output_nvte
.
set_rowwise_data
(
nullptr
,
data_dtype
,
data_nvte
.
shape
);
output_nvte
.
set_rowwise_scale_inv
(
output_scales_dptr
+
output_scales_offsets
[
i
],
scales_dtype
,
scales_nvte
.
shape
);
}
else
{
const
auto
data_nvte
=
tensor
.
get_columnwise_data
();
const
auto
scales_nvte
=
tensor
.
get_columnwise_scale_inv
();
const
auto
data_dtype
=
static_cast
<
transformer_engine
::
DType
>
(
data_nvte
.
dtype
);
const
auto
scales_dtype
=
static_cast
<
transformer_engine
::
DType
>
(
scales_nvte
.
dtype
);
input_nvte
.
set_columnwise_data
(
nullptr
,
data_dtype
,
data_nvte
.
shape
);
input_nvte
.
set_columnwise_scale_inv
(
scales_nvte
.
data_ptr
,
scales_dtype
,
scales_nvte
.
shape
);
output_nvte
.
set_columnwise_data
(
nullptr
,
data_dtype
,
data_nvte
.
shape
);
output_nvte
.
set_columnwise_scale_inv
(
output_scales_dptr
+
output_scales_offsets
[
i
],
scales_dtype
,
scales_nvte
.
shape
);
}
}
// Pack raw NVTETensors into vectors
std
::
vector
<
NVTETensor
>
inputs_nvte_raw
,
outputs_nvte_raw
;
for
(
auto
&
tensor
:
inputs_nvte
)
{
inputs_nvte_raw
.
emplace_back
(
tensor
.
data
());
}
for
(
auto
&
tensor
:
outputs_nvte
)
{
outputs_nvte_raw
.
emplace_back
(
tensor
.
data
());
}
// Launch kernel
NVTE_SCOPED_GIL_RELEASE
({
nvte_multi_tensor_swizzle_scaling_factors
(
inputs_nvte_raw
.
data
(),
outputs_nvte_raw
.
data
(),
inputs_nvte_raw
.
size
(),
at
::
cuda
::
getCurrentCUDAStream
());
});
// Update tensors with swizzled scales
for
(
size_t
i
=
0
;
i
<
tensors_needing_swizzle
.
size
();
++
i
)
{
auto
&
tensor
=
*
tensors_needing_swizzle
[
i
];
reset_tensor_data
(
tensor
,
!
rowwise_usage
,
!
columnwise_usage
);
tensor
.
set_with_gemm_swizzled_scales
(
true
);
if
(
rowwise_usage
)
{
auto
scales_nvte
=
outputs_nvte
[
i
].
get_rowwise_scale_inv
();
const
auto
scales_dtype
=
static_cast
<
transformer_engine
::
DType
>
(
scales_nvte
.
dtype
);
tensor
.
set_rowwise_scale_inv
(
output_scales_dptr
+
output_scales_offsets
[
i
],
scales_dtype
,
scales_nvte
.
shape
);
}
else
{
auto
scales_nvte
=
outputs_nvte
[
i
].
get_columnwise_scale_inv
();
const
auto
scales_dtype
=
static_cast
<
transformer_engine
::
DType
>
(
scales_nvte
.
dtype
);
tensor
.
set_columnwise_scale_inv
(
output_scales_dptr
+
output_scales_offsets
[
i
],
scales_dtype
,
scales_nvte
.
shape
);
}
}
return
std
::
move
(
output_scales_pyt
);
}
at
::
Tensor
convert_block_scaling_to_mxfp8_tensor
(
transformer_engine
::
TensorWrapper
&
input
,
bool
rowwise
)
{
// Check input tensor
const
NVTEScalingMode
scaling_mode
=
input
.
scaling_mode
();
NVTE_CHECK
(
scaling_mode
==
NVTE_BLOCK_SCALING_1D
||
scaling_mode
==
NVTE_BLOCK_SCALING_2D
,
"Input tensor must be a block scaling tensor"
);
// Get tensor data
NVTEBasicTensor
data
;
size_t
data_flat_first_dim
=
1
;
size_t
data_flat_last_dim
=
1
;
if
(
rowwise
)
{
data
=
input
.
get_rowwise_data
();
for
(
size_t
i
=
0
;
i
<
data
.
shape
.
ndim
-
1
;
++
i
)
{
data_flat_first_dim
*=
data
.
shape
.
data
[
i
];
}
data_flat_last_dim
=
data
.
shape
.
data
[
data
.
shape
.
ndim
-
1
];
}
else
{
data
=
input
.
get_columnwise_data
();
data_flat_first_dim
=
data
.
shape
.
data
[
0
];
for
(
size_t
i
=
1
;
i
<
data
.
shape
.
ndim
;
++
i
)
{
data_flat_last_dim
*=
data
.
shape
.
data
[
i
];
}
}
NVTEShape
data_shape
{};
data_shape
.
data
[
0
]
=
data_flat_first_dim
;
data_shape
.
data
[
1
]
=
data_flat_last_dim
;
data_shape
.
ndim
=
2
;
// Recreate input tensor with rowwise usage
transformer_engine
::
TensorWrapper
input_cu
(
scaling_mode
);
input_cu
.
set_rowwise_data
(
data
.
data_ptr
,
input
.
dtype
(),
data_shape
);
const
NVTEBasicTensor
scale_inv
=
rowwise
?
input
.
get_rowwise_scale_inv
()
:
input
.
get_columnwise_scale_inv
();
input_cu
.
set_rowwise_scale_inv
(
scale_inv
.
data_ptr
,
static_cast
<
transformer_engine
::
DType
>
(
scale_inv
.
dtype
),
scale_inv
.
shape
);
// Create output tensor
transformer_engine
::
TensorWrapper
output_cu
(
NVTE_MXFP8_1D_SCALING
);
output_cu
.
set_rowwise_data
(
data
.
data_ptr
,
input
.
dtype
(),
data_shape
);
// Output swizzled mxfp8 scaling factor dimensions
const
size_t
swizzled_scale_inv_first_dim
=
ceildiv
(
data_flat_first_dim
,
128
)
*
128
;
const
size_t
swizzled_scale_inv_last_dim
=
ceildiv
(
data_flat_last_dim
,
128
)
*
4
;
// Allocate memory for swizzled mxfp8 scaling factors
at
::
Tensor
swizzled_scale_inv
=
allocateSpace
(
std
::
vector
<
size_t
>
{
swizzled_scale_inv_first_dim
,
swizzled_scale_inv_last_dim
},
transformer_engine
::
DType
::
kByte
,
false
);
// Set rowwise scaling factors on output
void
*
const
swizzled_scale_inv_dptr
=
getDataPtr
(
swizzled_scale_inv
,
0
);
NVTEShape
swizzled_scale_inv_shape
{};
swizzled_scale_inv_shape
.
data
[
0
]
=
swizzled_scale_inv_first_dim
;
swizzled_scale_inv_shape
.
data
[
1
]
=
swizzled_scale_inv_last_dim
;
swizzled_scale_inv_shape
.
ndim
=
2
;
output_cu
.
set_rowwise_scale_inv
(
swizzled_scale_inv_dptr
,
transformer_engine
::
DType
::
kFloat8E8M0
,
swizzled_scale_inv_shape
);
output_cu
.
set_with_gemm_swizzled_scales
(
true
);
// Convert scaling factors from FP8 block scaling GEMM_READY format to mxfp8 swizzled format
NVTE_SCOPED_GIL_RELEASE
({
nvte_swizzle_block_scaling_to_mxfp8_scaling_factors
(
input_cu
.
data
(),
output_cu
.
data
(),
at
::
cuda
::
getCurrentCUDAStream
());
});
// Set the input tensor to be the converted mxfp8 tensor and return the swizzled scaling factor
// for it to be kept alive during the GEMM
input
=
std
::
move
(
output_cu
);
return
swizzled_scale_inv
;
}
void
inplace_swizzle_scale_for_gemm
(
py
::
handle
&
tensor
)
{
// Convert Python tensor to C++ tensor
auto
tensor_nvte
=
makeTransformerEngineTensor
(
tensor
,
py
::
none
());
// Return early if scale swizzling is not required
const
auto
scaling_mode
=
tensor_nvte
.
scaling_mode
();
switch
(
scaling_mode
)
{
case
NVTE_MXFP8_1D_SCALING
:
case
NVTE_NVFP4_1D_SCALING
:
// Tensor format requires scale swizzling
break
;
case
NVTE_INVALID_SCALING
:
NVTE_ERROR
(
"Invalid scaling mode for swizzling scaling factors."
);
default:
// Tensor format does not require scale swizzling for GEMM
return
;
}
// Return early if scales are already swizzled
if
(
tensor_nvte
.
get_with_gemm_swizzled_scales
())
{
return
;
}
// Check what scaling factors the tensor contains
auto
is_empty
=
[](
const
NVTEBasicTensor
&
t
)
->
bool
{
return
t
.
shape
.
ndim
==
1
&&
t
.
shape
.
data
[
0
]
==
0
;
};
const
bool
has_rowwise_scales
=
!
is_empty
(
tensor_nvte
.
get_rowwise_scale_inv
());
const
bool
has_columnwise_scales
=
!
is_empty
(
tensor_nvte
.
get_columnwise_scale_inv
());
// Swizzle scaling factors
auto
[
rowwise_scales
,
columnwise_scales
]
=
swizzle_scales_for_gemm
(
tensor_nvte
,
has_rowwise_scales
,
has_columnwise_scales
);
// Update Python tensor with swizzled scales
switch
(
scaling_mode
)
{
case
NVTE_MXFP8_1D_SCALING
:
if
(
has_rowwise_scales
)
{
tensor
.
attr
(
"_rowwise_scale_inv"
)
=
rowwise_scales
;
}
if
(
has_columnwise_scales
)
{
tensor
.
attr
(
"_columnwise_scale_inv"
)
=
columnwise_scales
;
}
tensor
.
attr
(
"_with_gemm_swizzled_scales"
)
=
true
;
break
;
case
NVTE_NVFP4_1D_SCALING
:
if
(
has_rowwise_scales
)
{
tensor
.
attr
(
"_rowwise_scale_inv"
)
=
rowwise_scales
;
}
if
(
has_columnwise_scales
)
{
tensor
.
attr
(
"_columnwise_scale_inv"
)
=
columnwise_scales
;
}
tensor
.
attr
(
"_with_gemm_swizzled_scales"
)
=
true
;
break
;
default:
NVTE_ERROR
(
"Invalid scaling mode for swizzling scaling factors."
);
}
}
}
// namespace pytorch
}
// namespace transformer_engine
transformer_engine/pytorch/csrc/extensions/transpose.cpp
View file @
0d874a4e
/*************************************************************************
/*************************************************************************
* Copyright (c) 2022-202
5
, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* Copyright (c) 2022-202
6
, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
*
*
* See LICENSE for license information.
* See LICENSE for license information.
************************************************************************/
************************************************************************/
...
...
transformer_engine/pytorch/csrc/pybind.h
View file @
0d874a4e
/*************************************************************************
/*************************************************************************
* Copyright (c) 2022-202
5
, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* Copyright (c) 2022-202
6
, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
*
*
* See LICENSE for license information.
* See LICENSE for license information.
************************************************************************/
************************************************************************/
...
...
transformer_engine/pytorch/csrc/quantizer.cpp
View file @
0d874a4e
/*************************************************************************
/*************************************************************************
* Copyright (c) 2022-202
5
, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* Copyright (c) 2022-202
6
, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
*
*
* See LICENSE for license information.
* See LICENSE for license information.
************************************************************************/
************************************************************************/
...
@@ -52,10 +52,12 @@ Quantizer::Quantizer(const py::handle& quantizer) {
...
@@ -52,10 +52,12 @@ Quantizer::Quantizer(const py::handle& quantizer) {
this
->
rowwise_usage
=
true
;
this
->
rowwise_usage
=
true
;
this
->
columnwise_usage
=
true
;
this
->
columnwise_usage
=
true
;
this
->
internal
=
false
;
this
->
internal
=
false
;
this
->
optimize_for_gemm
=
false
;
}
else
{
}
else
{
this
->
rowwise_usage
=
quantizer
.
attr
(
"rowwise_usage"
).
cast
<
bool
>
();
this
->
rowwise_usage
=
quantizer
.
attr
(
"rowwise_usage"
).
cast
<
bool
>
();
this
->
columnwise_usage
=
quantizer
.
attr
(
"columnwise_usage"
).
cast
<
bool
>
();
this
->
columnwise_usage
=
quantizer
.
attr
(
"columnwise_usage"
).
cast
<
bool
>
();
this
->
internal
=
quantizer
.
attr
(
"internal"
).
cast
<
bool
>
();
this
->
internal
=
quantizer
.
attr
(
"internal"
).
cast
<
bool
>
();
this
->
optimize_for_gemm
=
quantizer
.
attr
(
"optimize_for_gemm"
).
cast
<
bool
>
();
this
->
quantizer
=
quantizer
;
this
->
quantizer
=
quantizer
;
}
}
}
}
...
@@ -555,7 +557,6 @@ Float8BlockQuantizer::Float8BlockQuantizer(const py::handle& quantizer) : Quanti
...
@@ -555,7 +557,6 @@ Float8BlockQuantizer::Float8BlockQuantizer(const py::handle& quantizer) : Quanti
this
->
amax_epsilon
=
quantizer
.
attr
(
"amax_epsilon"
).
cast
<
float
>
();
this
->
amax_epsilon
=
quantizer
.
attr
(
"amax_epsilon"
).
cast
<
float
>
();
NVTE_CHECK
(
this
->
block_scaling_dim
==
1
||
this
->
block_scaling_dim
==
2
,
NVTE_CHECK
(
this
->
block_scaling_dim
==
1
||
this
->
block_scaling_dim
==
2
,
"Unsupported block scaling dim."
);
"Unsupported block scaling dim."
);
this
->
all_gather_usage
=
quantizer
.
attr
(
"all_gather_usage"
).
cast
<
bool
>
();
}
}
void
Float8BlockQuantizer
::
set_quantization_params
(
TensorWrapper
*
tensor
)
const
{}
void
Float8BlockQuantizer
::
set_quantization_params
(
TensorWrapper
*
tensor
)
const
{}
...
@@ -575,10 +576,6 @@ std::pair<TensorWrapper, py::object> Float8BlockQuantizer::create_tensor(
...
@@ -575,10 +576,6 @@ std::pair<TensorWrapper, py::object> Float8BlockQuantizer::create_tensor(
opts
=
opts
.
dtype
(
torch
::
kUInt8
).
device
(
torch
::
kCUDA
);
opts
=
opts
.
dtype
(
torch
::
kUInt8
).
device
(
torch
::
kCUDA
);
scale_opts
=
scale_opts
.
dtype
(
torch
::
kFloat32
).
device
(
torch
::
kCUDA
);
scale_opts
=
scale_opts
.
dtype
(
torch
::
kFloat32
).
device
(
torch
::
kCUDA
);
Float8BlockScaleTensorFormat
data_format
=
(
all_gather_usage
?
Float8BlockScaleTensorFormat
::
COMPACT
:
Float8BlockScaleTensorFormat
::
GEMM_READY
);
if
(
rowwise_usage
)
{
if
(
rowwise_usage
)
{
data_rowwise
=
at
::
empty
(
torch_shape
,
opts
);
data_rowwise
=
at
::
empty
(
torch_shape
,
opts
);
auto
scale_shape
=
get_scale_shape
(
shape
,
false
);
auto
scale_shape
=
get_scale_shape
(
shape
,
false
);
...
@@ -597,7 +594,6 @@ std::pair<TensorWrapper, py::object> Float8BlockQuantizer::create_tensor(
...
@@ -597,7 +594,6 @@ std::pair<TensorWrapper, py::object> Float8BlockQuantizer::create_tensor(
NVTE_CHECK
(
torch_shape
.
size
()
==
shape
.
size
(),
"Shape expected to match torch shape. Shape "
,
NVTE_CHECK
(
torch_shape
.
size
()
==
shape
.
size
(),
"Shape expected to match torch shape. Shape "
,
columnwise_shape
,
" torch shape: "
,
torch_columnwise_shape
);
columnwise_shape
,
" torch shape: "
,
torch_columnwise_shape
);
if
(
torch_shape
.
size
()
>
0
)
{
if
(
torch_shape
.
size
()
>
0
)
{
if
(
!
all_gather_usage
)
{
torch_columnwise_shape
.
reserve
(
torch_shape
.
size
());
torch_columnwise_shape
.
reserve
(
torch_shape
.
size
());
columnwise_shape
.
reserve
(
shape
.
size
());
columnwise_shape
.
reserve
(
shape
.
size
());
torch_columnwise_shape
.
push_back
(
torch_shape
[
torch_shape
.
size
()
-
1
]);
torch_columnwise_shape
.
push_back
(
torch_shape
[
torch_shape
.
size
()
-
1
]);
...
@@ -606,13 +602,6 @@ std::pair<TensorWrapper, py::object> Float8BlockQuantizer::create_tensor(
...
@@ -606,13 +602,6 @@ std::pair<TensorWrapper, py::object> Float8BlockQuantizer::create_tensor(
torch_columnwise_shape
.
push_back
(
torch_shape
[
i
]);
torch_columnwise_shape
.
push_back
(
torch_shape
[
i
]);
columnwise_shape
.
push_back
(
shape
[
i
]);
columnwise_shape
.
push_back
(
shape
[
i
]);
}
}
}
else
{
// assert we are doing 1D scaling
NVTE_CHECK
(
block_scaling_dim
==
1
,
"Compact columnwise format is not supported for 128x128 2D block scaling."
);
torch_columnwise_shape
=
torch_shape
;
columnwise_shape
=
shape
;
}
}
}
auto
scale_shape
=
get_scale_shape
(
shape
,
true
);
auto
scale_shape
=
get_scale_shape
(
shape
,
true
);
size_t
sinv0
=
scale_shape
[
0
];
size_t
sinv0
=
scale_shape
[
0
];
...
@@ -635,7 +624,7 @@ std::pair<TensorWrapper, py::object> Float8BlockQuantizer::create_tensor(
...
@@ -635,7 +624,7 @@ std::pair<TensorWrapper, py::object> Float8BlockQuantizer::create_tensor(
"rowwise_data"
_a
=
data_rowwise
,
"columnwise_data"
_a
=
data_colwise
,
"rowwise_data"
_a
=
data_rowwise
,
"columnwise_data"
_a
=
data_colwise
,
"rowwise_scale_inv"
_a
=
scale_inv_rowwise
,
"columnwise_scale_inv"
_a
=
scale_inv_colwise
,
"rowwise_scale_inv"
_a
=
scale_inv_rowwise
,
"columnwise_scale_inv"
_a
=
scale_inv_colwise
,
"fp8_dtype"
_a
=
this
->
dtype
,
"quantizer"
_a
=
this
->
quantizer
,
"fp8_dtype"
_a
=
this
->
dtype
,
"quantizer"
_a
=
this
->
quantizer
,
"is_2D_scaled"
_a
=
(
block_scaling_dim
==
2
)
,
"data_format"
_a
=
data_format
);
"is_2D_scaled"
_a
=
(
block_scaling_dim
==
2
));
}
else
{
}
else
{
py
::
handle
Float8BlockwiseQTensorClass
(
py
::
handle
Float8BlockwiseQTensorClass
(
reinterpret_cast
<
PyObject
*>
(
Float8BlockwiseQTensorPythonClass
));
reinterpret_cast
<
PyObject
*>
(
Float8BlockwiseQTensorPythonClass
));
...
@@ -643,8 +632,7 @@ std::pair<TensorWrapper, py::object> Float8BlockQuantizer::create_tensor(
...
@@ -643,8 +632,7 @@ std::pair<TensorWrapper, py::object> Float8BlockQuantizer::create_tensor(
"shape"
_a
=
torch_shape
,
"dtype"
_a
=
GetATenDType
(
dtype
),
"rowwise_data"
_a
=
data_rowwise
,
"shape"
_a
=
torch_shape
,
"dtype"
_a
=
GetATenDType
(
dtype
),
"rowwise_data"
_a
=
data_rowwise
,
"columnwise_data"
_a
=
data_colwise
,
"rowwise_scale_inv"
_a
=
scale_inv_rowwise
,
"columnwise_data"
_a
=
data_colwise
,
"rowwise_scale_inv"
_a
=
scale_inv_rowwise
,
"columnwise_scale_inv"
_a
=
scale_inv_colwise
,
"fp8_dtype"
_a
=
this
->
dtype
,
"columnwise_scale_inv"
_a
=
scale_inv_colwise
,
"fp8_dtype"
_a
=
this
->
dtype
,
"quantizer"
_a
=
this
->
quantizer
,
"is_2D_scaled"
_a
=
(
block_scaling_dim
==
2
),
"quantizer"
_a
=
this
->
quantizer
,
"is_2D_scaled"
_a
=
(
block_scaling_dim
==
2
));
"data_format"
_a
=
data_format
);
}
}
return
{
std
::
move
(
tensor
),
std
::
move
(
ret
)};
return
{
std
::
move
(
tensor
),
std
::
move
(
ret
)};
...
@@ -654,6 +642,7 @@ std::pair<TensorWrapper, py::object> Float8BlockQuantizer::convert_and_update_te
...
@@ -654,6 +642,7 @@ std::pair<TensorWrapper, py::object> Float8BlockQuantizer::convert_and_update_te
py
::
object
tensor
)
const
{
py
::
object
tensor
)
const
{
const
DType
dtype
=
tensor
.
attr
(
"_fp8_dtype"
).
cast
<
DType
>
();
const
DType
dtype
=
tensor
.
attr
(
"_fp8_dtype"
).
cast
<
DType
>
();
bool
is_2D_scaled
=
tensor
.
attr
(
"_is_2D_scaled"
).
cast
<
bool
>
();
bool
is_2D_scaled
=
tensor
.
attr
(
"_is_2D_scaled"
).
cast
<
bool
>
();
const
bool
with_gemm_swizzled_scales
=
true
;
// Extract buffers from Python tensor
// Extract buffers from Python tensor
auto
get_tensor
=
[
&
tensor
](
const
char
*
name
)
->
std
::
optional
<
at
::
Tensor
>
{
auto
get_tensor
=
[
&
tensor
](
const
char
*
name
)
->
std
::
optional
<
at
::
Tensor
>
{
...
@@ -675,13 +664,10 @@ std::pair<TensorWrapper, py::object> Float8BlockQuantizer::convert_and_update_te
...
@@ -675,13 +664,10 @@ std::pair<TensorWrapper, py::object> Float8BlockQuantizer::convert_and_update_te
opts
=
opts
.
dtype
(
torch
::
kUInt8
).
device
(
torch
::
kCUDA
);
opts
=
opts
.
dtype
(
torch
::
kUInt8
).
device
(
torch
::
kCUDA
);
scale_opts
=
scale_opts
.
dtype
(
torch
::
kFloat32
).
device
(
torch
::
kCUDA
);
scale_opts
=
scale_opts
.
dtype
(
torch
::
kFloat32
).
device
(
torch
::
kCUDA
);
auto
get_columnwise_shape
=
[
&
columnwise_data
](
bool
all_gather_usage
)
->
std
::
vector
<
size_t
>
{
auto
get_columnwise_shape
=
[
&
columnwise_data
]()
->
std
::
vector
<
size_t
>
{
if
(
!
columnwise_data
)
{
if
(
!
columnwise_data
)
{
return
std
::
vector
<
size_t
>
();
return
std
::
vector
<
size_t
>
();
}
}
if
(
all_gather_usage
)
{
return
getTensorShape
(
*
columnwise_data
);
}
std
::
vector
<
size_t
>
shape
=
getTensorShape
(
*
columnwise_data
);
std
::
vector
<
size_t
>
shape
=
getTensorShape
(
*
columnwise_data
);
std
::
vector
<
size_t
>
shape_transposed
(
shape
.
size
());
std
::
vector
<
size_t
>
shape_transposed
(
shape
.
size
());
for
(
size_t
i
=
0
;
i
+
1
<
shape
.
size
();
++
i
)
{
for
(
size_t
i
=
0
;
i
+
1
<
shape
.
size
();
++
i
)
{
...
@@ -696,12 +682,12 @@ std::pair<TensorWrapper, py::object> Float8BlockQuantizer::convert_and_update_te
...
@@ -696,12 +682,12 @@ std::pair<TensorWrapper, py::object> Float8BlockQuantizer::convert_and_update_te
if
(
rowwise_data
)
{
if
(
rowwise_data
)
{
shape
=
getTensorShape
(
*
rowwise_data
);
shape
=
getTensorShape
(
*
rowwise_data
);
if
(
columnwise_data
)
{
if
(
columnwise_data
)
{
auto
expected_shape
=
get_columnwise_shape
(
all_gather_usage
);
auto
expected_shape
=
get_columnwise_shape
();
NVTE_CHECK
(
shape
==
expected_shape
,
"BlockwiseFP8 row-wise data (shape="
,
shape
,
NVTE_CHECK
(
shape
==
expected_shape
,
"BlockwiseFP8 row-wise data (shape="
,
shape
,
") and column-wise data (shape="
,
expected_shape
,
") do not match"
);
") and column-wise data (shape="
,
expected_shape
,
") do not match"
);
}
}
}
else
{
}
else
{
shape
=
get_columnwise_shape
(
all_gather_usage
);
shape
=
get_columnwise_shape
();
}
}
std
::
vector
<
int64_t
>
torch_shape
;
std
::
vector
<
int64_t
>
torch_shape
;
for
(
auto
s
:
shape
)
{
for
(
auto
s
:
shape
)
{
...
@@ -738,7 +724,6 @@ std::pair<TensorWrapper, py::object> Float8BlockQuantizer::convert_and_update_te
...
@@ -738,7 +724,6 @@ std::pair<TensorWrapper, py::object> Float8BlockQuantizer::convert_and_update_te
std
::
vector
<
size_t
>
columnwise_shape
;
std
::
vector
<
size_t
>
columnwise_shape
;
std
::
vector
<
int64_t
>
torch_columnwise_shape
;
std
::
vector
<
int64_t
>
torch_columnwise_shape
;
if
(
torch_shape
.
size
()
>
0
)
{
if
(
torch_shape
.
size
()
>
0
)
{
if
(
!
all_gather_usage
)
{
torch_columnwise_shape
.
reserve
(
torch_shape
.
size
());
torch_columnwise_shape
.
reserve
(
torch_shape
.
size
());
columnwise_shape
.
reserve
(
shape
.
size
());
columnwise_shape
.
reserve
(
shape
.
size
());
torch_columnwise_shape
.
push_back
(
torch_shape
[
torch_shape
.
size
()
-
1
]);
torch_columnwise_shape
.
push_back
(
torch_shape
[
torch_shape
.
size
()
-
1
]);
...
@@ -747,13 +732,6 @@ std::pair<TensorWrapper, py::object> Float8BlockQuantizer::convert_and_update_te
...
@@ -747,13 +732,6 @@ std::pair<TensorWrapper, py::object> Float8BlockQuantizer::convert_and_update_te
torch_columnwise_shape
.
push_back
(
torch_shape
[
i
]);
torch_columnwise_shape
.
push_back
(
torch_shape
[
i
]);
columnwise_shape
.
push_back
(
shape
[
i
]);
columnwise_shape
.
push_back
(
shape
[
i
]);
}
}
}
else
{
// assert we are doing 1D scaling
NVTE_CHECK
(
block_scaling_dim
==
1
,
"Compact columnwise format is not supported for 128x128 2D block scaling."
);
torch_columnwise_shape
=
torch_shape
;
columnwise_shape
=
shape
;
}
}
}
if
(
!
columnwise_data
)
{
if
(
!
columnwise_data
)
{
columnwise_data
=
at
::
empty
(
torch_columnwise_shape
,
opts
);
columnwise_data
=
at
::
empty
(
torch_columnwise_shape
,
opts
);
...
@@ -798,6 +776,7 @@ std::pair<TensorWrapper, py::object> Float8BlockQuantizer::convert_and_update_te
...
@@ -798,6 +776,7 @@ std::pair<TensorWrapper, py::object> Float8BlockQuantizer::convert_and_update_te
const
auto
scale_inv_colwise_shape
=
getTensorShape
(
scale_inv_colwise
);
const
auto
scale_inv_colwise_shape
=
getTensorShape
(
scale_inv_colwise
);
ret
.
set_columnwise_scale_inv
(
scale_inv_colwise_dptr
,
DType
::
kFloat32
,
scale_inv_colwise_shape
);
ret
.
set_columnwise_scale_inv
(
scale_inv_colwise_dptr
,
DType
::
kFloat32
,
scale_inv_colwise_shape
);
}
}
ret
.
set_with_gemm_swizzled_scales
(
with_gemm_swizzled_scales
);
set_quantization_params
(
&
ret
);
set_quantization_params
(
&
ret
);
return
{
std
::
move
(
ret
),
std
::
move
(
tensor
)};
return
{
std
::
move
(
ret
),
std
::
move
(
tensor
)};
}
}
...
@@ -813,9 +792,6 @@ void Float8BlockQuantizer::quantize(const TensorWrapper& input, TensorWrapper& o
...
@@ -813,9 +792,6 @@ void Float8BlockQuantizer::quantize(const TensorWrapper& input, TensorWrapper& o
}
}
quant_config
.
set_force_pow_2_scales
(
force_pow_2_scales
);
quant_config
.
set_force_pow_2_scales
(
force_pow_2_scales
);
quant_config
.
set_amax_epsilon
(
amax_epsilon
);
quant_config
.
set_amax_epsilon
(
amax_epsilon
);
if
(
all_gather_usage
)
{
quant_config
.
set_float8_block_scale_tensor_format
(
Float8BlockScaleTensorFormat
::
COMPACT
);
}
NVTE_SCOPED_GIL_RELEASE
({
NVTE_SCOPED_GIL_RELEASE
({
nvte_quantize_v2
(
input
.
data
(),
out
.
data
(),
quant_config
,
at
::
cuda
::
getCurrentCUDAStream
());
nvte_quantize_v2
(
input
.
data
(),
out
.
data
(),
quant_config
,
at
::
cuda
::
getCurrentCUDAStream
());
});
});
...
@@ -832,10 +808,6 @@ std::vector<size_t> Float8BlockQuantizer::get_scale_shape(const std::vector<size
...
@@ -832,10 +808,6 @@ std::vector<size_t> Float8BlockQuantizer::get_scale_shape(const std::vector<size
size_t
m_dim
=
numel
/
k_dim
;
size_t
m_dim
=
numel
/
k_dim
;
size_t
kBlockLen
=
static_cast
<
size_t
>
(
blockwise_fp8_block_len
());
size_t
kBlockLen
=
static_cast
<
size_t
>
(
blockwise_fp8_block_len
());
Float8BlockScaleTensorFormat
data_format
=
(
all_gather_usage
?
Float8BlockScaleTensorFormat
::
COMPACT
:
Float8BlockScaleTensorFormat
::
GEMM_READY
);
std
::
vector
<
size_t
>
scale_shape
;
std
::
vector
<
size_t
>
scale_shape
;
bool
rowwise_usage
=
!
columnwise
;
bool
rowwise_usage
=
!
columnwise
;
...
@@ -845,23 +817,14 @@ std::vector<size_t> Float8BlockQuantizer::get_scale_shape(const std::vector<size
...
@@ -845,23 +817,14 @@ std::vector<size_t> Float8BlockQuantizer::get_scale_shape(const std::vector<size
size_t
sinv0
=
0
;
size_t
sinv0
=
0
;
size_t
sinv1
=
0
;
size_t
sinv1
=
0
;
if
(
block_scaling_dim
==
2
)
{
if
(
block_scaling_dim
==
2
)
{
// 2D scaling is always GEMM_READY for now
sinv0
=
ceildiv
(
m_dim
,
kBlockLen
);
NVTE_CHECK
(
data_format
==
Float8BlockScaleTensorFormat
::
GEMM_READY
,
sinv1
=
roundup
(
ceildiv
(
k_dim
,
kBlockLen
),
4
);
"2D scaling is always GEMM_READY for now."
);
sinv0
=
(
m_dim
+
kBlockLen
-
1
)
/
kBlockLen
;
sinv1
=
roundup
((
k_dim
+
kBlockLen
-
1
)
/
kBlockLen
,
4
);
}
else
if
(
block_scaling_dim
==
1
)
{
}
else
if
(
block_scaling_dim
==
1
)
{
// 1D scaling can be GEMM_READY or COMPACT
bool
rowwise_compact
=
data_format
==
Float8BlockScaleTensorFormat
::
COMPACT
;
// default rowwise scaling factor shape already transpose the scaling factor so it's GEMM_READY
// default rowwise scaling factor shape already transpose the scaling factor so it's GEMM_READY
sinv0
=
(
k_dim
+
kBlockLen
-
1
)
/
kBlockLen
;
sinv0
=
ceildiv
(
k_dim
,
kBlockLen
);
sinv1
=
rowwise_compact
?
m_dim
:
roundup
(
m_dim
,
4
);
sinv1
=
roundup
(
m_dim
,
4
);
// if the rowwise format is compact, the scaling factor is not be transposed
if
(
rowwise_compact
)
{
std
::
swap
(
sinv0
,
sinv1
);
}
}
else
{
}
else
{
NVTE_
CHECK
(
false
,
NVTE_
ERROR
(
"Unsupported block_scaling_dim in create_tensor rowwise."
"Unsupported block_scaling_dim in create_tensor rowwise."
"Expected 1 or 2. Got "
,
"Expected 1 or 2. Got "
,
block_scaling_dim
);
block_scaling_dim
);
...
@@ -872,21 +835,13 @@ std::vector<size_t> Float8BlockQuantizer::get_scale_shape(const std::vector<size
...
@@ -872,21 +835,13 @@ std::vector<size_t> Float8BlockQuantizer::get_scale_shape(const std::vector<size
size_t
sinv0
=
0
;
size_t
sinv0
=
0
;
size_t
sinv1
=
0
;
size_t
sinv1
=
0
;
if
(
block_scaling_dim
==
2
)
{
if
(
block_scaling_dim
==
2
)
{
// 2D scaling is always GEMM_READY for now
sinv0
=
ceildiv
(
k_dim
,
kBlockLen
);
NVTE_CHECK
(
data_format
==
Float8BlockScaleTensorFormat
::
GEMM_READY
,
sinv1
=
roundup
(
ceildiv
(
m_dim
,
kBlockLen
),
4
);
"2D scaling is always GEMM_READY for now."
);
sinv0
=
(
k_dim
+
kBlockLen
-
1
)
/
kBlockLen
;
sinv1
=
roundup
((
m_dim
+
kBlockLen
-
1
)
/
kBlockLen
,
4
);
}
else
if
(
block_scaling_dim
==
1
)
{
}
else
if
(
block_scaling_dim
==
1
)
{
// 1D scaling can be GEMM_READY or COMPACT
sinv0
=
ceildiv
(
m_dim
,
kBlockLen
);
bool
columnwise_compact
=
data_format
==
Float8BlockScaleTensorFormat
::
COMPACT
;
sinv1
=
roundup
(
k_dim
,
4
);
sinv0
=
(
m_dim
+
kBlockLen
-
1
)
/
kBlockLen
;
sinv1
=
columnwise_compact
?
k_dim
:
roundup
(
k_dim
,
4
);
// GEMM READY case: scaling factor is [sinv0, sinv1], already transposed here for CuBLAS
// for COMPACT case, since we apply 128x1 scaling here without transposing columnwise data, scaling factor is also [sinv0, sinv1]
// so no need to swap sinv0 and sinv1 here
}
else
{
}
else
{
NVTE_
CHECK
(
false
,
NVTE_
ERROR
(
"Unsupported block_scaling_dim in create_tensor columnwise."
"Unsupported block_scaling_dim in create_tensor columnwise."
"Expected 1 or 2. Got "
,
"Expected 1 or 2. Got "
,
block_scaling_dim
);
block_scaling_dim
);
...
@@ -906,6 +861,9 @@ std::pair<TensorWrapper, py::object> MXFP8Quantizer::create_tensor(const std::ve
...
@@ -906,6 +861,9 @@ std::pair<TensorWrapper, py::object> MXFP8Quantizer::create_tensor(const std::ve
DType
dtype
)
const
{
DType
dtype
)
const
{
using
namespace
pybind11
::
literals
;
using
namespace
pybind11
::
literals
;
// Scaling factor format
const
bool
with_gemm_swizzled_scales
=
this
->
optimize_for_gemm
;
// Tensor dimensions
// Tensor dimensions
const
std
::
vector
<
int64_t
>
shape_int64
(
shape
.
begin
(),
shape
.
end
());
const
std
::
vector
<
int64_t
>
shape_int64
(
shape
.
begin
(),
shape
.
end
());
size_t
flat_first_dim
=
1
;
size_t
flat_first_dim
=
1
;
...
@@ -951,19 +909,17 @@ std::pair<TensorWrapper, py::object> MXFP8Quantizer::create_tensor(const std::ve
...
@@ -951,19 +909,17 @@ std::pair<TensorWrapper, py::object> MXFP8Quantizer::create_tensor(const std::ve
py
::
object
out_py
;
py
::
object
out_py
;
if
(
internal
)
{
if
(
internal
)
{
py
::
handle
MXFP8TensorClass
(
reinterpret_cast
<
PyObject
*>
(
MXFP8TensorStoragePythonClass
));
py
::
handle
MXFP8TensorClass
(
reinterpret_cast
<
PyObject
*>
(
MXFP8TensorStoragePythonClass
));
out_py
=
MXFP8TensorClass
(
"rowwise_data"
_a
=
rowwise_data_py
,
out_py
=
MXFP8TensorClass
(
rowwise_data_py
,
rowwise_scale_inv_py
,
columnwise_data_py
,
"columnwise_data"
_a
=
columnwise_data_py
,
columnwise_scale_inv_py
,
this
->
dtype
,
this
->
quantizer
,
"rowwise_scale_inv"
_a
=
rowwise_scale_inv_py
,
with_gemm_swizzled_scales
);
"columnwise_scale_inv"
_a
=
columnwise_scale_inv_py
,
"fp8_dtype"
_a
=
this
->
dtype
,
"quantizer"
_a
=
this
->
quantizer
);
}
else
{
}
else
{
py
::
handle
MXFP8TensorClass
(
reinterpret_cast
<
PyObject
*>
(
MXFP8TensorPythonClass
));
py
::
handle
MXFP8TensorClass
(
reinterpret_cast
<
PyObject
*>
(
MXFP8TensorPythonClass
));
out_py
=
MXFP8TensorClass
(
"shape"
_a
=
shape_int64
,
"dtype"
_a
=
GetATenDType
(
dtype
),
out_py
=
MXFP8TensorClass
(
"rowwise_data"
_a
=
rowwise_data_py
,
"shape"
_a
=
shape_int64
,
"dtype"
_a
=
GetATenDType
(
dtype
)
,
"columnwise_data"
_a
=
columnwise_data_py
,
"rowwise_data"
_a
=
rowwise_data_py
,
"columnwise_data"
_a
=
columnwise_data_py
,
"rowwise_scale_inv"
_a
=
rowwise_scale_inv_py
,
"rowwise_scale_inv"
_a
=
rowwise_scale_inv_py
,
"columnwise_scale_inv"
_a
=
columnwise_scale_inv_py
,
"columnwise_scale_inv"
_a
=
columnwise_scale_inv_py
,
"fp8_dtype"
_a
=
this
->
dtype
,
"fp8_dtype
"
_a
=
this
->
dtype
,
"
quantizer
"
_a
=
this
->
quantizer
);
"quantizer
"
_a
=
this
->
quantizer
,
"with_gemm_swizzled_scales"
_a
=
with_gemm_swizzled_scales
);
}
}
// Construct C++ MXFP8 tensor
// Construct C++ MXFP8 tensor
...
@@ -978,6 +934,7 @@ std::pair<TensorWrapper, py::object> MXFP8Quantizer::create_tensor(const std::ve
...
@@ -978,6 +934,7 @@ std::pair<TensorWrapper, py::object> MXFP8Quantizer::create_tensor(const std::ve
out_cpp
.
set_columnwise_scale_inv
(
columnwise_scale_inv_tensor
.
data_ptr
(),
DType
::
kFloat8E8M0
,
out_cpp
.
set_columnwise_scale_inv
(
columnwise_scale_inv_tensor
.
data_ptr
(),
DType
::
kFloat8E8M0
,
columnwise_scale_inv_shape
);
columnwise_scale_inv_shape
);
}
}
out_cpp
.
set_with_gemm_swizzled_scales
(
with_gemm_swizzled_scales
);
this
->
set_quantization_params
(
&
out_cpp
);
this
->
set_quantization_params
(
&
out_cpp
);
return
{
std
::
move
(
out_cpp
),
std
::
move
(
out_py
)};
return
{
std
::
move
(
out_cpp
),
std
::
move
(
out_py
)};
...
@@ -987,6 +944,9 @@ std::pair<TensorWrapper, py::object> MXFP8Quantizer::convert_and_update_tensor(
...
@@ -987,6 +944,9 @@ std::pair<TensorWrapper, py::object> MXFP8Quantizer::convert_and_update_tensor(
py
::
object
tensor
)
const
{
py
::
object
tensor
)
const
{
NVTE_CHECK
(
detail
::
IsMXFP8Tensor
(
tensor
.
ptr
()),
"MXFP8Quantizer must output to MXFP8Tensor."
);
NVTE_CHECK
(
detail
::
IsMXFP8Tensor
(
tensor
.
ptr
()),
"MXFP8Quantizer must output to MXFP8Tensor."
);
// Scaling factor format
const
bool
with_gemm_swizzled_scales
=
this
->
optimize_for_gemm
;
// Extract buffers from Python tensor
// Extract buffers from Python tensor
auto
get_tensor
=
[
&
tensor
](
const
char
*
name
)
->
std
::
optional
<
at
::
Tensor
>
{
auto
get_tensor
=
[
&
tensor
](
const
char
*
name
)
->
std
::
optional
<
at
::
Tensor
>
{
auto
attr_py
=
tensor
.
attr
(
name
);
auto
attr_py
=
tensor
.
attr
(
name
);
...
@@ -1070,6 +1030,7 @@ std::pair<TensorWrapper, py::object> MXFP8Quantizer::convert_and_update_tensor(
...
@@ -1070,6 +1030,7 @@ std::pair<TensorWrapper, py::object> MXFP8Quantizer::convert_and_update_tensor(
// Coerce other attrs
// Coerce other attrs
tensor
.
attr
(
"_fp8_dtype"
)
=
dtype
;
tensor
.
attr
(
"_fp8_dtype"
)
=
dtype
;
tensor
.
attr
(
"_with_gemm_swizzled_scales"
)
=
with_gemm_swizzled_scales
;
// Construct C++ MXFP8 tensor
// Construct C++ MXFP8 tensor
TensorWrapper
out_cpp
(
NVTE_MXFP8_1D_SCALING
);
TensorWrapper
out_cpp
(
NVTE_MXFP8_1D_SCALING
);
...
@@ -1083,6 +1044,7 @@ std::pair<TensorWrapper, py::object> MXFP8Quantizer::convert_and_update_tensor(
...
@@ -1083,6 +1044,7 @@ std::pair<TensorWrapper, py::object> MXFP8Quantizer::convert_and_update_tensor(
out_cpp
.
set_columnwise_scale_inv
(
columnwise_scale_inv
->
data_ptr
(),
DType
::
kFloat8E8M0
,
out_cpp
.
set_columnwise_scale_inv
(
columnwise_scale_inv
->
data_ptr
(),
DType
::
kFloat8E8M0
,
getTensorShape
(
*
columnwise_scale_inv
));
getTensorShape
(
*
columnwise_scale_inv
));
}
}
out_cpp
.
set_with_gemm_swizzled_scales
(
with_gemm_swizzled_scales
);
this
->
set_quantization_params
(
&
out_cpp
);
this
->
set_quantization_params
(
&
out_cpp
);
return
{
std
::
move
(
out_cpp
),
std
::
move
(
tensor
)};
return
{
std
::
move
(
out_cpp
),
std
::
move
(
tensor
)};
...
@@ -1173,6 +1135,9 @@ std::pair<TensorWrapper, py::object> NVFP4Quantizer::create_tensor(const std::ve
...
@@ -1173,6 +1135,9 @@ std::pair<TensorWrapper, py::object> NVFP4Quantizer::create_tensor(const std::ve
DType
dtype
)
const
{
DType
dtype
)
const
{
using
namespace
pybind11
::
literals
;
using
namespace
pybind11
::
literals
;
// Scaling factor format
const
bool
with_gemm_swizzled_scales
=
false
;
/// TODO (tmoon) self->optimize_for_gemm
// Tensor dimensions
// Tensor dimensions
const
std
::
vector
<
int64_t
>
shape_int64
(
shape
.
begin
(),
shape
.
end
());
const
std
::
vector
<
int64_t
>
shape_int64
(
shape
.
begin
(),
shape
.
end
());
size_t
flat_first_dim
=
1
;
size_t
flat_first_dim
=
1
;
...
@@ -1235,12 +1200,9 @@ std::pair<TensorWrapper, py::object> NVFP4Quantizer::create_tensor(const std::ve
...
@@ -1235,12 +1200,9 @@ std::pair<TensorWrapper, py::object> NVFP4Quantizer::create_tensor(const std::ve
py
::
object
out_py
;
py
::
object
out_py
;
if
(
internal
)
{
if
(
internal
)
{
py
::
handle
NVFP4TensorClass
(
reinterpret_cast
<
PyObject
*>
(
NVFP4TensorStoragePythonClass
));
py
::
handle
NVFP4TensorClass
(
reinterpret_cast
<
PyObject
*>
(
NVFP4TensorStoragePythonClass
));
out_py
=
NVFP4TensorClass
(
out_py
=
NVFP4TensorClass
(
rowwise_data_py
,
rowwise_scale_inv_py
,
columnwise_data_py
,
"rowwise_data"
_a
=
rowwise_data_py
,
"columnwise_data"
_a
=
columnwise_data_py
,
columnwise_scale_inv_py
,
amax_rowwise_py
,
amax_columnwise_py
,
"rowwise_scale_inv"
_a
=
rowwise_scale_inv_py
,
this
->
dtype
,
this
->
quantizer
,
with_gemm_swizzled_scales
);
"columnwise_scale_inv"
_a
=
columnwise_scale_inv_py
,
"amax_rowwise"
_a
=
amax_rowwise_py
,
"amax_columnwise"
_a
=
amax_columnwise_py
,
"fp4_dtype"
_a
=
this
->
dtype
,
"quantizer"
_a
=
this
->
quantizer
);
}
else
{
}
else
{
py
::
handle
NVFP4TensorClass
(
reinterpret_cast
<
PyObject
*>
(
NVFP4TensorPythonClass
));
py
::
handle
NVFP4TensorClass
(
reinterpret_cast
<
PyObject
*>
(
NVFP4TensorPythonClass
));
out_py
=
NVFP4TensorClass
(
out_py
=
NVFP4TensorClass
(
...
@@ -1249,7 +1211,7 @@ std::pair<TensorWrapper, py::object> NVFP4Quantizer::create_tensor(const std::ve
...
@@ -1249,7 +1211,7 @@ std::pair<TensorWrapper, py::object> NVFP4Quantizer::create_tensor(const std::ve
"rowwise_scale_inv"
_a
=
rowwise_scale_inv_py
,
"rowwise_scale_inv"
_a
=
rowwise_scale_inv_py
,
"columnwise_scale_inv"
_a
=
columnwise_scale_inv_py
,
"amax_rowwise"
_a
=
amax_rowwise_py
,
"columnwise_scale_inv"
_a
=
columnwise_scale_inv_py
,
"amax_rowwise"
_a
=
amax_rowwise_py
,
"amax_columnwise"
_a
=
amax_columnwise_py
,
"fp4_dtype"
_a
=
this
->
dtype
,
"amax_columnwise"
_a
=
amax_columnwise_py
,
"fp4_dtype"
_a
=
this
->
dtype
,
"quantizer"
_a
=
this
->
quantizer
);
"quantizer"
_a
=
this
->
quantizer
,
"with_gemm_swizzled_scales"
_a
=
with_gemm_swizzled_scales
);
}
}
// Construct C++ tensor
// Construct C++ tensor
...
@@ -1272,6 +1234,7 @@ std::pair<TensorWrapper, py::object> NVFP4Quantizer::create_tensor(const std::ve
...
@@ -1272,6 +1234,7 @@ std::pair<TensorWrapper, py::object> NVFP4Quantizer::create_tensor(const std::ve
out_cpp
.
set_columnwise_amax
(
amax_columnwise
.
data_ptr
(),
DType
::
kFloat32
,
out_cpp
.
set_columnwise_amax
(
amax_columnwise
.
data_ptr
(),
DType
::
kFloat32
,
std
::
vector
<
size_t
>
{
1
});
std
::
vector
<
size_t
>
{
1
});
}
}
out_cpp
.
set_with_gemm_swizzled_scales
(
with_gemm_swizzled_scales
);
this
->
set_quantization_params
(
&
out_cpp
);
this
->
set_quantization_params
(
&
out_cpp
);
return
{
std
::
move
(
out_cpp
),
std
::
move
(
out_py
)};
return
{
std
::
move
(
out_cpp
),
std
::
move
(
out_py
)};
...
@@ -1301,6 +1264,9 @@ std::pair<TensorWrapper, py::object> NVFP4Quantizer::convert_and_update_tensor(
...
@@ -1301,6 +1264,9 @@ std::pair<TensorWrapper, py::object> NVFP4Quantizer::convert_and_update_tensor(
py
::
object
tensor
)
const
{
py
::
object
tensor
)
const
{
NVTE_CHECK
(
detail
::
IsNVFP4Tensor
(
tensor
.
ptr
()),
"NVFP4Quantizer must output to IsNVFP4Tensor."
);
NVTE_CHECK
(
detail
::
IsNVFP4Tensor
(
tensor
.
ptr
()),
"NVFP4Quantizer must output to IsNVFP4Tensor."
);
// Scaling factor format
const
bool
with_gemm_swizzled_scales
=
false
;
// TODO (tmoon) Enable with optimize_for_gemm
// Extract buffers from Python tensor
// Extract buffers from Python tensor
auto
get_tensor
=
[
&
tensor
](
const
char
*
name
)
->
std
::
optional
<
at
::
Tensor
>
{
auto
get_tensor
=
[
&
tensor
](
const
char
*
name
)
->
std
::
optional
<
at
::
Tensor
>
{
auto
attr_py
=
tensor
.
attr
(
name
);
auto
attr_py
=
tensor
.
attr
(
name
);
...
@@ -1438,6 +1404,7 @@ std::pair<TensorWrapper, py::object> NVFP4Quantizer::convert_and_update_tensor(
...
@@ -1438,6 +1404,7 @@ std::pair<TensorWrapper, py::object> NVFP4Quantizer::convert_and_update_tensor(
out_cpp
.
set_columnwise_amax
(
amax_columnwise
->
data_ptr
(),
DType
::
kFloat32
,
out_cpp
.
set_columnwise_amax
(
amax_columnwise
->
data_ptr
(),
DType
::
kFloat32
,
std
::
vector
<
size_t
>
{
1
});
std
::
vector
<
size_t
>
{
1
});
}
}
out_cpp
.
set_with_gemm_swizzled_scales
(
with_gemm_swizzled_scales
);
this
->
set_quantization_params
(
&
out_cpp
);
this
->
set_quantization_params
(
&
out_cpp
);
return
{
std
::
move
(
out_cpp
),
std
::
move
(
tensor
)};
return
{
std
::
move
(
out_cpp
),
std
::
move
(
tensor
)};
...
@@ -1468,20 +1435,40 @@ void NVFP4Quantizer::quantize_impl(const TensorWrapper& input, TensorWrapper& ou
...
@@ -1468,20 +1435,40 @@ void NVFP4Quantizer::quantize_impl(const TensorWrapper& input, TensorWrapper& ou
}
}
size_t
cols
=
input
.
size
(
input
.
ndim
()
-
1
);
size_t
cols
=
input
.
size
(
input
.
ndim
()
-
1
);
// Stochastic rounding
// When both rowwise and columnwise quantization are used with RHT,
// we need separate RNG states for each to ensure they use different random numbers.
TensorWrapper
te_rng_state
;
TensorWrapper
te_rng_state
;
TensorWrapper
te_rng_state_columnwise
;
QuantizationConfigWrapper
quant_config_columnwise
;
const
bool
need_separate_columnwise_rng
=
this
->
stochastic_rounding
&&
this
->
with_rht
&&
this
->
columnwise_usage
;
if
(
this
->
stochastic_rounding
)
{
if
(
this
->
stochastic_rounding
)
{
const
size_t
rng_elts_per_thread
=
1024
;
// Wild guess, probably can be tightened
const
size_t
rng_elts_per_thread
=
1024
;
// Wild guess, probably can be tightened
auto
gen
=
at
::
get_generator_or_default
<
at
::
CUDAGeneratorImpl
>
(
auto
gen
=
at
::
get_generator_or_default
<
at
::
CUDAGeneratorImpl
>
(
std
::
nullopt
,
at
::
cuda
::
detail
::
getDefaultCUDAGenerator
());
std
::
nullopt
,
at
::
cuda
::
detail
::
getDefaultCUDAGenerator
());
at
::
PhiloxCudaState
philox_args
=
init_philox_state
(
gen
,
rng_elts_per_thread
);
auto
opts
=
at
::
TensorOptions
().
dtype
(
torch
::
kInt64
).
device
(
torch
::
kCUDA
);
auto
opts
=
at
::
TensorOptions
().
dtype
(
torch
::
kInt64
).
device
(
torch
::
kCUDA
);
// Generate RNG state for rowwise quantization
at
::
PhiloxCudaState
philox_args
=
init_philox_state
(
gen
,
rng_elts_per_thread
);
auto
rng_state
=
torch
::
empty
({
2
},
opts
);
auto
rng_state
=
torch
::
empty
({
2
},
opts
);
philox_unpack
(
philox_args
,
static_cast
<
int64_t
*>
(
rng_state
.
data_ptr
()));
philox_unpack
(
philox_args
,
static_cast
<
int64_t
*>
(
rng_state
.
data_ptr
()));
te_rng_state
=
makeTransformerEngineTensor
(
rng_state
);
te_rng_state
=
makeTransformerEngineTensor
(
rng_state
);
quant_config
.
set_rng_state
(
te_rng_state
.
data
());
quant_config
.
set_rng_state
(
te_rng_state
.
data
());
// Generate separate RNG state for columnwise quantization
if
(
need_separate_columnwise_rng
)
{
at
::
PhiloxCudaState
philox_args_columnwise
=
init_philox_state
(
gen
,
rng_elts_per_thread
);
auto
rng_state_columnwise
=
torch
::
empty
({
2
},
opts
);
philox_unpack
(
philox_args_columnwise
,
static_cast
<
int64_t
*>
(
rng_state_columnwise
.
data_ptr
()));
te_rng_state_columnwise
=
makeTransformerEngineTensor
(
rng_state_columnwise
);
quant_config_columnwise
.
set_stochastic_rounding
(
true
);
quant_config_columnwise
.
set_rng_state
(
te_rng_state_columnwise
.
data
());
}
}
}
// Restriction for the RHT cast fusion kernel
.
// Restriction for the RHT cast fusion kernel
because we are using MMA hardware for computing RHT
bool
eligible_for_rht_cast_fusion
=
bool
eligible_for_rht_cast_fusion
=
input
.
dtype
()
==
DType
::
kBFloat16
&&
rows
%
64
==
0
&&
cols
%
128
==
0
;
input
.
dtype
()
==
DType
::
kBFloat16
&&
rows
%
64
==
0
&&
cols
%
128
==
0
;
...
@@ -1609,6 +1596,10 @@ void NVFP4Quantizer::quantize_impl(const TensorWrapper& input, TensorWrapper& ou
...
@@ -1609,6 +1596,10 @@ void NVFP4Quantizer::quantize_impl(const TensorWrapper& input, TensorWrapper& ou
static_cast
<
DType
>
(
out_columnwise_amax
.
dtype
),
static_cast
<
DType
>
(
out_columnwise_amax
.
dtype
),
out_columnwise_amax
.
shape
);
out_columnwise_amax
.
shape
);
// Use separate RNG state for columnwise to ensure different random numbers than rowwise
auto
&
columnwise_quant_config
=
need_separate_columnwise_rng
?
quant_config_columnwise
:
quant_config
;
if
(
!
eligible_for_rht_cast_fusion
)
{
if
(
!
eligible_for_rht_cast_fusion
)
{
// Invoking fallback RHT kernel.
// Invoking fallback RHT kernel.
...
@@ -1637,7 +1628,8 @@ void NVFP4Quantizer::quantize_impl(const TensorWrapper& input, TensorWrapper& ou
...
@@ -1637,7 +1628,8 @@ void NVFP4Quantizer::quantize_impl(const TensorWrapper& input, TensorWrapper& ou
// Quantize kernel will treat everything as rowwise input/output, which is
// Quantize kernel will treat everything as rowwise input/output, which is
// intended.
// intended.
NVTE_SCOPED_GIL_RELEASE
({
NVTE_SCOPED_GIL_RELEASE
({
nvte_quantize_v2
(
rht_output_t_cpp
.
data
(),
out_transpose
.
data
(),
quant_config
,
stream
);
nvte_quantize_v2
(
rht_output_t_cpp
.
data
(),
out_transpose
.
data
(),
columnwise_quant_config
,
stream
);
});
});
}
else
{
}
else
{
// RHT cast fusion kernel.
// RHT cast fusion kernel.
...
@@ -1648,8 +1640,9 @@ void NVFP4Quantizer::quantize_impl(const TensorWrapper& input, TensorWrapper& ou
...
@@ -1648,8 +1640,9 @@ void NVFP4Quantizer::quantize_impl(const TensorWrapper& input, TensorWrapper& ou
NVTE_CHECK
(
false
,
"Not only supported for nvte_hadamard_transform_cast_fusion_columnwise"
);
NVTE_CHECK
(
false
,
"Not only supported for nvte_hadamard_transform_cast_fusion_columnwise"
);
#else
#else
NVTE_SCOPED_GIL_RELEASE
({
NVTE_SCOPED_GIL_RELEASE
({
nvte_hadamard_transform_cast_fusion_columnwise
(
nvte_hadamard_transform_cast_fusion_columnwise
(
input
.
data
(),
out_transpose
.
data
(),
input
.
data
(),
out_transpose
.
data
(),
rht_matrix_nvte
.
data
(),
quant_config
,
stream
);
rht_matrix_nvte
.
data
(),
columnwise_quant_config
,
stream
);
});
});
#endif
#endif
}
}
...
...
transformer_engine/pytorch/csrc/type_converters.cpp
View file @
0d874a4e
/*************************************************************************
/*************************************************************************
* Copyright (c) 2022-202
5
, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* Copyright (c) 2022-202
6
, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
*
*
* See LICENSE for license information.
* See LICENSE for license information.
************************************************************************/
************************************************************************/
...
@@ -55,8 +55,9 @@ TensorWrapper NVTETensorFromFloat8Tensor(py::handle tensor, Quantizer *quantizer
...
@@ -55,8 +55,9 @@ TensorWrapper NVTETensorFromFloat8Tensor(py::handle tensor, Quantizer *quantizer
TensorWrapper
NVTETensorFromMXFP8Tensor
(
py
::
handle
tensor
,
Quantizer
*
quantizer
)
{
TensorWrapper
NVTETensorFromMXFP8Tensor
(
py
::
handle
tensor
,
Quantizer
*
quantizer
)
{
auto
ret
=
TensorWrapper
(
NVTE_MXFP8_1D_SCALING
);
auto
ret
=
TensorWrapper
(
NVTE_MXFP8_1D_SCALING
);
bool
rowwise_usage
=
!
(
tensor
.
attr
(
"_rowwise_data"
).
is_none
());
const
bool
rowwise_usage
=
!
(
tensor
.
attr
(
"_rowwise_data"
).
is_none
());
bool
columnwise_usage
=
!
(
tensor
.
attr
(
"_columnwise_data"
).
is_none
());
const
bool
columnwise_usage
=
!
(
tensor
.
attr
(
"_columnwise_data"
).
is_none
());
const
bool
with_gemm_swizzled_scales
=
tensor
.
attr
(
"_with_gemm_swizzled_scales"
).
cast
<
bool
>
();
NVTE_CHECK
(
rowwise_usage
||
columnwise_usage
,
"No data found for MXFP8 Tensor."
);
NVTE_CHECK
(
rowwise_usage
||
columnwise_usage
,
"No data found for MXFP8 Tensor."
);
...
@@ -78,6 +79,9 @@ TensorWrapper NVTETensorFromMXFP8Tensor(py::handle tensor, Quantizer *quantizer)
...
@@ -78,6 +79,9 @@ TensorWrapper NVTETensorFromMXFP8Tensor(py::handle tensor, Quantizer *quantizer)
getTensorShape
(
scale_inv
));
getTensorShape
(
scale_inv
));
}
}
// Scale layout
ret
.
set_with_gemm_swizzled_scales
(
with_gemm_swizzled_scales
);
// Quantizer state
// Quantizer state
quantizer
->
set_quantization_params
(
&
ret
);
quantizer
->
set_quantization_params
(
&
ret
);
...
@@ -93,6 +97,7 @@ TensorWrapper NVTETensorFromFloat8BlockwiseQTensor(py::handle tensor, Quantizer
...
@@ -93,6 +97,7 @@ TensorWrapper NVTETensorFromFloat8BlockwiseQTensor(py::handle tensor, Quantizer
auto
ret
=
TensorWrapper
(
is_2D_scaled
?
NVTE_BLOCK_SCALING_2D
:
NVTE_BLOCK_SCALING_1D
);
auto
ret
=
TensorWrapper
(
is_2D_scaled
?
NVTE_BLOCK_SCALING_2D
:
NVTE_BLOCK_SCALING_1D
);
// Row-wise data
if
(
rowwise_usage
)
{
if
(
rowwise_usage
)
{
const
at
::
Tensor
&
data_rowwise
=
tensor
.
attr
(
"_rowwise_data"
).
cast
<
at
::
Tensor
>
();
const
at
::
Tensor
&
data_rowwise
=
tensor
.
attr
(
"_rowwise_data"
).
cast
<
at
::
Tensor
>
();
const
at
::
Tensor
&
scale_inv_rowwise
=
tensor
.
attr
(
"_rowwise_scale_inv"
).
cast
<
at
::
Tensor
>
();
const
at
::
Tensor
&
scale_inv_rowwise
=
tensor
.
attr
(
"_rowwise_scale_inv"
).
cast
<
at
::
Tensor
>
();
...
@@ -102,6 +107,8 @@ TensorWrapper NVTETensorFromFloat8BlockwiseQTensor(py::handle tensor, Quantizer
...
@@ -102,6 +107,8 @@ TensorWrapper NVTETensorFromFloat8BlockwiseQTensor(py::handle tensor, Quantizer
const
auto
scale_inv_rowwise_shape
=
getTensorShape
(
scale_inv_rowwise
);
const
auto
scale_inv_rowwise_shape
=
getTensorShape
(
scale_inv_rowwise
);
ret
.
set_rowwise_scale_inv
(
scale_inv_rowwise_dptr
,
DType
::
kFloat32
,
scale_inv_rowwise_shape
);
ret
.
set_rowwise_scale_inv
(
scale_inv_rowwise_dptr
,
DType
::
kFloat32
,
scale_inv_rowwise_shape
);
}
}
// Column-wise data
if
(
columnwise_usage
)
{
if
(
columnwise_usage
)
{
const
at
::
Tensor
&
data_colwise
=
tensor
.
attr
(
"_columnwise_data"
).
cast
<
at
::
Tensor
>
();
const
at
::
Tensor
&
data_colwise
=
tensor
.
attr
(
"_columnwise_data"
).
cast
<
at
::
Tensor
>
();
const
at
::
Tensor
&
scale_inv_colwise
=
tensor
.
attr
(
"_columnwise_scale_inv"
).
cast
<
at
::
Tensor
>
();
const
at
::
Tensor
&
scale_inv_colwise
=
tensor
.
attr
(
"_columnwise_scale_inv"
).
cast
<
at
::
Tensor
>
();
...
@@ -112,7 +119,10 @@ TensorWrapper NVTETensorFromFloat8BlockwiseQTensor(py::handle tensor, Quantizer
...
@@ -112,7 +119,10 @@ TensorWrapper NVTETensorFromFloat8BlockwiseQTensor(py::handle tensor, Quantizer
const
auto
scale_inv_colwise_shape
=
getTensorShape
(
scale_inv_colwise
);
const
auto
scale_inv_colwise_shape
=
getTensorShape
(
scale_inv_colwise
);
ret
.
set_columnwise_scale_inv
(
scale_inv_colwise_dptr
,
DType
::
kFloat32
,
scale_inv_colwise_shape
);
ret
.
set_columnwise_scale_inv
(
scale_inv_colwise_dptr
,
DType
::
kFloat32
,
scale_inv_colwise_shape
);
}
}
// Quantizer state
quantizer
->
set_quantization_params
(
&
ret
);
quantizer
->
set_quantization_params
(
&
ret
);
return
ret
;
return
ret
;
}
}
...
@@ -121,8 +131,9 @@ TensorWrapper NVTETensorFromNVFP4Tensor(py::handle tensor, Quantizer *quantizer)
...
@@ -121,8 +131,9 @@ TensorWrapper NVTETensorFromNVFP4Tensor(py::handle tensor, Quantizer *quantizer)
auto
ret
=
TensorWrapper
(
NVTE_NVFP4_1D_SCALING
);
auto
ret
=
TensorWrapper
(
NVTE_NVFP4_1D_SCALING
);
bool
rowwise_usage
=
!
(
tensor
.
attr
(
"_rowwise_data"
).
is_none
());
const
bool
rowwise_usage
=
!
(
tensor
.
attr
(
"_rowwise_data"
).
is_none
());
bool
columnwise_usage
=
!
(
tensor
.
attr
(
"_columnwise_data"
).
is_none
());
const
bool
columnwise_usage
=
!
(
tensor
.
attr
(
"_columnwise_data"
).
is_none
());
const
bool
with_gemm_swizzled_scales
=
tensor
.
attr
(
"_with_gemm_swizzled_scales"
).
cast
<
bool
>
();
NVTE_CHECK
(
rowwise_usage
||
columnwise_usage
,
"No data found for NVFP4 Tensor."
);
NVTE_CHECK
(
rowwise_usage
||
columnwise_usage
,
"No data found for NVFP4 Tensor."
);
...
@@ -150,6 +161,9 @@ TensorWrapper NVTETensorFromNVFP4Tensor(py::handle tensor, Quantizer *quantizer)
...
@@ -150,6 +161,9 @@ TensorWrapper NVTETensorFromNVFP4Tensor(py::handle tensor, Quantizer *quantizer)
getTensorShape
(
amax_columnwise
));
getTensorShape
(
amax_columnwise
));
}
}
// Scale layout
ret
.
set_with_gemm_swizzled_scales
(
with_gemm_swizzled_scales
);
// Quantizer state
// Quantizer state
quantizer
->
set_quantization_params
(
&
ret
);
quantizer
->
set_quantization_params
(
&
ret
);
...
...
transformer_engine/pytorch/csrc/util.cpp
deleted
100644 → 0
View file @
a68e5f87
/*************************************************************************
* Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
*
* See LICENSE for license information.
************************************************************************/
#include "util.h"
#include "common.h"
#include "common/common.h"
std
::
optional
<
at
::
Tensor
>
swizzle_scaling_factors
(
transformer_engine
::
TensorWrapper
&
input
,
bool
rowwise
)
{
using
namespace
transformer_engine
::
pytorch
;
if
(
input
.
scaling_mode
()
==
NVTE_INVALID_SCALING
)
{
NVTE_ERROR
(
"Invalid scaling mode for swizzle."
);
}
else
if
(
input
.
scaling_mode
()
!=
NVTE_MXFP8_1D_SCALING
&&
input
.
scaling_mode
()
!=
NVTE_NVFP4_1D_SCALING
)
{
return
std
::
nullopt
;
}
NVTE_CHECK
(
input
.
element_size_bits
()
==
4
||
input
.
element_size_bits
()
==
8
,
"4-bit or 8-bit input required for swizzling scaling factors."
);
const
auto
nvfp4
=
input
.
scaling_mode
()
==
NVTE_NVFP4_1D_SCALING
;
NVTEBasicTensor
scale_inv
;
NVTEShape
nvte_input_shape
;
if
(
rowwise
)
{
nvte_input_shape
=
input
.
shape
();
scale_inv
=
input
.
get_rowwise_scale_inv
();
}
else
{
nvte_input_shape
=
input
.
get_columnwise_data
().
shape
;
scale_inv
=
input
.
get_columnwise_scale_inv
();
}
auto
input_shape
=
nvte_shape_to_vector
(
nvte_input_shape
);
auto
scale_inv_shape
=
nvte_shape_to_vector
(
scale_inv
.
shape
);
NVTE_CHECK
(
input_shape
.
size
()
>=
2
,
"Wrong ndims for swizzle input shape."
);
// Allocate memory for swizzled output.
auto
options
=
at
::
TensorOptions
().
dtype
(
torch
::
kByte
).
device
(
torch
::
kCUDA
);
std
::
vector
<
int64_t
>
scale_inv_shape_int
;
for
(
size_t
i
=
0
;
i
<
scale_inv_shape
.
size
();
++
i
)
{
scale_inv_shape_int
.
push_back
(
static_cast
<
int64_t
>
(
scale_inv_shape
[
i
]));
}
auto
swizzled_scale_inv
=
at
::
empty
(
scale_inv_shape_int
,
options
);
void
*
scale_inv_dptr
=
scale_inv
.
data_ptr
;
void
*
swizzled_scale_inv_dptr
=
getDataPtr
(
swizzled_scale_inv
,
0
);
transformer_engine
::
TensorWrapper
input_cu
(
input
.
scaling_mode
());
transformer_engine
::
TensorWrapper
output_cu
(
input
.
scaling_mode
());
const
auto
input_dtype
=
(
nvfp4
)
?
transformer_engine
::
DType
::
kFloat4E2M1
:
transformer_engine
::
DType
::
kFloat8E4M3
;
const
auto
scale_inv_dtype
=
(
nvfp4
)
?
transformer_engine
::
DType
::
kFloat8E4M3
:
transformer_engine
::
DType
::
kFloat8E8M0
;
if
(
rowwise
)
{
input_cu
.
set_rowwise_data
(
input
.
dptr
(),
input_dtype
,
input_shape
);
input_cu
.
set_rowwise_scale_inv
(
scale_inv_dptr
,
scale_inv_dtype
,
scale_inv_shape
);
output_cu
.
set_rowwise_data
(
input
.
dptr
(),
input_dtype
,
input_shape
);
output_cu
.
set_rowwise_scale_inv
(
swizzled_scale_inv_dptr
,
scale_inv_dtype
,
scale_inv_shape
);
}
else
{
input_cu
.
set_columnwise_data
(
input
.
columnwise_dptr
(),
input_dtype
,
input_shape
);
input_cu
.
set_columnwise_scale_inv
(
scale_inv_dptr
,
scale_inv_dtype
,
scale_inv_shape
);
output_cu
.
set_columnwise_data
(
input
.
columnwise_dptr
(),
input_dtype
,
input_shape
);
output_cu
.
set_columnwise_scale_inv
(
swizzled_scale_inv_dptr
,
scale_inv_dtype
,
scale_inv_shape
);
}
// Launch kernel
nvte_swizzle_scaling_factors
(
input_cu
.
data
(),
output_cu
.
data
(),
at
::
cuda
::
getCurrentCUDAStream
());
if
(
rowwise
)
{
input
.
set_rowwise_scale_inv
(
swizzled_scale_inv_dptr
,
scale_inv_dtype
,
scale_inv_shape
);
}
else
{
input
.
set_columnwise_scale_inv
(
swizzled_scale_inv_dptr
,
scale_inv_dtype
,
scale_inv_shape
);
}
return
swizzled_scale_inv
;
}
std
::
optional
<
at
::
Tensor
>
multi_tensor_swizzle_scaling_factors
(
std
::
vector
<
transformer_engine
::
TensorWrapper
>&
tensors
,
bool
rowwise
)
{
using
namespace
transformer_engine
::
pytorch
;
if
(
tensors
.
empty
())
{
return
std
::
nullopt
;
}
bool
all_same_scaling_mode
=
std
::
all_of
(
tensors
.
cbegin
(),
tensors
.
cend
(),
[
&
tensors
](
const
transformer_engine
::
TensorWrapper
&
val
)
{
return
val
.
scaling_mode
()
==
tensors
.
front
().
scaling_mode
();
});
NVTE_CHECK
(
all_same_scaling_mode
,
"Scaling mode of the input tensors must be the same."
);
if
(
tensors
.
front
().
scaling_mode
()
==
NVTE_INVALID_SCALING
)
{
NVTE_ERROR
(
"Invalid scaling mode for swizzle."
);
}
else
if
(
tensors
.
front
().
scaling_mode
()
!=
NVTE_MXFP8_1D_SCALING
&&
tensors
.
front
().
scaling_mode
()
!=
NVTE_NVFP4_1D_SCALING
)
{
return
std
::
nullopt
;
}
const
auto
scaling_mode
=
tensors
.
front
().
scaling_mode
();
const
auto
nvfp4
=
scaling_mode
==
NVTE_NVFP4_1D_SCALING
;
std
::
vector
<
transformer_engine
::
TensorWrapper
>
wrappers
;
std
::
vector
<
NVTETensor
>
input_tensors
,
output_tensors
;
// Collect scale_inv shapes and calculate buffer size and offsets for scale_invs
std
::
vector
<
std
::
vector
<
size_t
>>
scale_inv_shapes
;
std
::
vector
<
void
*>
scale_inv_dptrs
;
size_t
buffer_size
=
0
;
std
::
vector
<
size_t
>
scale_inv_offsets
;
constexpr
size_t
scale_elem_size
=
1
;
for
(
auto
&
tensor
:
tensors
)
{
NVTEBasicTensor
scale_inv
;
if
(
rowwise
)
{
scale_inv
=
tensor
.
get_rowwise_scale_inv
();
}
else
{
scale_inv
=
tensor
.
get_columnwise_scale_inv
();
}
auto
scale_inv_shape
=
nvte_shape_to_vector
(
scale_inv
.
shape
);
buffer_size
=
roundup
(
buffer_size
,
16
);
// align to 16B
scale_inv_offsets
.
push_back
(
buffer_size
);
buffer_size
+=
product
(
scale_inv_shape
)
*
scale_elem_size
;
scale_inv_shapes
.
emplace_back
(
scale_inv_shape
);
scale_inv_dptrs
.
push_back
(
scale_inv
.
data_ptr
);
}
// Allocate full buffer
auto
buffer
=
at
::
empty
({(
int64_t
)
buffer_size
},
at
::
device
(
at
::
kCUDA
).
dtype
(
torch
::
kUInt8
));
const
auto
input_dtype
=
(
nvfp4
)
?
transformer_engine
::
DType
::
kFloat4E2M1
:
transformer_engine
::
DType
::
kFloat8E4M3
;
const
auto
scale_inv_dtype
=
(
nvfp4
)
?
transformer_engine
::
DType
::
kFloat8E4M3
:
transformer_engine
::
DType
::
kFloat8E8M0
;
for
(
size_t
i
=
0
;
i
<
tensors
.
size
();
++
i
)
{
auto
&
tensor
=
tensors
[
i
];
void
*
scale_inv_dptr
=
scale_inv_dptrs
[
i
];
void
*
swizzled_scale_inv_dptr
=
getDataPtr
(
buffer
,
scale_inv_offsets
[
i
]);
// auto input_shape = nvte_shape_to_vector(tensor.shape());
NVTEShape
nvte_input_shape
;
if
(
rowwise
)
{
nvte_input_shape
=
tensor
.
shape
();
}
else
{
nvte_input_shape
=
tensor
.
get_columnwise_data
().
shape
;
}
auto
input_shape
=
nvte_shape_to_vector
(
nvte_input_shape
);
// Reconstruct input only to avoid swizzling both directions if not needed.
// Use any 8 bit type, it's irrelevant.
transformer_engine
::
TensorWrapper
input_cu
(
scaling_mode
);
transformer_engine
::
TensorWrapper
output_cu
(
scaling_mode
);
if
(
rowwise
)
{
input_cu
.
set_rowwise_data
(
tensor
.
dptr
(),
input_dtype
,
input_shape
);
input_cu
.
set_rowwise_scale_inv
(
scale_inv_dptr
,
scale_inv_dtype
,
scale_inv_shapes
[
i
]);
output_cu
.
set_rowwise_data
(
tensor
.
dptr
(),
input_dtype
,
input_shape
);
output_cu
.
set_rowwise_scale_inv
(
swizzled_scale_inv_dptr
,
scale_inv_dtype
,
scale_inv_shapes
[
i
]);
// Set the swizzled scaling factor to the original tensor.
tensor
.
set_rowwise_scale_inv
(
swizzled_scale_inv_dptr
,
scale_inv_dtype
,
scale_inv_shapes
[
i
]);
}
else
{
input_cu
.
set_columnwise_data
(
tensor
.
columnwise_dptr
(),
input_dtype
,
input_shape
);
input_cu
.
set_columnwise_scale_inv
(
scale_inv_dptr
,
scale_inv_dtype
,
scale_inv_shapes
[
i
]);
output_cu
.
set_columnwise_data
(
tensor
.
columnwise_dptr
(),
input_dtype
,
input_shape
);
output_cu
.
set_columnwise_scale_inv
(
swizzled_scale_inv_dptr
,
scale_inv_dtype
,
scale_inv_shapes
[
i
]);
// Set the swizzled scaling factor to the original tensor.
tensor
.
set_columnwise_scale_inv
(
swizzled_scale_inv_dptr
,
scale_inv_dtype
,
scale_inv_shapes
[
i
]);
}
input_tensors
.
emplace_back
(
input_cu
.
data
());
output_tensors
.
emplace_back
(
output_cu
.
data
());
wrappers
.
emplace_back
(
std
::
move
(
input_cu
));
wrappers
.
emplace_back
(
std
::
move
(
output_cu
));
}
// Launch kernel
nvte_multi_tensor_swizzle_scaling_factors
(
input_tensors
.
data
(),
output_tensors
.
data
(),
input_tensors
.
size
(),
at
::
cuda
::
getCurrentCUDAStream
());
return
buffer
;
}
at
::
Tensor
convert_block_scaling_to_mxfp8_tensor
(
transformer_engine
::
TensorWrapper
&
input
,
bool
rowwise
)
{
using
namespace
transformer_engine
::
pytorch
;
using
transformer_engine
::
DIVUP
;
// Check input tensor
const
NVTEScalingMode
scaling_mode
=
input
.
scaling_mode
();
NVTE_CHECK
(
scaling_mode
==
NVTE_BLOCK_SCALING_1D
||
scaling_mode
==
NVTE_BLOCK_SCALING_2D
,
"Input tensor must be a block scaling tensor"
);
// Get tensor data
NVTEBasicTensor
data
;
size_t
data_flat_first_dim
=
1
;
size_t
data_flat_last_dim
=
1
;
if
(
rowwise
)
{
data
=
input
.
get_rowwise_data
();
for
(
int
i
=
0
;
i
<
data
.
shape
.
ndim
-
1
;
++
i
)
{
data_flat_first_dim
*=
data
.
shape
.
data
[
i
];
}
data_flat_last_dim
=
data
.
shape
.
data
[
data
.
shape
.
ndim
-
1
];
}
else
{
data
=
input
.
get_columnwise_data
();
data_flat_first_dim
=
data
.
shape
.
data
[
0
];
for
(
int
i
=
1
;
i
<
data
.
shape
.
ndim
;
++
i
)
{
data_flat_last_dim
*=
data
.
shape
.
data
[
i
];
}
}
NVTEShape
data_shape
{};
data_shape
.
data
[
0
]
=
data_flat_first_dim
;
data_shape
.
data
[
1
]
=
data_flat_last_dim
;
data_shape
.
ndim
=
2
;
// Recreate input tensor with rowwise usage
transformer_engine
::
TensorWrapper
input_cu
(
scaling_mode
);
input_cu
.
set_rowwise_data
(
data
.
data_ptr
,
input
.
dtype
(),
data_shape
);
const
NVTEBasicTensor
scale_inv
=
rowwise
?
input
.
get_rowwise_scale_inv
()
:
input
.
get_columnwise_scale_inv
();
input_cu
.
set_rowwise_scale_inv
(
scale_inv
.
data_ptr
,
static_cast
<
transformer_engine
::
DType
>
(
scale_inv
.
dtype
),
scale_inv
.
shape
);
// Create output tensor
transformer_engine
::
TensorWrapper
output_cu
(
NVTE_MXFP8_1D_SCALING
);
output_cu
.
set_rowwise_data
(
data
.
data_ptr
,
input
.
dtype
(),
data_shape
);
// Output swizzled mxfp8 scaling factor dimensions
const
size_t
swizzled_scale_inv_first_dim
=
DIVUP
<
size_t
>
(
data_flat_first_dim
,
128
)
*
128
;
const
size_t
swizzled_scale_inv_last_dim
=
DIVUP
<
size_t
>
(
data_flat_last_dim
,
128
)
*
4
;
// Allocate memory for swizzled mxfp8 scaling factors
const
auto
options
=
at
::
TensorOptions
().
dtype
(
torch
::
kByte
).
device
(
torch
::
kCUDA
);
at
::
Tensor
swizzled_scale_inv
=
at
::
empty
(
std
::
vector
<
int64_t
>
{
static_cast
<
int64_t
>
(
swizzled_scale_inv_first_dim
),
static_cast
<
int64_t
>
(
swizzled_scale_inv_last_dim
)},
options
);
// Set rowwise scaling factors on output
void
*
const
swizzled_scale_inv_dptr
=
getDataPtr
(
swizzled_scale_inv
,
0
);
NVTEShape
swizzled_scale_inv_shape
{};
swizzled_scale_inv_shape
.
data
[
0
]
=
swizzled_scale_inv_first_dim
;
swizzled_scale_inv_shape
.
data
[
1
]
=
swizzled_scale_inv_last_dim
;
swizzled_scale_inv_shape
.
ndim
=
2
;
output_cu
.
set_rowwise_scale_inv
(
swizzled_scale_inv_dptr
,
transformer_engine
::
DType
::
kFloat8E8M0
,
swizzled_scale_inv_shape
);
// Convert scaling factors from FP8 block scaling GEMM_READY format to mxfp8 swizzled format
nvte_swizzle_block_scaling_to_mxfp8_scaling_factors
(
input_cu
.
data
(),
output_cu
.
data
(),
at
::
cuda
::
getCurrentCUDAStream
());
// Set the input tensor to be the converted mxfp8 tensor and return the swizzled scaling factor
// for it to be kept alive during the GEMM
input
=
std
::
move
(
output_cu
);
return
swizzled_scale_inv
;
}
transformer_engine/pytorch/csrc/util.h
View file @
0d874a4e
/*************************************************************************
/*************************************************************************
* Copyright (c) 2022-202
5
, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* Copyright (c) 2022-202
6
, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
*
*
* See LICENSE for license information.
* See LICENSE for license information.
************************************************************************/
************************************************************************/
...
@@ -10,33 +10,44 @@
...
@@ -10,33 +10,44 @@
#include <torch/extension.h>
#include <torch/extension.h>
#include <optional>
#include <optional>
#include <tuple>
#include <vector>
#include "transformer_engine/transformer_engine.h"
#include "transformer_engine/transformer_engine.h"
/*! \brief Swizzle the scaling factor of the input tensor.
namespace
transformer_engine
{
namespace
pytorch
{
/*! \brief Convert tensor block scales into GEMM swizzled format.
*
*
* The returned swizzled scal
ing factor tensor
should be kept alive during the GEMM.
*
The returned swizzled scal
es
should be kept alive during the GEMM.
*/
*/
std
::
optional
<
at
::
Tensor
>
swizzle_scaling_factors
(
transformer_engine
::
TensorWrapper
&
input
,
std
::
tuple
<
std
::
optional
<
at
::
Tensor
>
,
std
::
optional
<
at
::
Tensor
>>
swizzle_scales_for_gemm
(
bool
rowwis
e
);
TensorWrapper
&
tensor
,
bool
rowwise_usage
,
bool
columnwise_usag
e
);
/*! \brief
Swizzle the scaling factor of the input tensors
.
/*! \brief
Convert multiple tensor block scales into GEMM swizzled format
.
*
*
* The returned swizzled scal
ing factor tensor
s should be kept alive during the GEMMs.
*
The returned swizzled scal
e
s should be kept alive during the GEMMs.
*/
*/
std
::
optional
<
at
::
Tensor
>
multi_tensor_swizzle_scaling_factors
(
std
::
optional
<
at
::
Tensor
>
multi_tensor_swizzle_scales_for_gemm
(
std
::
vector
<
TensorWrapper
>&
tensors
,
std
::
vector
<
transformer_engine
::
TensorWrapper
>
&
inputs
,
bool
rowwise
);
bool
rowwise_usage
,
bool
columnwise_usage
);
/*! \brief Convert a block scaling tensor to an mxfp8 tensor in-place.
/*! \brief Convert a block scaling tensor to an mxfp8 tensor in-place.
*
*
* If rowwise==false, the columnwise data will be reinterpreted as rowwise data to avoid
* If rowwise==false, the columnwise data will be reinterpreted as
* transposing it in memory. Due to differences in how block scaling and mxfp8 store data,
* rowwise data to avoid transposing it in memory. Due to differences
* this requires the calling code to treat the output tensor as having been tranposed in this case.
* in how block scaling and mxfp8 store data, this requires the
* calling code to treat the output tensor as having been transposed
* in this case.
*
*
* Returns the swizzled scaling factor of the converted mxfp8 tensor.
* Returns the swizzled scaling factor of the converted mxfp8 tensor.
* The returned swizzled scaling factor tensor should be kept alive during the GEMM.
* The returned swizzled scaling factor tensor should be kept alive
* during the GEMM.
*/
*/
at
::
Tensor
convert_block_scaling_to_mxfp8_tensor
(
transformer_engine
::
TensorWrapper
&
input
,
at
::
Tensor
convert_block_scaling_to_mxfp8_tensor
(
TensorWrapper
&
input
,
bool
rowwise
);
bool
rowwise
);
}
// namespace pytorch
}
// namespace transformer_engine
#endif // TRANSFORMER_ENGINE_PYTORCH_CSRC_UTIL_H_
#endif // TRANSFORMER_ENGINE_PYTORCH_CSRC_UTIL_H_
transformer_engine/pytorch/custom_recipes/__init__.py
View file @
0d874a4e
# Copyright (c) 2022-202
5
, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# Copyright (c) 2022-202
6
, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
#
# See LICENSE for license information.
# See LICENSE for license information.
...
...
Prev
1
…
24
25
26
27
28
29
30
31
32
33
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment