Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
TransformerEngine
Commits
a5892578
Commit
a5892578
authored
Jul 15, 2025
by
yuguo
Browse files
Merge branch 'develop_v2.4' of
http://10.16.6.30/dcutoolkit/deeplearing/TransformerEngine
parents
f9faa7ca
793e0103
Changes
8
Show whitespace changes
Inline
Side-by-side
Showing
8 changed files
with
294 additions
and
6 deletions
+294
-6
qa/L0_pytorch_unittest/test.sh
qa/L0_pytorch_unittest/test.sh
+2
-0
tests/pytorch/test_float8_current_scaling_exact.py
tests/pytorch/test_float8_current_scaling_exact.py
+3
-2
tests/pytorch/test_int8_channelwise_gemm_exact.py
tests/pytorch/test_int8_channelwise_gemm_exact.py
+1
-1
transformer_engine/common/include/transformer_engine/gemm.h
transformer_engine/common/include/transformer_engine/gemm.h
+2
-2
transformer_engine/common/transformer_engine.cpp
transformer_engine/common/transformer_engine.cpp
+4
-0
transformer_engine/pytorch/cpp_extensions/gemm.py
transformer_engine/pytorch/cpp_extensions/gemm.py
+270
-1
transformer_engine/pytorch/triton/per_token_group_quant.py
transformer_engine/pytorch/triton/per_token_group_quant.py
+10
-0
transformer_engine/pytorch/utils.py
transformer_engine/pytorch/utils.py
+2
-0
No files found.
qa/L0_pytorch_unittest/test.sh
View file @
a5892578
...
...
@@ -34,6 +34,8 @@ python3 -m pytest -v -s --junitxml=$XML_LOG_DIR/pytest_test_jit.xml $TE_PATH/tes
python3
-m
pytest
-v
-s
--junitxml
=
$XML_LOG_DIR
/pytest_test_fused_rope.xml
$TE_PATH
/tests/pytorch/test_fused_rope.py
||
test_fail
"test_fused_rope.py"
python3
-m
pytest
-v
-s
--junitxml
=
$XML_LOG_DIR
/pytest_test_float8tensor.xml
$TE_PATH
/tests/pytorch/test_float8tensor.py
||
test_fail
"test_float8tensor.py"
python3
-m
pytest
-v
-s
--junitxml
=
$XML_LOG_DIR
/pytest_test_float8blockwisetensor.xml
$TE_PATH
/tests/pytorch/test_float8blockwisetensor.py
||
test_fail
"test_float8blockwisetensor.py"
# channelwise int8 test
NVTE_INT8_SIM_FP8
=
1 python3
-m
pytest
-v
-s
test_float8_current_scaling_exact.py
python3
-m
pytest
-v
-s
--junitxml
=
$XML_LOG_DIR
/pytest_test_float8_blockwise_scaling_exact.xml
$TE_PATH
/tests/pytorch/test_float8_blockwise_scaling_exact.py
||
test_fail
"test_float8_blockwise_scaling_exact.py"
python3
-m
pytest
-v
-s
--junitxml
=
$XML_LOG_DIR
/pytest_test_float8_blockwise_gemm_exact.xml
$TE_PATH
/tests/pytorch/test_float8_blockwise_gemm_exact.py
||
test_fail
"test_float8_blockwise_gemm_exact.py"
python3
$TE_PATH
/tests/pytorch/test_int8_blockwise_gemm_exact.py
...
...
tests/pytorch/test_float8_current_scaling_exact.py
View file @
a5892578
...
...
@@ -14,6 +14,7 @@ import transformer_engine_torch as tex
from
transformer_engine.pytorch.fp8
import
FP8GlobalStateManager
from
transformer_engine.common.recipe
import
Float8CurrentScaling
from
transformer_engine.pytorch.fp8
import
fp8_autocast
,
get_fp8_torch_dtype
from
transformer_engine.pytorch.fp8
import
int8_simulation_fp8
# read env variable NVTE_TEST_FLOAT8_CURRENT_SCALING_EXACT_TENSOR_DUMP_DIR to override the default tensor dump directory
...
...
@@ -715,7 +716,7 @@ class TestFP8CurrentScalingRecipeLinear(TestFP8RecipeLinearBase):
hidden_size
,
out_size
,
dtype
,
use_bias
=
True
,
use_bias
=
False
if
int8_simulation_fp8
else
True
,
):
fp8_zero_tolerance_tensor_dumps_recipe2
=
None
# check tensor dumps dir, if the dir exists, then read files to get y, dgrad, wgrad, bgrad
...
...
@@ -775,7 +776,7 @@ class TestFP8CurrentScalingRecipeLayerNormLinear(TestFP8RecipeLayerNormLinearBas
hidden_size
,
out_size
,
dtype
,
use_bias
=
True
,
use_bias
=
False
if
int8_simulation_fp8
else
True
,
):
fp8_zero_tolerance_tensor_dumps_recipe2
=
None
# check tensor dumps dir, if the dir exists, then read files to get y, dgrad, wgrad, bgrad
...
...
tests/pytorch/test_int8_channelwise_gemm_exact.py
View file @
a5892578
...
...
@@ -557,4 +557,4 @@ te_dw = tex.generic_batchgemm(
)[
0
]
# print("te_dw.shape: ", te_dw.view(b, -1, te_dw.size(-1)).shape)
# print("te_dw: ", te_dw.view(b, -1, te_dw.size(-1)))
torch
.
testing
.
assert_close
(
te_dw
,
batched_int32_dw
,
atol
=
1e-5
,
rtol
=
1e-5
)
torch
.
testing
.
assert_close
(
te_dw
.
view
(
b
,
-
1
,
te_dw
.
size
(
-
1
))
,
batched_int32_dw
,
atol
=
0
,
rtol
=
0
)
transformer_engine/common/include/transformer_engine/gemm.h
View file @
a5892578
...
...
@@ -42,7 +42,7 @@ extern "C" {
void
nvte_cublas_gemm
(
const
NVTETensor
A
,
const
NVTETensor
B
,
NVTETensor
D
,
const
NVTETensor
bias
,
NVTETensor
pre_gelu_out
,
bool
transa
,
bool
transb
,
bool
grad
,
NVTETensor
workspace
,
bool
accumulate
,
bool
use_split_accumulator
,
int
math_sm_count
,
cudaStream_t
stream
,
bool
nvte_use_hipblaslt
=
0
,
bool
nvte_use_rocblas
=
0
,
int
compute_stream_offset
=
-
1
);
int
math_sm_count
,
cudaStream_t
stream
,
bool
nvte_use_hipblaslt
=
0
,
bool
nvte_use_rocblas
=
0
,
int
compute_stream_offset
=
0
);
/*! \brief Compute matrix multiplication of 2 matrices with chunking and atomic counters.
*
...
...
@@ -77,7 +77,7 @@ void nvte_cublas_atomic_gemm(const NVTETensor A, const NVTETensor B, NVTETensor
bool
transb
,
bool
grad
,
NVTETensor
workspace
,
bool
accumulate
,
bool
use_split_accumulator
,
int
math_sm_count
,
int
m_split
,
int
n_split
,
bool
gemm_producer
,
const
NVTETensor
counter
,
cudaStream_t
stream
,
bool
nvte_use_hipblaslt
=
0
,
bool
nvte_use_rocblas
=
0
,
int
compute_stream_offset
=
-
1
);
cudaStream_t
stream
,
bool
nvte_use_hipblaslt
=
0
,
bool
nvte_use_rocblas
=
0
,
int
compute_stream_offset
=
0
);
/*! \brief Compute multiple pairs of matrix multiplication, potentially fused with other operations,
* on multiple streams.
...
...
transformer_engine/common/transformer_engine.cpp
View file @
a5892578
...
...
@@ -638,6 +638,9 @@ void nvte_destroy_quantization_config(NVTEQuantizationConfig config) {
}
int
nvte_is_non_tn_fp8_gemm_supported
()
{
#if USE_ROCM
return
true
;
#else
int
deviceComputeCapability
=
transformer_engine
::
cuda
::
sm_arch
(
transformer_engine
::
cuda
::
current_device
());
...
...
@@ -645,4 +648,5 @@ int nvte_is_non_tn_fp8_gemm_supported() {
// (remove the note once it's done.)
return
(
deviceComputeCapability
>=
100
&&
deviceComputeCapability
<
120
)
||
deviceComputeCapability
>=
130
;
#endif
}
transformer_engine/pytorch/cpp_extensions/gemm.py
View file @
a5892578
...
...
@@ -9,14 +9,24 @@ import os
import
torch
import
transformer_engine_torch
as
tex
import
w8a8_matmul_extension
from
..constants
import
TE_DType
from
..constants
import
TE_DType
,
TE_DType_To_Torch
from
..utils
import
get_sm_count
,
_empty_tensor
from
transformer_engine.pytorch.triton.blockwise_int8_gemm_nt
import
w8a8_block_int8_matmul
,
w8a8_block_int8_matmul_batched
from
transformer_engine.pytorch.triton.blockwise_int8_gemm_nt_wgrad
import
w8a8_block_int8_matmul_wgrad
,
w8a8_block_int8_matmul_wgrad_batched
,
w8a8_block_int8_matmul_wgrad_batched_native
from
..tensor.quantized_tensor
import
Quantizer
from
..tensor._internal.float8_blockwise_tensor_base
import
Float8BlockwiseQTensorBase
from
..tensor._internal.float8_tensor_base
import
Float8TensorBase
from
...debug.pytorch.debug_quantization
import
DebugQuantizer
from
transformer_engine.pytorch.fp8
import
blockwise_fp8_block_len
from
transformer_engine.pytorch.triton.per_token_group_quant
import
(
per_token_quant_fp8_to_int8
,
per_token_quant_fp8_to_int8_v2
,
per_token_quant_fp8_to_int8_opt
,
channelwise_dequantize
,
channelwise_dequantize_transA
,
channelwise_dequantize_transA_float
,
channelwise_dequantize_transB
,
channelwise_dequantize_transA_add
,
channelwise_dequantize_transA_float_add
)
int8_simulation_fp8
=
bool
(
int
(
os
.
getenv
(
"NVTE_INT8_SIM_FP8"
,
"0"
)))
__all__
=
[
...
...
@@ -165,6 +175,106 @@ def general_gemm(
):
raise
RuntimeError
(
"GEMM with Float8BlockwiseQTensor requires GEMM_READY format"
)
if
int8_simulation_fp8
and
(
isinstance
(
A
,
Float8TensorBase
)
or
isinstance
(
B
,
Float8TensorBase
)):
assert
not
gelu
,
"GELU not supported with int8 simulation"
assert
gelu_in
is
None
,
"GELU input not supported with int8 simulation"
assert
bias
is
None
,
"Bias not supported with int8 simulation"
assert
ub
is
None
,
"User buffer not supported with int8 simulation"
assert
ub_type
is
None
,
"User buffer type not supported with int8 simulation"
assert
extra_output
is
None
,
"Extra output not supported with int8 simulation"
assert
not
bulk_overlap
,
"Bulk overlap not supported with int8 simulation"
assert
out_dtype
is
torch
.
bfloat16
or
out_dtype
is
torch
.
float32
,
"Out_dtype must be bfloat16 or float32 for int8 simulation"
if
layout
==
"TN"
:
assert
out_dtype
is
torch
.
bfloat16
x_int8
,
x_scales
=
per_token_quant_fp8_to_int8
(
B
.
_data
.
view
(
dtype
=
TE_DType_To_Torch
[
B
.
_fp8_dtype
]),
B
.
_scale_inv
,
False
)
w_int8
,
w_scales
=
per_token_quant_fp8_to_int8
(
A
.
_data
.
view
(
dtype
=
TE_DType_To_Torch
[
A
.
_fp8_dtype
]),
A
.
_scale_inv
,
False
)
y_int32
=
tex
.
generic_gemm
(
w_int8
,
transa
,
x_int8
,
transb
,
None
,
quantization_params
,
TE_DType
[
torch
.
int32
],
bias
,
bias_dtype
,
gelu
,
gelu_in
,
grad
,
# grad
workspace
,
workspace
.
shape
[
0
],
False
,
use_split_accumulator
,
)[
0
]
y
=
channelwise_dequantize_transB
(
x_scales
,
w_scales
,
y_int32
)
return
y
,
None
,
None
,
None
elif
layout
==
"NN"
:
assert
out_dtype
is
torch
.
bfloat16
dy_int8
,
dy_scales
=
per_token_quant_fp8_to_int8
(
B
.
_data
.
view
(
dtype
=
TE_DType_To_Torch
[
B
.
_fp8_dtype
]),
B
.
_scale_inv
,
False
)
w_int8
,
w_scales
=
per_token_quant_fp8_to_int8_opt
(
A
.
_data
.
view
(
dtype
=
TE_DType_To_Torch
[
A
.
_fp8_dtype
]),
A
.
_scale_inv
,
False
)
dx_int32
=
tex
.
generic_gemm
(
w_int8
,
transa
,
dy_int8
,
transb
,
None
,
quantization_params
,
TE_DType
[
torch
.
int32
],
bias
,
bias_dtype
,
gelu
,
gelu_in
,
grad
,
# grad
workspace
,
workspace
.
shape
[
0
],
False
,
use_split_accumulator
,
)[
0
]
dx
=
channelwise_dequantize
(
dy_scales
,
w_scales
,
dx_int32
)
return
dx
,
None
,
None
,
None
elif
layout
==
"NT"
:
dy_int8
,
dy_scales
=
per_token_quant_fp8_to_int8_opt
(
B
.
_data
.
view
(
dtype
=
TE_DType_To_Torch
[
B
.
_fp8_dtype
]),
B
.
_scale_inv
,
False
)
x_int8
,
x_scales
=
per_token_quant_fp8_to_int8_opt
(
A
.
_data
.
view
(
dtype
=
TE_DType_To_Torch
[
A
.
_fp8_dtype
]),
A
.
_scale_inv
,
False
)
dw_int32
=
tex
.
generic_gemm
(
x_int8
,
transa
,
dy_int8
,
transb
,
None
,
quantization_params
,
TE_DType
[
torch
.
int32
],
bias
,
bias_dtype
,
gelu
,
gelu_in
,
grad
,
# grad
workspace
,
workspace
.
shape
[
0
],
False
,
use_split_accumulator
,
)[
0
]
if
out_dtype
is
torch
.
bfloat16
:
if
accumulate
:
out
=
channelwise_dequantize_transA_add
(
dy_scales
,
x_scales
,
dw_int32
,
out
)
else
:
out
=
channelwise_dequantize_transA
(
dy_scales
,
x_scales
,
dw_int32
)
else
:
if
accumulate
:
out
=
channelwise_dequantize_transA_float_add
(
dy_scales
,
x_scales
,
dw_int32
,
out
)
else
:
out
=
channelwise_dequantize_transA_float
(
dy_scales
,
x_scales
,
dw_int32
)
return
out
,
None
,
None
,
None
else
:
raise
ValueError
(
f
"Unsupported layout
{
layout
}
in int8 simulation fp8"
)
args
=
(
A
,
transa
,
# transa
...
...
@@ -311,6 +421,165 @@ def general_grouped_gemm(
else
:
raise
ValueError
(
f
"Unsupported layout
{
layout
}
in int8 simulation fp8"
)
if
int8_simulation_fp8
and
(
isinstance
(
A
[
0
],
Float8TensorBase
)
or
isinstance
(
B
[
0
],
Float8TensorBase
)):
assert
len
(
set
(
m_splits
))
==
1
,
"Int8 simulation groupgemm just surpport token pad as same as batchgemm for now."
assert
not
gelu
,
"GELU not supported with int8 simulation groupgemm."
assert
not
use_bias
,
"Bias not supported with int8 simulation groupgemm."
assert
out_dtype
is
torch
.
bfloat16
or
out_dtype
is
torch
.
float32
,
"Out_dtype must be bfloat16 or float32 for int8 simulation"
if
layout
==
"TN"
:
assert
out_dtype
is
torch
.
bfloat16
qx_data_list
=
[]
w_data_list
=
[]
scales_x_list
=
[]
scales_w_list
=
[]
for
b
in
B
:
x_int8
,
x_scales
=
per_token_quant_fp8_to_int8
(
b
.
_data
.
view
(
dtype
=
TE_DType_To_Torch
[
b
.
_fp8_dtype
]),
b
.
_scale_inv
,
False
)
qx_data_list
.
append
(
x_int8
)
scales_x_list
.
append
(
x_scales
)
for
a
in
A
:
w_int8
,
w_scales
=
per_token_quant_fp8_to_int8
(
a
.
_data
.
view
(
dtype
=
TE_DType_To_Torch
[
a
.
_fp8_dtype
]),
a
.
_scale_inv
,
False
)
w_data_list
.
append
(
w_int8
)
scales_w_list
.
append
(
w_scales
)
num_gemms
=
len
(
A
)
seq_len
=
sum
(
m_splits
)
//
num_gemms
qx_data
=
torch
.
stack
(
qx_data_list
).
contiguous
()
w_data
=
torch
.
stack
(
w_data_list
).
contiguous
()
y_int32
=
torch
.
empty
((
num_gemms
,
seq_len
,
out
[
0
].
size
(
-
1
)),
dtype
=
torch
.
int32
,
device
=
"cuda"
)
y_int32
=
tex
.
generic_batchgemm
(
w_data
.
view
(
-
1
,
w_data
.
size
(
-
1
)),
transa
,
qx_data
.
view
(
-
1
,
qx_data
.
size
(
-
1
)),
transb
,
y_int32
.
view
(
-
1
,
y_int32
.
size
(
-
1
)),
num_gemms
,
None
,
TE_DType
[
torch
.
int32
],
bias
[
0
],
bias_dtype
,
gelu
,
gelu_input
[
0
],
grad
,
# grad
workspaces
[
0
],
workspaces
[
0
].
shape
[
0
],
False
,
use_split_accumulator
,
)[
0
]
out
[
0
]
=
out
[
0
].
view
(
num_gemms
,
seq_len
,
out
[
0
].
size
(
-
1
))
for
i
in
range
(
num_gemms
):
out
[
0
][
i
]
=
channelwise_dequantize_transB
(
scales_x_list
[
i
],
scales_w_list
[
i
],
y_int32
[
i
])
return
out
.
view
(
-
1
,
out
[
0
].
size
(
-
1
)),
bias
,
gelu_input
elif
layout
==
"NN"
:
assert
out_dtype
is
torch
.
bfloat16
qdout_data_list
=
[]
w_data_list
=
[]
scales_dout_list
=
[]
scales_w_list
=
[]
for
b
in
B
:
dy_int8
,
dy_scales
=
per_token_quant_fp8_to_int8
(
b
.
_data
.
view
(
dtype
=
TE_DType_To_Torch
[
b
.
_fp8_dtype
]),
b
.
_scale_inv
,
False
)
qdout_data_list
.
append
(
dy_int8
)
scales_dout_list
.
append
(
dy_scales
)
for
a
in
A
:
w_int8
,
w_scales
=
per_token_quant_fp8_to_int8_opt
(
a
.
_data
.
view
(
dtype
=
TE_DType_To_Torch
[
a
.
_fp8_dtype
]),
a
.
_scale_inv
,
False
)
w_data_list
.
append
(
w_int8
)
scales_w_list
.
append
(
w_scales
)
num_gemms
=
len
(
A
)
seq_len
=
sum
(
m_splits
)
//
num_gemms
qdout_data
=
torch
.
stack
(
qdout_data_list
).
contiguous
()
w_data
=
torch
.
stack
(
w_data_list
).
contiguous
()
dx_int32
=
torch
.
empty
((
num_gemms
,
seq_len
,
out
[
0
].
size
(
-
1
)),
dtype
=
torch
.
int32
,
device
=
"cuda"
)
dx_int32
=
tex
.
generic_batchgemm
(
w_data
.
view
(
-
1
,
w_data
.
size
(
-
1
)),
transa
,
qdout_data
.
view
(
-
1
,
qdout_data
.
size
(
-
1
)),
transb
,
dx_int32
.
view
(
-
1
,
dx_int32
.
size
(
-
1
)),
num_gemms
,
None
,
TE_DType
[
torch
.
int32
],
bias
[
0
],
bias_dtype
,
gelu
,
gelu_input
[
0
],
grad
,
# grad
workspaces
[
0
],
workspaces
[
0
].
shape
[
0
],
False
,
use_split_accumulator
,
)[
0
]
out
[
0
]
=
out
[
0
].
view
(
num_gemms
,
seq_len
,
out
[
0
].
size
(
-
1
))
for
i
in
range
(
num_gemms
):
out
[
0
][
i
]
=
channelwise_dequantize
(
scales_dout_list
[
i
],
scales_w_list
[
i
],
dx_int32
[
i
])
return
out
,
bias
,
gelu_input
elif
layout
==
"NT"
:
qdout_data_list
=
[]
qx_data_list
=
[]
scales_dout_list
=
[]
scales_x_list
=
[]
for
b
in
B
:
dy_int8
,
dy_scales
=
per_token_quant_fp8_to_int8_opt
(
b
.
_data
.
view
(
dtype
=
TE_DType_To_Torch
[
b
.
_fp8_dtype
]),
b
.
_scale_inv
,
False
)
qdout_data_list
.
append
(
dy_int8
)
scales_dout_list
.
append
(
dy_scales
)
for
a
in
A
:
x_int8
,
x_scales
=
per_token_quant_fp8_to_int8_opt
(
a
.
_data
.
view
(
dtype
=
TE_DType_To_Torch
[
a
.
_fp8_dtype
]),
a
.
_scale_inv
,
False
)
qx_data_list
.
append
(
x_int8
)
scales_x_list
.
append
(
x_scales
)
num_gemms
=
len
(
A
)
qdout_data
=
torch
.
stack
(
qdout_data_list
).
contiguous
()
qx_data
=
torch
.
stack
(
qx_data_list
).
contiguous
()
dw_int32
=
torch
.
empty
((
num_gemms
,
qdout_data
.
size
(
-
1
),
qx_data
.
size
(
-
1
)),
dtype
=
torch
.
int32
,
device
=
"cuda"
)
dw_int32
=
tex
.
generic_batchgemm
(
qx_data
.
view
(
-
1
,
qx_data
.
size
(
-
1
)),
transa
,
qdout_data
.
view
(
-
1
,
qdout_data
.
size
(
-
1
)),
transb
,
dw_int32
.
view
(
-
1
,
dw_int32
.
size
(
-
1
)),
num_gemms
,
None
,
TE_DType
[
torch
.
int32
],
bias
[
0
],
bias_dtype
,
gelu
,
gelu_input
[
0
],
grad
,
# grad
workspaces
[
0
],
workspaces
[
0
].
shape
[
0
],
False
,
use_split_accumulator
,
)[
0
]
if
out_dtype
is
torch
.
bfloat16
:
if
accumulate
:
for
i
in
num_gemms
:
out
[
i
]
=
channelwise_dequantize_transA_add
(
scales_dout_list
[
i
],
scales_x_list
[
i
],
dw_int32
[
i
],
out
[
i
])
else
:
for
i
in
num_gemms
:
out
[
i
]
=
channelwise_dequantize_transA
(
scales_dout_list
[
i
],
scales_x_list
[
i
],
dw_int32
[
i
])
else
:
if
accumulate
:
for
i
in
num_gemms
:
out
[
i
]
=
channelwise_dequantize_transA_float_add
(
scales_dout_list
[
i
],
scales_x_list
[
i
],
dw_int32
[
i
],
out
[
i
])
else
:
for
i
in
num_gemms
:
out
[
i
]
=
channelwise_dequantize_transA_float
(
scales_dout_list
[
i
],
scales_x_list
[
i
],
dw_int32
[
i
])
return
out
,
bias
,
gelu_input
else
:
raise
ValueError
(
f
"Unsupported layout
{
layout
}
in int8 simulation fp8"
)
bias
=
tex
.
te_general_grouped_gemm
(
A
,
transa
,
...
...
transformer_engine/pytorch/triton/per_token_group_quant.py
View file @
a5892578
...
...
@@ -328,6 +328,16 @@ def channelwise_dequantize_transA_float(A, B, C):
out_scales
=
A
.
T
*
B
return
out_scales
*
C
.
to
(
dtype
=
torch
.
float32
)
@
torch
.
compile
(
mode
=
"max-autotune-no-cudagraphs"
)
def
channelwise_dequantize_transA_add
(
A
,
B
,
C
,
D
):
out_scales
=
A
.
T
*
B
return
(
out_scales
*
C
.
to
(
dtype
=
torch
.
float32
)).
to
(
torch
.
bfloat16
)
+
D
@
torch
.
compile
(
mode
=
"max-autotune-no-cudagraphs"
)
def
channelwise_dequantize_transA_float_add
(
A
,
B
,
C
,
D
):
out_scales
=
A
.
T
*
B
return
out_scales
*
C
.
to
(
dtype
=
torch
.
float32
)
+
D
@
torch
.
compile
(
mode
=
"max-autotune-no-cudagraphs"
)
def
channelwise_dequantize_transB
(
A
,
B
,
C
):
out_scales
=
A
*
B
.
T
...
...
transformer_engine/pytorch/utils.py
View file @
a5892578
...
...
@@ -475,6 +475,8 @@ def is_non_tn_fp8_gemm_supported() -> bool:
"""Checks whether the device supports
non-TN layouts for FP8 GEMMs.
"""
if
IS_HIP_EXTENSION
:
return
True
device_capability
=
torch
.
cuda
.
get_device_capability
()
return
(
10
,
0
)
<=
device_capability
<
(
12
,
0
)
or
device_capability
>=
(
13
,
0
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment