Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
TransformerEngine
Commits
0e886dab
Commit
0e886dab
authored
Jul 01, 2025
by
wenjh
Browse files
Merge develop_v2.4
Signed-off-by:
wenjh
<
wenjh@sugon.com
>
parents
e56de127
b944277c
Changes
23
Expand all
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
1123 additions
and
140 deletions
+1123
-140
tests/cpp/operator/test_cast_float8blockwise.cu
tests/cpp/operator/test_cast_float8blockwise.cu
+13
-0
tests/cpp/test_common.cu
tests/cpp/test_common.cu
+6
-6
tests/cpp/test_common.h
tests/cpp/test_common.h
+12
-0
tests/pytorch/references/blockwise_fp8_gemm_reference.py
tests/pytorch/references/blockwise_fp8_gemm_reference.py
+2
-1
tests/pytorch/references/blockwise_quantizer_reference.py
tests/pytorch/references/blockwise_quantizer_reference.py
+3
-4
tests/pytorch/test_float8_blockwise_gemm_exact.py
tests/pytorch/test_float8_blockwise_gemm_exact.py
+7
-5
tests/pytorch/test_float8_blockwise_scaling_exact.py
tests/pytorch/test_float8_blockwise_scaling_exact.py
+14
-8
tests/pytorch/test_int8_blockwise_gemm_exact.py
tests/pytorch/test_int8_blockwise_gemm_exact.py
+14
-10
transformer_engine/common/common.h
transformer_engine/common/common.h
+5
-0
transformer_engine/common/recipe/fp8_block_scaling.cu
transformer_engine/common/recipe/fp8_block_scaling.cu
+211
-23
transformer_engine/common/transpose/quantize_transpose_square_blockwise.cu
...e/common/transpose/quantize_transpose_square_blockwise.cu
+449
-35
transformer_engine/common/transpose/quantize_transpose_vector_blockwise.cu
...e/common/transpose/quantize_transpose_vector_blockwise.cu
+356
-35
transformer_engine/pytorch/cpp_extensions/gemm.py
transformer_engine/pytorch/cpp_extensions/gemm.py
+7
-6
transformer_engine/pytorch/csrc/common.h
transformer_engine/pytorch/csrc/common.h
+14
-0
transformer_engine/pytorch/csrc/extensions/fp8_block_scaling_partial_cast.cpp
...ytorch/csrc/extensions/fp8_block_scaling_partial_cast.cpp
+2
-2
transformer_engine/pytorch/csrc/quantizer.cpp
transformer_engine/pytorch/csrc/quantizer.cpp
+1
-1
transformer_engine/pytorch/distributed.py
transformer_engine/pytorch/distributed.py
+1
-1
transformer_engine/pytorch/fp8.py
transformer_engine/pytorch/fp8.py
+1
-0
transformer_engine/pytorch/tensor/_internal/float8_blockwise_tensor_base.py
.../pytorch/tensor/_internal/float8_blockwise_tensor_base.py
+3
-2
transformer_engine/pytorch/tensor/float8_blockwise_tensor.py
transformer_engine/pytorch/tensor/float8_blockwise_tensor.py
+2
-1
No files found.
tests/cpp/operator/test_cast_float8blockwise.cu
View file @
0e886dab
...
@@ -25,7 +25,11 @@ struct QuantizationOptions {
...
@@ -25,7 +25,11 @@ struct QuantizationOptions {
size_t
block_scaling_dim
=
2u
;
size_t
block_scaling_dim
=
2u
;
};
};
#ifdef __HIP_PLATFORM_AMD__
size_t
kBlockLen
=
static_cast
<
size_t
>
(
blockwise_fp8_block_len
());
#else
constexpr
size_t
kBlockLen
=
128
;
constexpr
size_t
kBlockLen
=
128
;
#endif
enum
ProcessingMethod
{
enum
ProcessingMethod
{
CAST_ONLY
,
CAST_ONLY
,
...
@@ -80,8 +84,13 @@ template <typename InputType, typename OutputType>
...
@@ -80,8 +84,13 @@ template <typename InputType, typename OutputType>
void
ref_quantize
(
const
ProcessingMethod
processing_method
,
const
InputType
*
input
,
void
ref_quantize
(
const
ProcessingMethod
processing_method
,
const
InputType
*
input
,
const
std
::
pair
<
size_t
,
size_t
>&
input_hw
,
OutputType
*
output
,
float
*
scale_inv
,
const
std
::
pair
<
size_t
,
size_t
>&
input_hw
,
OutputType
*
output
,
float
*
scale_inv
,
OutputType
*
output_t
,
float
*
scale_inv_t
,
const
QuantizationOptions
&
opts
)
{
OutputType
*
output_t
,
float
*
scale_inv_t
,
const
QuantizationOptions
&
opts
)
{
#ifdef __HIP_PLATFORM_AMD__
size_t
kBlockLenX
=
kBlockLen
;
size_t
kBlockLenY
=
kBlockLen
;
#else
constexpr
size_t
kBlockLenX
=
kBlockLen
;
constexpr
size_t
kBlockLenX
=
kBlockLen
;
constexpr
size_t
kBlockLenY
=
kBlockLen
;
constexpr
size_t
kBlockLenY
=
kBlockLen
;
#endif
auto
quantize_element
=
[](
InputType
element
,
float
qscale
)
->
OutputType
{
auto
quantize_element
=
[](
InputType
element
,
float
qscale
)
->
OutputType
{
// Scale in FP32 and cast result to nearest FP8.
// Scale in FP32 and cast result to nearest FP8.
...
@@ -157,7 +166,11 @@ void ref_quantize_onedimensional_blocks(const ProcessingMethod processing_method
...
@@ -157,7 +166,11 @@ void ref_quantize_onedimensional_blocks(const ProcessingMethod processing_method
float
input_type_max_val
=
Quantized_Limits
<
InputType
>::
max
();
float
input_type_max_val
=
Quantized_Limits
<
InputType
>::
max
();
float
quant_type_max_val
=
Quantized_Limits
<
OutputType
>::
max
();
float
quant_type_max_val
=
Quantized_Limits
<
OutputType
>::
max
();
#ifdef __HIP_PLATFORM_AMD__
size_t
kBlockLenX
=
kBlockLen
;
#else
constexpr
size_t
kBlockLenX
=
kBlockLen
;
constexpr
size_t
kBlockLenX
=
kBlockLen
;
#endif
auto
quantize_element
=
[](
InputType
element
,
float
qscale
)
->
OutputType
{
auto
quantize_element
=
[](
InputType
element
,
float
qscale
)
->
OutputType
{
// Scale in FP32 and cast result to nearest FP8.
// Scale in FP32 and cast result to nearest FP8.
...
...
tests/cpp/test_common.cu
View file @
0e886dab
...
@@ -176,13 +176,13 @@ std::pair<scale_inv_meta, scale_inv_meta> get_scales(const NVTEShape& shape,
...
@@ -176,13 +176,13 @@ std::pair<scale_inv_meta, scale_inv_meta> get_scales(const NVTEShape& shape,
scale_inv_meta
ret_rowwise
,
ret_colwise
;
scale_inv_meta
ret_rowwise
,
ret_colwise
;
{
{
auto
scale_dim_0
=
DIVUP
(
first_dim
,
static_cast
<
size_t
>
(
128
));
auto
scale_dim_0
=
DIVUP
(
first_dim
,
static_cast
<
size_t
>
(
blockwise_fp8_block_len
()
));
auto
scale_dim_1
=
DIVUP
(
DIVUP
(
last_dim
,
static_cast
<
size_t
>
(
128
)),
4
)
*
4
;
auto
scale_dim_1
=
DIVUP
(
DIVUP
(
last_dim
,
static_cast
<
size_t
>
(
blockwise_fp8_block_len
()
)),
4
)
*
4
;
ret_rowwise
.
shape
=
{
scale_dim_0
,
scale_dim_1
};
ret_rowwise
.
shape
=
{
scale_dim_0
,
scale_dim_1
};
}
}
{
{
auto
scale_dim_0
=
DIVUP
(
last_dim
,
static_cast
<
size_t
>
(
128
));
auto
scale_dim_0
=
DIVUP
(
last_dim
,
static_cast
<
size_t
>
(
blockwise_fp8_block_len
()
));
auto
scale_dim_1
=
DIVUP
(
DIVUP
(
first_dim
,
static_cast
<
size_t
>
(
128
)),
4
)
*
4
;
auto
scale_dim_1
=
DIVUP
(
DIVUP
(
first_dim
,
static_cast
<
size_t
>
(
blockwise_fp8_block_len
()
)),
4
)
*
4
;
ret_colwise
.
shape
=
{
scale_dim_0
,
scale_dim_1
};
ret_colwise
.
shape
=
{
scale_dim_0
,
scale_dim_1
};
}
}
ret_rowwise
.
type
=
DType
::
kFloat32
;
ret_rowwise
.
type
=
DType
::
kFloat32
;
...
@@ -202,12 +202,12 @@ std::pair<scale_inv_meta, scale_inv_meta> get_scales(const NVTEShape& shape,
...
@@ -202,12 +202,12 @@ std::pair<scale_inv_meta, scale_inv_meta> get_scales(const NVTEShape& shape,
scale_inv_meta
ret_rowwise
,
ret_colwise
;
scale_inv_meta
ret_rowwise
,
ret_colwise
;
{
{
auto
scale_dim_0
=
DIVUP
(
last_dim
,
static_cast
<
size_t
>
(
128
));
auto
scale_dim_0
=
DIVUP
(
last_dim
,
static_cast
<
size_t
>
(
blockwise_fp8_block_len
()
));
auto
scale_dim_1
=
DIVUP
(
first_dim
,
4
)
*
4
;
auto
scale_dim_1
=
DIVUP
(
first_dim
,
4
)
*
4
;
ret_rowwise
.
shape
=
{
scale_dim_0
,
scale_dim_1
};
ret_rowwise
.
shape
=
{
scale_dim_0
,
scale_dim_1
};
}
}
{
{
auto
scale_dim_0
=
DIVUP
(
first_dim
,
static_cast
<
size_t
>
(
128
));
auto
scale_dim_0
=
DIVUP
(
first_dim
,
static_cast
<
size_t
>
(
blockwise_fp8_block_len
()
));
auto
scale_dim_1
=
DIVUP
(
last_dim
,
4
)
*
4
;
auto
scale_dim_1
=
DIVUP
(
last_dim
,
4
)
*
4
;
ret_colwise
.
shape
=
{
scale_dim_0
,
scale_dim_1
};
ret_colwise
.
shape
=
{
scale_dim_0
,
scale_dim_1
};
}
}
...
...
tests/cpp/test_common.h
View file @
0e886dab
...
@@ -29,6 +29,18 @@
...
@@ -29,6 +29,18 @@
namespace
test
{
namespace
test
{
using
namespace
transformer_engine
;
using
namespace
transformer_engine
;
inline
int
blockwise_fp8_block_len
()
{
const
char
*
env
=
std
::
getenv
(
"NVTE_BLOCKWISE_FP8_BLOCK_LEN"
);
if
(
env
==
nullptr
||
env
[
0
]
==
'\0'
)
{
return
128
;
}
int
value
;
std
::
istringstream
iss
(
env
);
iss
>>
value
;
NVTE_CHECK
(
iss
,
"Invalid environment variable value"
);
return
value
;
}
template
<
size_t
i
>
template
<
size_t
i
>
struct
BytesToType
{};
struct
BytesToType
{};
...
...
tests/pytorch/references/blockwise_fp8_gemm_reference.py
View file @
0e886dab
...
@@ -8,6 +8,7 @@ import torch
...
@@ -8,6 +8,7 @@ import torch
import
triton
import
triton
import
triton.language
as
tl
import
triton.language
as
tl
from
torch.utils.cpp_extension
import
IS_HIP_EXTENSION
from
torch.utils.cpp_extension
import
IS_HIP_EXTENSION
from
transformer_engine.pytorch.fp8
import
blockwise_fp8_block_len
@
triton
.
jit
@
triton
.
jit
...
@@ -135,7 +136,7 @@ class CuBLASRefBlockwiseGemm:
...
@@ -135,7 +136,7 @@ class CuBLASRefBlockwiseGemm:
N
,
K_w
=
qw
.
shape
N
,
K_w
=
qw
.
shape
assert
K
==
K_w
,
"K dimension mismatch between qx and qw"
assert
K
==
K_w
,
"K dimension mismatch between qx and qw"
tile_len
=
128
tile_len
=
blockwise_fp8_block_len
# Calculate grid sizes without padding
# Calculate grid sizes without padding
grid_m
=
(
M
+
tile_len
-
1
)
//
tile_len
grid_m
=
(
M
+
tile_len
-
1
)
//
tile_len
grid_n
=
(
N
+
tile_len
-
1
)
//
tile_len
grid_n
=
(
N
+
tile_len
-
1
)
//
tile_len
...
...
tests/pytorch/references/blockwise_quantizer_reference.py
View file @
0e886dab
...
@@ -7,7 +7,7 @@ import math
...
@@ -7,7 +7,7 @@ import math
import
torch
import
torch
from
typing
import
Optional
,
Protocol
,
Tuple
from
typing
import
Optional
,
Protocol
,
Tuple
from
references.quantize_scale_calc
import
scale_from_amax_tensor
from
references.quantize_scale_calc
import
scale_from_amax_tensor
from
transformer_engine.pytorch.fp8
import
blockwise_fp8_block_len
@
dataclasses
.
dataclass
()
@
dataclasses
.
dataclass
()
class
QuantizeResult
:
class
QuantizeResult
:
...
@@ -277,8 +277,7 @@ class BlockwiseQuantizerReference:
...
@@ -277,8 +277,7 @@ class BlockwiseQuantizerReference:
return_transpose
:
bool
=
False
,
return_transpose
:
bool
=
False
,
eps
:
float
=
0.0
,
eps
:
float
=
0.0
,
pow_2_scales
:
bool
=
False
,
pow_2_scales
:
bool
=
False
,
quant_tile_shape
:
Tuple
[
int
,
int
]
=
(
128
,
128
),
quant_tile_shape
:
Tuple
[
int
,
int
]
=
(
blockwise_fp8_block_len
,
blockwise_fp8_block_len
),
munge_scale_shapes
:
bool
=
True
,
)
->
QuantizeResult
:
)
->
QuantizeResult
:
# sanity checks
# sanity checks
assert
x
.
dim
()
==
2
assert
x
.
dim
()
==
2
...
@@ -294,7 +293,7 @@ class BlockwiseQuantizerReference:
...
@@ -294,7 +293,7 @@ class BlockwiseQuantizerReference:
torch
.
int8
,
torch
.
int8
,
),
"Unsupported quant dtype."
),
"Unsupported quant dtype."
assert
quant_tile_shape
in
((
1
,
128
),
(
128
,
128
))
assert
quant_tile_shape
in
((
1
,
blockwise_fp8_block_len
),
(
blockwise_fp8_block_len
,
blockwise_fp8_block_len
))
if
quant_tile_shape
[
0
]
==
1
:
if
quant_tile_shape
[
0
]
==
1
:
# Quantize row-wise
# Quantize row-wise
result
=
self
.
_quantize_vector_tiling
(
result
=
self
.
_quantize_vector_tiling
(
...
...
tests/pytorch/test_float8_blockwise_gemm_exact.py
View file @
0e886dab
...
@@ -8,7 +8,7 @@ import transformer_engine as te
...
@@ -8,7 +8,7 @@ import transformer_engine as te
import
transformer_engine_torch
as
tex
import
transformer_engine_torch
as
tex
from
transformer_engine.pytorch.constants
import
TE_DType
from
transformer_engine.pytorch.constants
import
TE_DType
from
transformer_engine.pytorch.fp8
import
FP8GlobalStateManager
from
transformer_engine.pytorch.fp8
import
(
FP8GlobalStateManager
,
blockwise_fp8_block_len
)
from
transformer_engine.pytorch.tensor.float8_blockwise_tensor
import
(
from
transformer_engine.pytorch.tensor.float8_blockwise_tensor
import
(
Float8BlockQuantizer
,
Float8BlockQuantizer
,
Float8BlockwiseQTensor
,
Float8BlockwiseQTensor
,
...
@@ -77,8 +77,9 @@ def cublas_gemm_fp8_blockwise_case(
...
@@ -77,8 +77,9 @@ def cublas_gemm_fp8_blockwise_case(
assert
not
(
use_bias
and
use_grad
),
"Bias grad not supported by GEMM"
assert
not
(
use_bias
and
use_grad
),
"Bias grad not supported by GEMM"
# Set quantize_op and quantization parameters
# Set quantize_op and quantization parameters
x_quant_tile_shape
=
(
1
,
128
)
if
is_x_1d_scaled
else
(
128
,
128
)
block_len
=
blockwise_fp8_block_len
w_quant_tile_shape
=
(
1
,
128
)
if
is_w_1d_scaled
else
(
128
,
128
)
x_quant_tile_shape
=
(
1
,
block_len
)
if
is_x_1d_scaled
else
(
block_len
,
block_len
)
w_quant_tile_shape
=
(
1
,
block_len
)
if
is_w_1d_scaled
else
(
block_len
,
block_len
)
x_block_scaling_dim
=
1
if
is_x_1d_scaled
else
2
x_block_scaling_dim
=
1
if
is_x_1d_scaled
else
2
w_block_scaling_dim
=
1
if
is_w_1d_scaled
else
2
w_block_scaling_dim
=
1
if
is_w_1d_scaled
else
2
x_te_dtype
=
TE_DType
[
x_dtype
]
x_te_dtype
=
TE_DType
[
x_dtype
]
...
@@ -247,8 +248,9 @@ def cublas_gemm_test_constraint_enforced(
...
@@ -247,8 +248,9 @@ def cublas_gemm_test_constraint_enforced(
out
=
None
out
=
None
# Set quantize_op and quantization parameters
# Set quantize_op and quantization parameters
x_quant_tile_shape
=
(
1
,
128
)
if
is_x_1d_scaled
else
(
128
,
128
)
block_len
=
blockwise_fp8_block_len
w_quant_tile_shape
=
(
1
,
128
)
if
is_w_1d_scaled
else
(
128
,
128
)
x_quant_tile_shape
=
(
1
,
block_len
)
if
is_x_1d_scaled
else
(
block_len
,
block_len
)
w_quant_tile_shape
=
(
1
,
block_len
)
if
is_w_1d_scaled
else
(
block_len
,
block_len
)
x_block_scaling_dim
=
1
if
is_x_1d_scaled
else
2
x_block_scaling_dim
=
1
if
is_x_1d_scaled
else
2
w_block_scaling_dim
=
1
if
is_w_1d_scaled
else
2
w_block_scaling_dim
=
1
if
is_w_1d_scaled
else
2
x_te_dtype
=
TE_DType
[
x_dtype
]
x_te_dtype
=
TE_DType
[
x_dtype
]
...
...
tests/pytorch/test_float8_blockwise_scaling_exact.py
View file @
0e886dab
...
@@ -10,7 +10,7 @@ import pytest
...
@@ -10,7 +10,7 @@ import pytest
import
torch
import
torch
import
transformer_engine
as
te
import
transformer_engine
as
te
import
transformer_engine_torch
as
tex
import
transformer_engine_torch
as
tex
from
transformer_engine.pytorch.fp8
import
FP8GlobalStateManager
from
transformer_engine.pytorch.fp8
import
(
FP8GlobalStateManager
,
blockwise_fp8_block_len
)
from
transformer_engine.common.recipe
import
Float8BlockScaling
from
transformer_engine.common.recipe
import
Float8BlockScaling
from
transformer_engine.pytorch.constants
import
TE_DType
from
transformer_engine.pytorch.constants
import
TE_DType
from
transformer_engine.pytorch.tensor.float8_blockwise_tensor
import
(
from
transformer_engine.pytorch.tensor.float8_blockwise_tensor
import
(
...
@@ -219,9 +219,9 @@ def check_quantization_block_tiling_versus_reference(
...
@@ -219,9 +219,9 @@ def check_quantization_block_tiling_versus_reference(
tile_size
:
Tuple
[
int
,
int
],
tile_size
:
Tuple
[
int
,
int
],
)
->
None
:
)
->
None
:
te_dtype
=
TE_DType
[
quant_dtype
]
te_dtype
=
TE_DType
[
quant_dtype
]
if
tile_size
==
(
1
,
128
):
if
tile_size
in
(
(
1
,
128
)
,
(
1
,
64
))
:
block_scaling_dim
=
1
block_scaling_dim
=
1
elif
tile_size
==
(
128
,
128
):
elif
tile_size
in
(
(
128
,
128
)
,
(
64
,
64
))
:
block_scaling_dim
=
2
block_scaling_dim
=
2
else
:
else
:
raise
ValueError
(
"Non support tile size"
)
raise
ValueError
(
"Non support tile size"
)
...
@@ -334,7 +334,7 @@ def check_quantization_block_tiling_versus_reference(
...
@@ -334,7 +334,7 @@ def check_quantization_block_tiling_versus_reference(
"return_transpose"
,
[
True
,
False
],
ids
=
[
"quantize_transpose"
,
"quantize_only"
]
"return_transpose"
,
[
True
,
False
],
ids
=
[
"quantize_transpose"
,
"quantize_only"
]
)
)
@
pytest
.
mark
.
parametrize
(
"pow_2_scales"
,
[
True
],
ids
=
[
"pow2scales"
])
@
pytest
.
mark
.
parametrize
(
"pow_2_scales"
,
[
True
],
ids
=
[
"pow2scales"
])
@
pytest
.
mark
.
parametrize
(
"tile_size"
,
[(
1
,
128
),
(
128
,
128
)],
ids
=
[
"1DTile"
,
"2DTile"
])
@
pytest
.
mark
.
parametrize
(
"tile_size"
,
[(
1
,
128
),
(
128
,
128
)
,
(
1
,
64
),
(
64
,
64
)
],
ids
=
[
"1D
128
Tile"
,
"2D
128Tile"
,
"1D64Tile"
,
"2D64
Tile"
])
def
test_quantization_block_tiling_versus_reference
(
def
test_quantization_block_tiling_versus_reference
(
x_dtype
:
torch
.
dtype
,
x_dtype
:
torch
.
dtype
,
M
:
int
,
M
:
int
,
...
@@ -345,6 +345,8 @@ def test_quantization_block_tiling_versus_reference(
...
@@ -345,6 +345,8 @@ def test_quantization_block_tiling_versus_reference(
pow_2_scales
:
bool
,
pow_2_scales
:
bool
,
tile_size
:
Tuple
[
int
,
int
],
tile_size
:
Tuple
[
int
,
int
],
)
->
None
:
)
->
None
:
if
blockwise_fp8_block_len
!=
tile_size
[
1
]:
pytest
.
skip
(
"Block len of blockwise is skipped by env."
)
check_quantization_block_tiling_versus_reference
(
check_quantization_block_tiling_versus_reference
(
x_dtype
,
M
,
N
,
quant_dtype
,
eps
,
return_transpose
,
pow_2_scales
,
tile_size
x_dtype
,
M
,
N
,
quant_dtype
,
eps
,
return_transpose
,
pow_2_scales
,
tile_size
)
)
...
@@ -369,7 +371,7 @@ def test_quantization_block_tiling_versus_reference(
...
@@ -369,7 +371,7 @@ def test_quantization_block_tiling_versus_reference(
"return_transpose"
,
[
True
,
False
],
ids
=
[
"quantize_transpose"
,
"quantize_only"
]
"return_transpose"
,
[
True
,
False
],
ids
=
[
"quantize_transpose"
,
"quantize_only"
]
)
)
@
pytest
.
mark
.
parametrize
(
"pow_2_scales"
,
[
False
],
ids
=
[
"fp32scales"
])
@
pytest
.
mark
.
parametrize
(
"pow_2_scales"
,
[
False
],
ids
=
[
"fp32scales"
])
@
pytest
.
mark
.
parametrize
(
"tile_size"
,
[(
1
,
128
),
(
128
,
128
)],
ids
=
[
"1DTile"
,
"2DTile"
])
@
pytest
.
mark
.
parametrize
(
"tile_size"
,
[(
1
,
128
),
(
128
,
128
)
,
(
1
,
64
),
(
64
,
64
)
],
ids
=
[
"1D
128
Tile"
,
"2D
128Tile"
,
"1D64Tile"
,
"2D64
Tile"
])
def
test_quantization_block_tiling_versus_reference_fp32_scales
(
def
test_quantization_block_tiling_versus_reference_fp32_scales
(
x_dtype
:
torch
.
dtype
,
x_dtype
:
torch
.
dtype
,
M
:
int
,
M
:
int
,
...
@@ -380,6 +382,8 @@ def test_quantization_block_tiling_versus_reference_fp32_scales(
...
@@ -380,6 +382,8 @@ def test_quantization_block_tiling_versus_reference_fp32_scales(
pow_2_scales
:
bool
,
pow_2_scales
:
bool
,
tile_size
:
Tuple
[
int
,
int
],
tile_size
:
Tuple
[
int
,
int
],
)
->
None
:
)
->
None
:
if
blockwise_fp8_block_len
!=
tile_size
[
1
]:
pytest
.
skip
(
"Block len of blockwise is skipped by env."
)
check_quantization_block_tiling_versus_reference
(
check_quantization_block_tiling_versus_reference
(
x_dtype
,
M
,
N
,
quant_dtype
,
eps
,
return_transpose
,
pow_2_scales
,
tile_size
x_dtype
,
M
,
N
,
quant_dtype
,
eps
,
return_transpose
,
pow_2_scales
,
tile_size
)
)
...
@@ -397,7 +401,7 @@ def test_quantization_block_tiling_versus_reference_fp32_scales(
...
@@ -397,7 +401,7 @@ def test_quantization_block_tiling_versus_reference_fp32_scales(
@
pytest
.
mark
.
parametrize
(
"quant_dtype"
,
[
torch
.
int8
,
torch
.
float8_e4m3fn
,
torch
.
float8_e5m2
],
ids
=
str
)
@
pytest
.
mark
.
parametrize
(
"quant_dtype"
,
[
torch
.
int8
,
torch
.
float8_e4m3fn
,
torch
.
float8_e5m2
],
ids
=
str
)
@
pytest
.
mark
.
parametrize
(
"eps"
,
[
0
],
ids
=
[
"eps_0"
])
@
pytest
.
mark
.
parametrize
(
"eps"
,
[
0
],
ids
=
[
"eps_0"
])
@
pytest
.
mark
.
parametrize
(
"pow_2_scales"
,
[
True
,
False
],
ids
=
[
"pow2scales"
,
"fp32scales"
])
@
pytest
.
mark
.
parametrize
(
"pow_2_scales"
,
[
True
,
False
],
ids
=
[
"pow2scales"
,
"fp32scales"
])
@
pytest
.
mark
.
parametrize
(
"tile_size"
,
[(
128
,
128
)])
@
pytest
.
mark
.
parametrize
(
"tile_size"
,
[(
128
,
128
)
,
(
64
,
64
)],
ids
=
[
"2D128Tile"
,
"2D64Tile"
])
@
pytest
.
mark
.
parametrize
(
"extrema_high"
,
[
False
,
True
],
ids
=
[
"zeros"
,
"maxes"
])
@
pytest
.
mark
.
parametrize
(
"extrema_high"
,
[
False
,
True
],
ids
=
[
"zeros"
,
"maxes"
])
def
test_quantization_block_tiling_extrema_versus_reference
(
def
test_quantization_block_tiling_extrema_versus_reference
(
x_dtype
:
torch
.
dtype
,
x_dtype
:
torch
.
dtype
,
...
@@ -411,10 +415,12 @@ def test_quantization_block_tiling_extrema_versus_reference(
...
@@ -411,10 +415,12 @@ def test_quantization_block_tiling_extrema_versus_reference(
)
->
None
:
)
->
None
:
# This test runs a single tile through a quantizer as a way to test
# This test runs a single tile through a quantizer as a way to test
# branch coverage of scale computation.
# branch coverage of scale computation.
if
blockwise_fp8_block_len
!=
tile_size
[
1
]:
pytest
.
skip
(
"Block len of blockwise is skipped by env."
)
te_dtype
=
TE_DType
[
quant_dtype
]
te_dtype
=
TE_DType
[
quant_dtype
]
if
tile_size
==
(
1
,
128
):
if
tile_size
in
(
(
1
,
128
)
,
(
1
,
64
))
:
block_scaling_dim
=
1
block_scaling_dim
=
1
elif
tile_size
==
(
128
,
128
):
elif
tile_size
in
(
(
128
,
128
)
,
(
64
,
64
))
:
block_scaling_dim
=
2
block_scaling_dim
=
2
else
:
else
:
raise
ValueError
(
"Non support tile size"
)
raise
ValueError
(
"Non support tile size"
)
...
...
tests/pytorch/test_int8_blockwise_gemm_exact.py
View file @
0e886dab
...
@@ -4,7 +4,7 @@ import transformer_engine as te
...
@@ -4,7 +4,7 @@ import transformer_engine as te
import
transformer_engine_torch
as
tex
import
transformer_engine_torch
as
tex
from
transformer_engine.pytorch.constants
import
TE_DType
from
transformer_engine.pytorch.constants
import
TE_DType
from
transformer_engine.pytorch.fp8
import
FP8GlobalStateManager
from
transformer_engine.pytorch.fp8
import
(
FP8GlobalStateManager
,
blockwise_fp8_block_len
)
from
transformer_engine.pytorch.tensor.float8_blockwise_tensor
import
(
from
transformer_engine.pytorch.tensor.float8_blockwise_tensor
import
(
Float8BlockQuantizer
,
Float8BlockQuantizer
,
Float8BlockwiseQTensor
,
Float8BlockwiseQTensor
,
...
@@ -82,8 +82,9 @@ def cublas_gemm_fp8_blockwise_case_fw(
...
@@ -82,8 +82,9 @@ def cublas_gemm_fp8_blockwise_case_fw(
assert
not
(
use_bias
and
use_grad
),
"Bias grad not supported by GEMM"
assert
not
(
use_bias
and
use_grad
),
"Bias grad not supported by GEMM"
# Set quantize_op and quantization parameters
# Set quantize_op and quantization parameters
x_quant_tile_shape
=
(
1
,
128
)
if
is_x_1d_scaled
else
(
128
,
128
)
block_len
=
blockwise_fp8_block_len
w_quant_tile_shape
=
(
1
,
128
)
if
is_w_1d_scaled
else
(
128
,
128
)
x_quant_tile_shape
=
(
1
,
block_len
)
if
is_x_1d_scaled
else
(
block_len
,
block_len
)
w_quant_tile_shape
=
(
1
,
block_len
)
if
is_w_1d_scaled
else
(
block_len
,
block_len
)
x_block_scaling_dim
=
1
if
is_x_1d_scaled
else
2
x_block_scaling_dim
=
1
if
is_x_1d_scaled
else
2
w_block_scaling_dim
=
1
if
is_w_1d_scaled
else
2
w_block_scaling_dim
=
1
if
is_w_1d_scaled
else
2
x_te_dtype
=
TE_DType
[
x_dtype
]
x_te_dtype
=
TE_DType
[
x_dtype
]
...
@@ -196,7 +197,7 @@ def cublas_gemm_fp8_blockwise_case_fw(
...
@@ -196,7 +197,7 @@ def cublas_gemm_fp8_blockwise_case_fw(
ref_scales_w
=
qw
.
_columnwise_scale_inv
if
w_columnwise
else
qw
.
_rowwise_scale_inv
ref_scales_w
=
qw
.
_columnwise_scale_inv
if
w_columnwise
else
qw
.
_rowwise_scale_inv
y
,
_
=
w8a8_block_int8_matmul
(
y
,
_
=
w8a8_block_int8_matmul
(
qx_data
,
qw_data
,
ref_scales_x
,
ref_scales_w
,
[
128
,
128
],
qx_data
,
qw_data
,
ref_scales_x
,
ref_scales_w
,
[
block_len
,
block_len
],
output_dtype
=
out_dtype
output_dtype
=
out_dtype
)
)
...
@@ -265,8 +266,9 @@ def cublas_gemm_fp8_blockwise_case_bw_xgrad(
...
@@ -265,8 +266,9 @@ def cublas_gemm_fp8_blockwise_case_bw_xgrad(
assert
not
(
use_bias
and
use_grad
),
"Bias grad not supported by GEMM"
assert
not
(
use_bias
and
use_grad
),
"Bias grad not supported by GEMM"
# Set quantize_op and quantization parameters
# Set quantize_op and quantization parameters
dout_quant_tile_shape
=
(
1
,
128
)
if
is_dout_1d_scaled
else
(
128
,
128
)
block_len
=
blockwise_fp8_block_len
w_quant_tile_shape
=
(
1
,
128
)
if
is_w_1d_scaled
else
(
128
,
128
)
dout_quant_tile_shape
=
(
1
,
block_len
)
if
is_dout_1d_scaled
else
(
block_len
,
block_len
)
w_quant_tile_shape
=
(
1
,
block_len
)
if
is_w_1d_scaled
else
(
block_len
,
block_len
)
dout_block_scaling_dim
=
1
if
is_dout_1d_scaled
else
2
dout_block_scaling_dim
=
1
if
is_dout_1d_scaled
else
2
w_block_scaling_dim
=
1
if
is_w_1d_scaled
else
2
w_block_scaling_dim
=
1
if
is_w_1d_scaled
else
2
dout_te_dtype
=
TE_DType
[
dout_dtype
]
dout_te_dtype
=
TE_DType
[
dout_dtype
]
...
@@ -373,7 +375,7 @@ def cublas_gemm_fp8_blockwise_case_bw_xgrad(
...
@@ -373,7 +375,7 @@ def cublas_gemm_fp8_blockwise_case_bw_xgrad(
ref_scales_w
=
qw
.
_columnwise_scale_inv
if
w_columnwise
else
qw
.
_rowwise_scale_inv
ref_scales_w
=
qw
.
_columnwise_scale_inv
if
w_columnwise
else
qw
.
_rowwise_scale_inv
y
,
_
=
w8a8_block_int8_matmul
(
y
,
_
=
w8a8_block_int8_matmul
(
qdout_data
,
qw_data
,
ref_scales_dout
,
ref_scales_w
,
[
128
,
128
],
qdout_data
,
qw_data
,
ref_scales_dout
,
ref_scales_w
,
[
block_len
,
block_len
],
output_dtype
=
dx_dtype
output_dtype
=
dx_dtype
)
)
...
@@ -441,8 +443,9 @@ def cublas_gemm_fp8_blockwise_case_bw_wgrad(
...
@@ -441,8 +443,9 @@ def cublas_gemm_fp8_blockwise_case_bw_wgrad(
assert
not
(
use_bias
and
use_grad
),
"Bias grad not supported by GEMM"
assert
not
(
use_bias
and
use_grad
),
"Bias grad not supported by GEMM"
# Set quantize_op and quantization parameters
# Set quantize_op and quantization parameters
dout_quant_tile_shape
=
(
1
,
128
)
if
is_dout_1d_scaled
else
(
128
,
128
)
block_len
=
blockwise_fp8_block_len
x_quant_tile_shape
=
(
1
,
128
)
if
is_x_1d_scaled
else
(
128
,
128
)
dout_quant_tile_shape
=
(
1
,
block_len
)
if
is_dout_1d_scaled
else
(
block_len
,
block_len
)
x_quant_tile_shape
=
(
1
,
block_len
)
if
is_x_1d_scaled
else
(
block_len
,
block_len
)
dout_block_scaling_dim
=
1
if
is_dout_1d_scaled
else
2
dout_block_scaling_dim
=
1
if
is_dout_1d_scaled
else
2
x_block_scaling_dim
=
1
if
is_x_1d_scaled
else
2
x_block_scaling_dim
=
1
if
is_x_1d_scaled
else
2
dout_te_dtype
=
TE_DType
[
dout_dtype
]
dout_te_dtype
=
TE_DType
[
dout_dtype
]
...
@@ -552,7 +555,8 @@ def cublas_gemm_fp8_blockwise_case_bw_wgrad(
...
@@ -552,7 +555,8 @@ def cublas_gemm_fp8_blockwise_case_bw_wgrad(
# print(f"ref_scales_dout.shape: {ref_scales_dout.shape}, ref_scales_x.shape: {ref_scales_x.shape}")
# print(f"ref_scales_dout.shape: {ref_scales_dout.shape}, ref_scales_x.shape: {ref_scales_x.shape}")
y
,
_
=
w8a8_block_int8_matmul_wgrad
(
y
,
_
=
w8a8_block_int8_matmul_wgrad
(
qdout_data
,
qx_data
,
ref_scales_dout
,
ref_scales_x
,
[
128
,
128
],
qdout_data
,
qx_data
,
ref_scales_dout
,
ref_scales_x
,
dw
.
clone
()
if
accumulate
else
None
,
accumulate
,
[
block_len
,
block_len
],
output_dtype
=
dw_dtype
output_dtype
=
dw_dtype
)
)
...
...
transformer_engine/common/common.h
View file @
0e886dab
...
@@ -6,6 +6,7 @@
...
@@ -6,6 +6,7 @@
#ifndef TRANSFORMER_ENGINE_COMMON_COMMON_H_
#ifndef TRANSFORMER_ENGINE_COMMON_COMMON_H_
#define TRANSFORMER_ENGINE_COMMON_COMMON_H_
#define TRANSFORMER_ENGINE_COMMON_COMMON_H_
#include "util/system.h"
#ifndef __HIP_PLATFORM_AMD__
#ifndef __HIP_PLATFORM_AMD__
#include <cudaTypedefs.h>
#include <cudaTypedefs.h>
#endif
#endif
...
@@ -39,6 +40,10 @@ namespace transformer_engine {
...
@@ -39,6 +40,10 @@ namespace transformer_engine {
std
::
string
to_string
(
const
DType
type
);
std
::
string
to_string
(
const
DType
type
);
std
::
string
to_string
(
const
NVTEScalingMode
&
mode
);
std
::
string
to_string
(
const
NVTEScalingMode
&
mode
);
inline
int
blockwise_fp8_block_len
()
{
return
::
transformer_engine
::
getenv
<
int
>
(
"NVTE_BLOCKWISE_FP8_BLOCK_LEN"
,
128
);
}
inline
bool
is_tensor_scaling
(
const
NVTEScalingMode
&
mode
)
{
inline
bool
is_tensor_scaling
(
const
NVTEScalingMode
&
mode
)
{
return
mode
==
NVTE_DELAYED_TENSOR_SCALING
;
return
mode
==
NVTE_DELAYED_TENSOR_SCALING
;
}
}
...
...
transformer_engine/common/recipe/fp8_block_scaling.cu
View file @
0e886dab
...
@@ -14,6 +14,7 @@
...
@@ -14,6 +14,7 @@
namespace
transformer_engine
{
namespace
transformer_engine
{
namespace
fp8_block_scaling_recipe
{
namespace
fp8_block_scaling_recipe
{
constexpr
int
kTileDim64
=
64
;
constexpr
int
kTileDim
=
128
;
constexpr
int
kTileDim
=
128
;
constexpr
int
kThreadsPerBlock
=
256
;
constexpr
int
kThreadsPerBlock
=
256
;
...
@@ -116,10 +117,10 @@ __global__ void __launch_bounds__(kThreadsPerBlock)
...
@@ -116,10 +117,10 @@ __global__ void __launch_bounds__(kThreadsPerBlock)
if
(
h_in_input
<
h
&&
w_in_input
<
w
&&
idx_in_input
>=
start_offset
&&
if
(
h_in_input
<
h
&&
w_in_input
<
w
&&
idx_in_input
>=
start_offset
&&
idx_in_input
<
end_offset
)
{
idx_in_input
<
end_offset
)
{
float
inp
=
static_cast
<
float
>
(
input_minus_offset
[
idx_in_input
])
*
scale
;
float
inp
=
static_cast
<
float
>
(
input_minus_offset
[
idx_in_input
])
*
scale
;
if
constexpr
(
std
::
is_same_v
<
OType
,
int8_t
>
)
{
if
constexpr
(
std
::
is_same_v
<
OType
,
int8_t
>
)
{
smem
[
h_in_smem
][
w_in_smem
]
=
static_cast
<
OType
>
(
lroundf
(
fmaxf
(
-
127.0
f
,
fminf
(
127.0
f
,
inp
))));
smem
[
h_in_smem
][
w_in_smem
]
=
}
static_cast
<
OType
>
(
lroundf
(
fmaxf
(
-
127.0
f
,
fminf
(
127.0
f
,
inp
))));
else
{
}
else
{
smem
[
h_in_smem
][
w_in_smem
]
=
static_cast
<
OType
>
(
inp
);
smem
[
h_in_smem
][
w_in_smem
]
=
static_cast
<
OType
>
(
inp
);
}
}
skip_store
=
false
;
skip_store
=
false
;
...
@@ -175,11 +176,171 @@ __global__ void __launch_bounds__(kThreadsPerBlock)
...
@@ -175,11 +176,171 @@ __global__ void __launch_bounds__(kThreadsPerBlock)
}
}
}
}
template
<
typename
IType
>
__global__
void
__launch_bounds__
(
kThreadsPerBlock
)
fp8_block_scaling_block_len64_compute_partial_amax_kernel
(
const
IType
*
input
,
float
*
amax_ptr
,
const
size_t
amax_stride_h
,
const
size_t
amax_stride_w
,
const
size_t
h
,
const
size_t
w
,
const
size_t
start_offset
,
const
size_t
len
)
{
constexpr
int
kThreadsPerWarp
=
32
;
constexpr
int
kLoopsPerRow
=
kTileDim64
/
kThreadsPerWarp
;
constexpr
int
kNumWarps
=
kThreadsPerBlock
/
kThreadsPerWarp
;
constexpr
int
kLoopsPerCol
=
kTileDim64
/
kNumWarps
;
const
int
tile_col
=
blockIdx
.
x
;
const
int
tile_row
=
blockIdx
.
y
;
const
size_t
end_offset
=
start_offset
+
len
;
const
IType
*
input_minus_offset
=
input
-
start_offset
;
__shared__
float
smem
[
kNumWarps
];
float
amax
=
0.0
f
;
for
(
int
loop_col
=
0
;
loop_col
<
kLoopsPerCol
;
++
loop_col
)
{
size_t
r
=
tile_row
*
kTileDim64
+
loop_col
*
kNumWarps
+
threadIdx
.
x
/
kThreadsPerWarp
;
for
(
int
loop_row
=
0
;
loop_row
<
kLoopsPerRow
;
++
loop_row
)
{
size_t
c
=
tile_col
*
kTileDim64
+
loop_row
*
kThreadsPerWarp
+
(
threadIdx
.
x
%
kThreadsPerWarp
);
size_t
idx
=
r
*
w
+
c
;
if
(
r
<
h
&&
c
<
w
&&
idx
>=
start_offset
&&
idx
<
end_offset
)
{
float
other_amax
=
fabs
(
static_cast
<
float
>
(
input_minus_offset
[
idx
]));
__builtin_assume
(
amax
>=
0
);
__builtin_assume
(
other_amax
>=
0
);
amax
=
fmaxf
(
amax
,
other_amax
);
}
}
}
for
(
int
delta
=
kThreadsPerWarp
/
2
;
delta
>
0
;
delta
/=
2
)
{
#ifdef __HIP_PLATFORM_AMD__
float
other_amax
=
__shfl_down
(
amax
,
delta
,
kThreadsPerWarp
);
#else
float
other_amax
=
__shfl_down_sync
(
0xFFFFFFFF
,
amax
,
delta
);
#endif
__builtin_assume
(
amax
>=
0
);
__builtin_assume
(
other_amax
>=
0
);
amax
=
fmaxf
(
amax
,
other_amax
);
}
if
(
threadIdx
.
x
%
kThreadsPerWarp
==
0
)
{
smem
[
threadIdx
.
x
/
kThreadsPerWarp
]
=
amax
;
}
__syncthreads
();
if
(
threadIdx
.
x
==
0
)
{
for
(
int
i
=
0
;
i
<
kNumWarps
;
++
i
)
{
float
other_amax
=
smem
[
i
];
__builtin_assume
(
amax
>=
0
);
__builtin_assume
(
other_amax
>=
0
);
amax
=
fmaxf
(
amax
,
other_amax
);
}
amax_ptr
[
tile_row
*
amax_stride_h
+
tile_col
*
amax_stride_w
]
=
amax
;
}
}
template
<
typename
IType
,
typename
OType
,
bool
kWidthAligned
>
__global__
void
__launch_bounds__
(
kThreadsPerBlock
)
fp8_block_scaling_block_len64_partial_cast_kernel
(
const
IType
*
input
,
OType
*
output
,
const
float
*
scale_ptr
,
const
size_t
scale_stride_h
,
const
size_t
scale_stride_w
,
const
size_t
h
,
const
size_t
w
,
const
size_t
start_offset
,
const
size_t
len
)
{
using
transformer_engine
::
Vec
;
static_assert
(
sizeof
(
OType
)
==
1
);
constexpr
int
kNumOutputElemsPerBank
=
4
/
sizeof
(
OType
);
constexpr
int
kThreadsPerWarp
=
32
;
constexpr
int
kLoopsPerRow
=
kTileDim64
/
kThreadsPerWarp
;
constexpr
int
kNumWarps
=
kThreadsPerBlock
/
kThreadsPerWarp
;
constexpr
int
kRowsPerWarp
=
kTileDim64
/
kNumWarps
;
__shared__
OType
smem
[
kTileDim64
][
kTileDim64
+
kNumOutputElemsPerBank
];
const
int
tile_w
=
blockIdx
.
x
;
const
int
tile_h
=
blockIdx
.
y
;
const
size_t
end_offset
=
start_offset
+
len
;
const
IType
*
input_minus_offset
=
input
-
start_offset
;
OType
*
output_minus_offset
=
output
-
start_offset
;
const
float
scale
=
scale_ptr
[
tile_h
*
scale_stride_h
+
tile_w
*
scale_stride_w
];
// Load input data into shared memory
bool
skip_store
=
true
;
for
(
int
i
=
0
;
i
<
kRowsPerWarp
;
++
i
)
{
for
(
int
j
=
0
;
j
<
kLoopsPerRow
;
++
j
)
{
const
int
h_in_smem
=
threadIdx
.
x
/
kThreadsPerWarp
*
kRowsPerWarp
+
i
;
const
int
w_in_smem
=
threadIdx
.
x
%
kThreadsPerWarp
+
kThreadsPerWarp
*
j
;
const
int
h_in_input
=
tile_h
*
kTileDim64
+
h_in_smem
;
const
int
w_in_input
=
tile_w
*
kTileDim64
+
w_in_smem
;
const
size_t
idx_in_input
=
static_cast
<
size_t
>
(
h_in_input
)
*
w
+
w_in_input
;
if
(
h_in_input
<
h
&&
w_in_input
<
w
&&
idx_in_input
>=
start_offset
&&
idx_in_input
<
end_offset
)
{
float
inp
=
static_cast
<
float
>
(
input_minus_offset
[
idx_in_input
])
*
scale
;
if
constexpr
(
std
::
is_same_v
<
OType
,
int8_t
>
)
{
smem
[
h_in_smem
][
w_in_smem
]
=
static_cast
<
OType
>
(
lroundf
(
fmaxf
(
-
127.0
f
,
fminf
(
127.0
f
,
inp
))));
}
else
{
smem
[
h_in_smem
][
w_in_smem
]
=
static_cast
<
OType
>
(
inp
);
}
skip_store
=
false
;
}
}
}
for
(
int
delta
=
kThreadsPerWarp
/
2
;
delta
>
0
;
delta
/=
2
)
{
#ifdef __HIP_PLATFORM_AMD__
bool
other_skip_store
=
__shfl_down
(
skip_store
,
delta
,
kThreadsPerWarp
);
#else
bool
other_skip_store
=
__shfl_down_sync
(
0xFFFFFFFF
,
skip_store
,
delta
);
#endif
skip_store
=
skip_store
&&
other_skip_store
;
}
#ifdef __HIP_PLATFORM_AMD__
skip_store
=
__shfl
(
skip_store
,
0
,
kThreadsPerWarp
);
#else
skip_store
=
__shfl_sync
(
0xFFFFFFFF
,
skip_store
,
0
);
#endif
if
(
skip_store
)
{
return
;
}
// Store the casted data into the output.
// Note that this store operation might write "out-of-bounds", but it is intentional:
// 1. The "out-of-bounds" here only crosses the boundary of the "local shard" (i.e., the region
// from start_offset to end_offset), not the boundary of the entire output memory. Therefore,
// this out-of-bounds write will not cause illegal memory access.
// 2. We assume that the subsequent all-gather operation happens in-place, so any parts that
// should not be updated here will be overwritten by the all-gather.
// This tricky approach allows us to avoid checking whether each output index falls within
// [start, end), resulting in a significant performance improvement.
Vec
<
OType
,
kNumOutputElemsPerBank
>
vec_output
;
for
(
int
i
=
0
;
i
<
kRowsPerWarp
;
++
i
)
{
const
int
row_in_smem
=
threadIdx
.
x
/
kThreadsPerWarp
*
kRowsPerWarp
+
i
;
const
int
col_in_smem
=
threadIdx
.
x
%
kThreadsPerWarp
*
kNumOutputElemsPerBank
;
for
(
int
j
=
0
;
j
<
kNumOutputElemsPerBank
;
++
j
)
{
vec_output
.
data
.
elt
[
j
]
=
smem
[
row_in_smem
][
col_in_smem
+
j
];
}
const
int
row_in_output
=
tile_h
*
kTileDim64
+
row_in_smem
;
const
int
col_in_output
=
tile_w
*
kTileDim64
+
col_in_smem
;
const
size_t
idx_in_output
=
static_cast
<
size_t
>
(
row_in_output
)
*
w
+
col_in_output
;
if
(
row_in_output
<
h
)
{
if
constexpr
(
kWidthAligned
)
{
vec_output
.
store_to
(
output_minus_offset
+
idx_in_output
);
}
else
{
int
num
=
min
(
static_cast
<
size_t
>
(
kNumOutputElemsPerBank
),
static_cast
<
size_t
>
(
col_in_output
<
w
?
w
-
col_in_output
:
0
));
vec_output
.
store_to_elts
(
output_minus_offset
,
idx_in_output
,
num
);
}
}
}
}
void
fp8_block_scaling_compute_partial_amax
(
const
Tensor
inp
,
Tensor
amax
,
size_t
h
,
size_t
w
,
void
fp8_block_scaling_compute_partial_amax
(
const
Tensor
inp
,
Tensor
amax
,
size_t
h
,
size_t
w
,
size_t
amax_stride_h
,
size_t
amax_stride_w
,
size_t
amax_stride_h
,
size_t
amax_stride_w
,
size_t
start_offset
,
size_t
block_len
,
size_t
start_offset
,
size_t
block_len
,
cudaStream_t
stream
)
{
cudaStream_t
stream
)
{
NVTE_CHECK
(
block_len
==
128
,
"Currently only block_len = 128 is supported"
);
NVTE_CHECK
(
block_len
==
128
||
block_len
==
64
,
"Currently only block_len = 128 or 64 is supported"
);
size_t
len
=
inp
.
numel
();
size_t
len
=
inp
.
numel
();
...
@@ -187,26 +348,39 @@ void fp8_block_scaling_compute_partial_amax(const Tensor inp, Tensor amax, size_
...
@@ -187,26 +348,39 @@ void fp8_block_scaling_compute_partial_amax(const Tensor inp, Tensor amax, size_
assert
(
start_offset
<
h
*
w
);
assert
(
start_offset
<
h
*
w
);
assert
(
start_offset
+
len
<=
h
*
w
);
assert
(
start_offset
+
len
<=
h
*
w
);
size_t
blocks_x
=
(
w
+
kTileDim
-
1
)
/
kTileDim
;
size_t
blocks_x
=
(
w
+
block_len
-
1
)
/
block_len
;
size_t
blocks_y
=
(
h
+
kTileDim
-
1
)
/
kTileDim
;
size_t
blocks_y
=
(
h
+
block_len
-
1
)
/
block_len
;
assert
(
blocks_x
<=
std
::
numeric_limits
<
unsigned
int
>::
max
());
assert
(
blocks_x
<=
std
::
numeric_limits
<
unsigned
int
>::
max
());
assert
(
blocks_y
<=
std
::
numeric_limits
<
unsigned
int
>::
max
());
assert
(
blocks_y
<=
std
::
numeric_limits
<
unsigned
int
>::
max
());
dim3
grid
(
blocks_x
,
blocks_y
);
dim3
grid
(
blocks_x
,
blocks_y
);
TRANSFORMER_ENGINE_TYPE_SWITCH_NON_FP8ONLY
(
TRANSFORMER_ENGINE_TYPE_SWITCH_NON_FP8ONLY
(
inp
.
dtype
(),
inp_dtype
,
inp
.
dtype
(),
inp_dtype
,
while
(
true
)
{
fp8_block_scaling_compute_partial_amax_kernel
<
inp_dtype
>
if
(
128
==
block_len
)
{
<<<
grid
,
kThreadsPerBlock
,
0
,
stream
>>>
(
reinterpret_cast
<
const
inp_dtype
*>
(
inp
.
data
.
dptr
),
fp8_block_scaling_compute_partial_amax_kernel
<
inp_dtype
>
reinterpret_cast
<
float
*>
(
amax
.
data
.
dptr
),
<<<
grid
,
kThreadsPerBlock
,
0
,
stream
>>>
(
amax_stride_h
,
amax_stride_w
,
h
,
w
,
start_offset
,
reinterpret_cast
<
const
inp_dtype
*>
(
inp
.
data
.
dptr
),
len
);)
reinterpret_cast
<
float
*>
(
amax
.
data
.
dptr
),
amax_stride_h
,
amax_stride_w
,
h
,
w
,
start_offset
,
len
);
break
;
}
if
(
64
==
block_len
)
{
fp8_block_scaling_block_len64_compute_partial_amax_kernel
<
inp_dtype
>
<<<
grid
,
kThreadsPerBlock
,
0
,
stream
>>>
(
reinterpret_cast
<
const
inp_dtype
*>
(
inp
.
data
.
dptr
),
reinterpret_cast
<
float
*>
(
amax
.
data
.
dptr
),
amax_stride_h
,
amax_stride_w
,
h
,
w
,
start_offset
,
len
);
break
;
}
})
}
}
void
fp8_block_scaling_partial_cast
(
const
Tensor
inp
,
Tensor
out
,
const
Tensor
scale
,
size_t
h
,
void
fp8_block_scaling_partial_cast
(
const
Tensor
inp
,
Tensor
out
,
const
Tensor
scale
,
size_t
h
,
size_t
w
,
size_t
scale_stride_h
,
size_t
scale_stride_w
,
size_t
w
,
size_t
scale_stride_h
,
size_t
scale_stride_w
,
size_t
start_offset
,
size_t
block_len
,
const
DType
out_dtype
,
size_t
start_offset
,
size_t
block_len
,
const
DType
out_dtype
,
cudaStream_t
stream
)
{
cudaStream_t
stream
)
{
NVTE_CHECK
(
block_len
==
128
,
"Currently only block_len = 128 is supported"
);
NVTE_CHECK
(
block_len
==
128
||
block_len
==
64
,
"Currently only block_len = 128 or 64 is supported"
);
size_t
len
=
inp
.
numel
();
size_t
len
=
inp
.
numel
();
...
@@ -214,8 +388,8 @@ void fp8_block_scaling_partial_cast(const Tensor inp, Tensor out, const Tensor s
...
@@ -214,8 +388,8 @@ void fp8_block_scaling_partial_cast(const Tensor inp, Tensor out, const Tensor s
assert
(
start_offset
<
h
*
w
);
assert
(
start_offset
<
h
*
w
);
assert
(
start_offset
+
len
<=
h
*
w
);
assert
(
start_offset
+
len
<=
h
*
w
);
size_t
blocks_x
=
(
w
+
kTileDim
-
1
)
/
kTileDim
;
size_t
blocks_x
=
(
w
+
block_len
-
1
)
/
block_len
;
size_t
blocks_y
=
(
h
+
kTileDim
-
1
)
/
kTileDim
;
size_t
blocks_y
=
(
h
+
block_len
-
1
)
/
block_len
;
assert
(
blocks_x
<=
std
::
numeric_limits
<
unsigned
int
>::
max
());
assert
(
blocks_x
<=
std
::
numeric_limits
<
unsigned
int
>::
max
());
assert
(
blocks_y
<=
std
::
numeric_limits
<
unsigned
int
>::
max
());
assert
(
blocks_y
<=
std
::
numeric_limits
<
unsigned
int
>::
max
());
dim3
grid
(
blocks_x
,
blocks_y
);
dim3
grid
(
blocks_x
,
blocks_y
);
...
@@ -225,13 +399,27 @@ void fp8_block_scaling_partial_cast(const Tensor inp, Tensor out, const Tensor s
...
@@ -225,13 +399,27 @@ void fp8_block_scaling_partial_cast(const Tensor inp, Tensor out, const Tensor s
TRANSFORMER_ENGINE_TYPE_SWITCH_8BIT
(
TRANSFORMER_ENGINE_TYPE_SWITCH_8BIT
(
out_dtype
,
fp8_type
,
out_dtype
,
fp8_type
,
TRANSFORMER_ENGINE_SWITCH_CONDITION
(
TRANSFORMER_ENGINE_SWITCH_CONDITION
(
w
%
kTileDim
==
0
,
kWidthAligned
,
w
%
block_len
==
0
,
kWidthAligned
,
while
(
true
)
{
fp8_block_scaling_partial_cast_kernel
<
inp_dtype
,
fp8_type
,
kWidthAligned
>
if
(
128
==
block_len
)
{
<<<
grid
,
kThreadsPerBlock
,
0
,
stream
>>>
(
fp8_block_scaling_partial_cast_kernel
<
inp_dtype
,
fp8_type
,
kWidthAligned
>
reinterpret_cast
<
const
inp_dtype
*>
(
inp
.
data
.
dptr
),
<<<
grid
,
kThreadsPerBlock
,
0
,
stream
>>>
(
reinterpret_cast
<
fp8_type
*>
(
out
.
data
.
dptr
),
reinterpret_cast
<
const
inp_dtype
*>
(
inp
.
data
.
dptr
),
reinterpret_cast
<
const
float
*>
(
scale
.
data
.
dptr
),
scale_stride_h
,
scale_stride_w
,
reinterpret_cast
<
fp8_type
*>
(
out
.
data
.
dptr
),
h
,
w
,
start_offset
,
len
);)))
reinterpret_cast
<
const
float
*>
(
scale
.
data
.
dptr
),
scale_stride_h
,
scale_stride_w
,
h
,
w
,
start_offset
,
len
);
break
;
}
if
(
64
==
block_len
)
{
fp8_block_scaling_block_len64_partial_cast_kernel
<
inp_dtype
,
fp8_type
,
kWidthAligned
>
<<<
grid
,
kThreadsPerBlock
,
0
,
stream
>>>
(
reinterpret_cast
<
const
inp_dtype
*>
(
inp
.
data
.
dptr
),
reinterpret_cast
<
fp8_type
*>
(
out
.
data
.
dptr
),
reinterpret_cast
<
const
float
*>
(
scale
.
data
.
dptr
),
scale_stride_h
,
scale_stride_w
,
h
,
w
,
start_offset
,
len
);
break
;
}
})))
}
}
}
// namespace fp8_block_scaling_recipe
}
// namespace fp8_block_scaling_recipe
...
...
transformer_engine/common/transpose/quantize_transpose_square_blockwise.cu
View file @
0e886dab
This diff is collapsed.
Click to expand it.
transformer_engine/common/transpose/quantize_transpose_vector_blockwise.cu
View file @
0e886dab
This diff is collapsed.
Click to expand it.
transformer_engine/pytorch/cpp_extensions/gemm.py
View file @
0e886dab
...
@@ -15,6 +15,7 @@ from transformer_engine.pytorch.triton.blockwise_int8_gemm_nt_wgrad import w8a8_
...
@@ -15,6 +15,7 @@ from transformer_engine.pytorch.triton.blockwise_int8_gemm_nt_wgrad import w8a8_
from
..tensor.quantized_tensor
import
Quantizer
from
..tensor.quantized_tensor
import
Quantizer
from
..tensor._internal.float8_blockwise_tensor_base
import
Float8BlockwiseQTensorBase
from
..tensor._internal.float8_blockwise_tensor_base
import
Float8BlockwiseQTensorBase
from
...debug.pytorch.debug_quantization
import
DebugQuantizer
from
...debug.pytorch.debug_quantization
import
DebugQuantizer
from
transformer_engine.pytorch.fp8
import
blockwise_fp8_block_len
int8_simulation_fp8
=
bool
(
int
(
os
.
getenv
(
"NVTE_INT8_SIM_FP8"
,
"0"
)))
int8_simulation_fp8
=
bool
(
int
(
os
.
getenv
(
"NVTE_INT8_SIM_FP8"
,
"0"
)))
__all__
=
[
__all__
=
[
...
@@ -76,7 +77,7 @@ def general_gemm(
...
@@ -76,7 +77,7 @@ def general_gemm(
ref_scales_w
=
A
.
_rowwise_scale_inv
ref_scales_w
=
A
.
_rowwise_scale_inv
y
,
_
=
w8a8_block_int8_matmul
(
y
,
_
=
w8a8_block_int8_matmul
(
qx_data
,
qw_data
,
ref_scales_x
,
ref_scales_w
,
[
128
,
128
],
qx_data
,
qw_data
,
ref_scales_x
,
ref_scales_w
,
[
blockwise_fp8_block_len
,
blockwise_fp8_block_len
],
output_dtype
=
out_dtype
output_dtype
=
out_dtype
)
)
return
y
,
None
,
None
,
None
return
y
,
None
,
None
,
None
...
@@ -92,7 +93,7 @@ def general_gemm(
...
@@ -92,7 +93,7 @@ def general_gemm(
ref_scales_w
=
A
.
_columnwise_scale_inv
ref_scales_w
=
A
.
_columnwise_scale_inv
y
,
_
=
w8a8_block_int8_matmul
(
y
,
_
=
w8a8_block_int8_matmul
(
qdout_data
,
qw_data
,
ref_scales_dout
,
ref_scales_w
,
[
128
,
128
],
qdout_data
,
qw_data
,
ref_scales_dout
,
ref_scales_w
,
[
blockwise_fp8_block_len
,
blockwise_fp8_block_len
],
output_dtype
=
out_dtype
output_dtype
=
out_dtype
)
)
return
y
,
None
,
None
,
None
return
y
,
None
,
None
,
None
...
@@ -108,7 +109,7 @@ def general_gemm(
...
@@ -108,7 +109,7 @@ def general_gemm(
ref_scales_x
=
A
.
_columnwise_scale_inv
ref_scales_x
=
A
.
_columnwise_scale_inv
out
,
_
=
w8a8_block_int8_matmul_wgrad
(
out
,
_
=
w8a8_block_int8_matmul_wgrad
(
qdout_data
,
qx_data
,
ref_scales_dout
,
ref_scales_x
,
out
,
accumulate
,
[
128
,
128
],
qdout_data
,
qx_data
,
ref_scales_dout
,
ref_scales_x
,
out
,
accumulate
,
[
blockwise_fp8_block_len
,
blockwise_fp8_block_len
],
output_dtype
=
out_dtype
output_dtype
=
out_dtype
)
)
return
out
,
None
,
None
,
None
return
out
,
None
,
None
,
None
...
@@ -251,7 +252,7 @@ def general_grouped_gemm(
...
@@ -251,7 +252,7 @@ def general_grouped_gemm(
seq_len
=
sum
(
m_splits
)
//
num_gemms
seq_len
=
sum
(
m_splits
)
//
num_gemms
out
[
0
]
=
w8a8_block_int8_matmul_batched
(
out
[
0
]
=
w8a8_block_int8_matmul_batched
(
qx_data
,
qw_data
,
ref_scales_x
,
ref_scales_w
,
out
[
0
].
view
(
num_gemms
,
seq_len
,
out
[
0
].
size
(
-
1
)),
[
128
,
128
],
qx_data
,
qw_data
,
ref_scales_x
,
ref_scales_w
,
out
[
0
].
view
(
num_gemms
,
seq_len
,
out
[
0
].
size
(
-
1
)),
[
blockwise_fp8_block_len
,
blockwise_fp8_block_len
],
output_dtype
=
out_dtype
output_dtype
=
out_dtype
)
)
return
out
,
bias
,
gelu_input
return
out
,
bias
,
gelu_input
...
@@ -270,7 +271,7 @@ def general_grouped_gemm(
...
@@ -270,7 +271,7 @@ def general_grouped_gemm(
seq_len
=
sum
(
m_splits
)
//
num_gemms
seq_len
=
sum
(
m_splits
)
//
num_gemms
out
[
0
]
=
w8a8_block_int8_matmul_batched
(
out
[
0
]
=
w8a8_block_int8_matmul_batched
(
qdout_data
,
qw_data
,
ref_scales_dout
,
ref_scales_w
,
out
[
0
].
view
(
num_gemms
,
seq_len
,
out
[
0
].
size
(
-
1
)),
[
128
,
128
],
qdout_data
,
qw_data
,
ref_scales_dout
,
ref_scales_w
,
out
[
0
].
view
(
num_gemms
,
seq_len
,
out
[
0
].
size
(
-
1
)),
[
blockwise_fp8_block_len
,
blockwise_fp8_block_len
],
output_dtype
=
out_dtype
output_dtype
=
out_dtype
)
)
return
out
,
bias
,
gelu_input
return
out
,
bias
,
gelu_input
...
@@ -286,7 +287,7 @@ def general_grouped_gemm(
...
@@ -286,7 +287,7 @@ def general_grouped_gemm(
ref_scales_x
=
[
a
.
_columnwise_scale_inv
for
a
in
A
]
ref_scales_x
=
[
a
.
_columnwise_scale_inv
for
a
in
A
]
out
=
w8a8_block_int8_matmul_wgrad_batched_native
(
out
=
w8a8_block_int8_matmul_wgrad_batched_native
(
qdout_data
,
qx_data
,
ref_scales_dout
,
ref_scales_x
,
out
,
accumulate
,
[
128
,
128
],
qdout_data
,
qx_data
,
ref_scales_dout
,
ref_scales_x
,
out
,
accumulate
,
[
blockwise_fp8_block_len
,
blockwise_fp8_block_len
],
output_dtype
=
out_dtype
output_dtype
=
out_dtype
)
)
return
out
,
bias
,
gelu_input
return
out
,
bias
,
gelu_input
...
...
transformer_engine/pytorch/csrc/common.h
View file @
0e886dab
...
@@ -49,6 +49,8 @@
...
@@ -49,6 +49,8 @@
#include <cassert>
#include <cassert>
#include <cstring>
#include <cstring>
#include <iostream>
#include <iostream>
#include <string>
#include <sstream>
#include <memory>
#include <memory>
#include <torch/csrc/distributed/c10d/ProcessGroup.hpp>
#include <torch/csrc/distributed/c10d/ProcessGroup.hpp>
#include <vector>
#include <vector>
...
@@ -61,6 +63,18 @@ namespace transformer_engine::pytorch {
...
@@ -61,6 +63,18 @@ namespace transformer_engine::pytorch {
// in python we have: dist_group_type = torch.distributed.ProcessGroup
// in python we have: dist_group_type = torch.distributed.ProcessGroup
using
dist_group_type
=
c10d
::
ProcessGroup
;
using
dist_group_type
=
c10d
::
ProcessGroup
;
inline
int
blockwise_fp8_block_len
()
{
const
char
*
env
=
std
::
getenv
(
"NVTE_BLOCKWISE_FP8_BLOCK_LEN"
);
if
(
env
==
nullptr
||
env
[
0
]
==
'\0'
)
{
return
128
;
}
int
value
;
std
::
istringstream
iss
(
env
);
iss
>>
value
;
NVTE_CHECK
(
iss
,
"Invalid environment variable value"
);
return
value
;
}
// Each tensor here is shape (N, ) holding all scaling
// Each tensor here is shape (N, ) holding all scaling
// data for a single FP8 block, e.g. LayerNormLinear
// data for a single FP8 block, e.g. LayerNormLinear
class
FP8TensorMeta
{
class
FP8TensorMeta
{
...
...
transformer_engine/pytorch/csrc/extensions/fp8_block_scaling_partial_cast.cpp
View file @
0e886dab
...
@@ -10,7 +10,7 @@ namespace transformer_engine::pytorch {
...
@@ -10,7 +10,7 @@ namespace transformer_engine::pytorch {
void
fp8_block_scaling_compute_partial_amax
(
const
at
::
Tensor
&
tensor
,
at
::
Tensor
amax
,
size_t
h
,
void
fp8_block_scaling_compute_partial_amax
(
const
at
::
Tensor
&
tensor
,
at
::
Tensor
amax
,
size_t
h
,
size_t
w
,
size_t
start_offset
,
size_t
block_len
)
{
size_t
w
,
size_t
start_offset
,
size_t
block_len
)
{
TORCH_CHECK
(
block_len
==
128
,
"Currently only block_len = 128 is supported"
);
TORCH_CHECK
(
block_len
==
128
||
block_len
==
64
,
"Currently only block_len = 128
or 64
is supported"
);
TORCH_CHECK
(
amax
.
dim
()
==
2
,
"amax must be a 2D tensor"
);
TORCH_CHECK
(
amax
.
dim
()
==
2
,
"amax must be a 2D tensor"
);
TORCH_CHECK
(
amax
.
scalar_type
()
==
at
::
ScalarType
::
Float
,
"amax must be a float tensor"
);
TORCH_CHECK
(
amax
.
scalar_type
()
==
at
::
ScalarType
::
Float
,
"amax must be a float tensor"
);
TORCH_CHECK
(
tensor
.
scalar_type
()
==
at
::
ScalarType
::
Float
||
TORCH_CHECK
(
tensor
.
scalar_type
()
==
at
::
ScalarType
::
Float
||
...
@@ -28,7 +28,7 @@ void fp8_block_scaling_compute_partial_amax(const at::Tensor &tensor, at::Tensor
...
@@ -28,7 +28,7 @@ void fp8_block_scaling_compute_partial_amax(const at::Tensor &tensor, at::Tensor
void
fp8_block_scaling_partial_cast
(
const
at
::
Tensor
&
inp
,
at
::
Tensor
out
,
const
at
::
Tensor
&
scale
,
void
fp8_block_scaling_partial_cast
(
const
at
::
Tensor
&
inp
,
at
::
Tensor
out
,
const
at
::
Tensor
&
scale
,
size_t
h
,
size_t
w
,
size_t
start_offset
,
size_t
block_len
,
size_t
h
,
size_t
w
,
size_t
start_offset
,
size_t
block_len
,
const
transformer_engine
::
DType
out_dtype
)
{
const
transformer_engine
::
DType
out_dtype
)
{
TORCH_CHECK
(
block_len
==
128
,
"Currently only block_len = 128 is supported"
);
TORCH_CHECK
(
block_len
==
128
||
block_len
==
64
,
"Currently only block_len = 128
or 64
is supported"
);
TORCH_CHECK
(
scale
.
dim
()
==
2
,
"scale must be a 2D tensor"
);
TORCH_CHECK
(
scale
.
dim
()
==
2
,
"scale must be a 2D tensor"
);
TORCH_CHECK
(
scale
.
scalar_type
()
==
at
::
ScalarType
::
Float
,
"scale must be a float tensor"
);
TORCH_CHECK
(
scale
.
scalar_type
()
==
at
::
ScalarType
::
Float
,
"scale must be a float tensor"
);
TORCH_CHECK
(
TORCH_CHECK
(
...
...
transformer_engine/pytorch/csrc/quantizer.cpp
View file @
0e886dab
...
@@ -298,7 +298,7 @@ std::pair<TensorWrapper, py::object> Float8BlockQuantizer::create_tensor(
...
@@ -298,7 +298,7 @@ std::pair<TensorWrapper, py::object> Float8BlockQuantizer::create_tensor(
size_t
k_dim
=
torch_shape
.
size
()
==
0
?
1u
:
torch_shape
.
back
();
size_t
k_dim
=
torch_shape
.
size
()
==
0
?
1u
:
torch_shape
.
back
();
size_t
m_dim
=
numel
/
k_dim
;
size_t
m_dim
=
numel
/
k_dim
;
constexpr
size_t
kBlockLen
=
128
;
size_t
kBlockLen
=
static_cast
<
size_t
>
(
blockwise_fp8_block_len
())
;
Float8BlockScaleTensorFormat
data_format
=
Float8BlockScaleTensorFormat
data_format
=
(
all_gather_usage
?
Float8BlockScaleTensorFormat
::
COMPACT
(
all_gather_usage
?
Float8BlockScaleTensorFormat
::
COMPACT
...
...
transformer_engine/pytorch/distributed.py
View file @
0e886dab
...
@@ -1084,7 +1084,7 @@ def _all_gather_fp8_blockwise(
...
@@ -1084,7 +1084,7 @@ def _all_gather_fp8_blockwise(
# Check that quantizer is valid
# Check that quantizer is valid
if
quantizer
is
not
None
and
not
isinstance
(
quantizer
,
Float8BlockQuantizer
):
if
quantizer
is
not
None
and
not
isinstance
(
quantizer
,
Float8BlockQuantizer
):
raise
ValueError
(
f
"Got non-FP8 blockwise quantizer (
{
quantizer
.
__class__
.
__name__
}
)"
)
raise
ValueError
(
f
"Got non-FP8 blockwise quantizer (
{
quantizer
.
__class__
.
__name__
}
)"
)
if
not
(
quantizer
.
block_scaling_dim
==
1
and
quantizer
.
block_len
==
128
):
if
not
(
quantizer
.
block_scaling_dim
==
1
and
(
quantizer
.
block_len
==
128
or
quantizer
.
block_len
==
64
)
):
raise
NotImplementedError
(
"Only 1D blockwise quantization is supported for allgather"
)
raise
NotImplementedError
(
"Only 1D blockwise quantization is supported for allgather"
)
# Output tensor dims
# Output tensor dims
...
...
transformer_engine/pytorch/fp8.py
View file @
0e886dab
...
@@ -28,6 +28,7 @@ from .utils import get_device_compute_capability
...
@@ -28,6 +28,7 @@ from .utils import get_device_compute_capability
from
.jit
import
jit_fuser
from
.jit
import
jit_fuser
from
torch.utils.cpp_extension
import
IS_HIP_EXTENSION
from
torch.utils.cpp_extension
import
IS_HIP_EXTENSION
int8_simulation_fp8
=
bool
(
int
(
os
.
getenv
(
"NVTE_INT8_SIM_FP8"
,
"0"
)))
int8_simulation_fp8
=
bool
(
int
(
os
.
getenv
(
"NVTE_INT8_SIM_FP8"
,
"0"
)))
blockwise_fp8_block_len
=
int
(
os
.
getenv
(
"NVTE_BLOCKWISE_FP8_BLOCK_LEN"
,
"128"
))
__all__
=
[
"fp8_autocast"
,
"fp8_model_init"
]
__all__
=
[
"fp8_autocast"
,
"fp8_model_init"
]
...
...
transformer_engine/pytorch/tensor/_internal/float8_blockwise_tensor_base.py
View file @
0e886dab
...
@@ -12,6 +12,7 @@ import torch
...
@@ -12,6 +12,7 @@ import torch
import
transformer_engine_torch
as
tex
import
transformer_engine_torch
as
tex
from
transformer_engine_torch
import
DType
as
TE_DType
from
transformer_engine_torch
import
DType
as
TE_DType
from
transformer_engine_torch
import
Float8BlockScaleTensorFormat
from
transformer_engine_torch
import
Float8BlockScaleTensorFormat
from
transformer_engine.pytorch.fp8
import
blockwise_fp8_block_len
from
..quantized_tensor
import
QuantizedTensorBase
from
..quantized_tensor
import
QuantizedTensorBase
...
@@ -134,7 +135,7 @@ class Float8BlockwiseQTensorBase(QuantizedTensorBase):
...
@@ -134,7 +135,7 @@ class Float8BlockwiseQTensorBase(QuantizedTensorBase):
return
torch
.
permute
(
columnwise_dq
,
tuple
(
permute_dims
)).
contiguous
()
return
torch
.
permute
(
columnwise_dq
,
tuple
(
permute_dims
)).
contiguous
()
def
_dequantize_vectorwise
(
self
,
*
,
dtype
:
torch
.
dtype
=
torch
.
float32
)
->
torch
.
Tensor
:
def
_dequantize_vectorwise
(
self
,
*
,
dtype
:
torch
.
dtype
=
torch
.
float32
)
->
torch
.
Tensor
:
block_len
=
128
block_len
=
blockwise_fp8_block_len
q_M
,
q_K
=
1
,
1
q_M
,
q_K
=
1
,
1
if
self
.
_rowwise_data
is
not
None
:
if
self
.
_rowwise_data
is
not
None
:
...
@@ -222,7 +223,7 @@ class Float8BlockwiseQTensorBase(QuantizedTensorBase):
...
@@ -222,7 +223,7 @@ class Float8BlockwiseQTensorBase(QuantizedTensorBase):
"""
"""
Construct plain PyTorch tensor from Float8BlockwiseQTensor
Construct plain PyTorch tensor from Float8BlockwiseQTensor
"""
"""
block_len
=
128
block_len
=
blockwise_fp8_block_len
if
not
self
.
_is_2D_scaled
:
if
not
self
.
_is_2D_scaled
:
return
self
.
_dequantize_vectorwise
(
dtype
=
dtype
)
return
self
.
_dequantize_vectorwise
(
dtype
=
dtype
)
...
...
transformer_engine/pytorch/tensor/float8_blockwise_tensor.py
View file @
0e886dab
...
@@ -16,6 +16,7 @@ from transformer_engine.common.recipe import Float8BlockScaling, Recipe
...
@@ -16,6 +16,7 @@ from transformer_engine.common.recipe import Float8BlockScaling, Recipe
from
._internal.float8_blockwise_tensor_base
import
Float8BlockwiseQTensorBase
from
._internal.float8_blockwise_tensor_base
import
Float8BlockwiseQTensorBase
from
.quantized_tensor
import
QuantizedTensor
,
Quantizer
,
_IdentityFunc
from
.quantized_tensor
import
QuantizedTensor
,
Quantizer
,
_IdentityFunc
from
..utils
import
devices_match
,
round_up_to_nearest_multiple
from
..utils
import
devices_match
,
round_up_to_nearest_multiple
from
transformer_engine.pytorch.fp8
import
blockwise_fp8_block_len
aten
=
torch
.
ops
.
aten
aten
=
torch
.
ops
.
aten
...
@@ -51,7 +52,7 @@ class Float8BlockQuantizer(Quantizer):
...
@@ -51,7 +52,7 @@ class Float8BlockQuantizer(Quantizer):
)
->
None
:
)
->
None
:
super
().
__init__
(
rowwise
=
rowwise
,
columnwise
=
columnwise
)
super
().
__init__
(
rowwise
=
rowwise
,
columnwise
=
columnwise
)
self
.
dtype
=
tex
.
DType
.
kInt8
if
int8_simulation_fp8
else
fp8_dtype
self
.
dtype
=
tex
.
DType
.
kInt8
if
int8_simulation_fp8
else
fp8_dtype
self
.
block_len
=
128
self
.
block_len
=
blockwise_fp8_block_len
self
.
force_pow_2_scales
=
force_pow_2_scales
self
.
force_pow_2_scales
=
force_pow_2_scales
self
.
amax_epsilon
=
amax_epsilon
self
.
amax_epsilon
=
amax_epsilon
self
.
block_scaling_dim
=
block_scaling_dim
self
.
block_scaling_dim
=
block_scaling_dim
...
...
Prev
1
2
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment