Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
TransformerEngine
Commits
e45d66a3
"docs/vscode:/vscode.git/clone" did not exist on "5265631d15d59735152c8b72b38d960110987f10"
Commit
e45d66a3
authored
Sep 18, 2025
by
yuguo
Browse files
Merge branch 'release_v2.7' of
https://github.com/NVIDIA/TransformerEngine
into release_v2.7
parents
419897d1
fedd9ddc
Changes
20
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
1336 additions
and
43 deletions
+1336
-43
qa/L0_cppunittest/test.sh
qa/L0_cppunittest/test.sh
+1
-1
qa/L0_pytorch_unittest/test.sh
qa/L0_pytorch_unittest/test.sh
+0
-2
qa/L1_cpp_distributed/test.sh
qa/L1_cpp_distributed/test.sh
+15
-0
qa/L1_pytorch_onnx_unittest/test.sh
qa/L1_pytorch_onnx_unittest/test.sh
+11
-0
setup.py
setup.py
+13
-0
tests/cpp/CMakeLists.txt
tests/cpp/CMakeLists.txt
+1
-0
tests/cpp/comm_gemm/CMakeLists.txt
tests/cpp/comm_gemm/CMakeLists.txt
+20
-0
tests/cpp/comm_gemm/test_comm_gemm.cu
tests/cpp/comm_gemm/test_comm_gemm.cu
+441
-0
tests/pytorch/test_numerics.py
tests/pytorch/test_numerics.py
+13
-4
tests/pytorch/test_onnx_export.py
tests/pytorch/test_onnx_export.py
+58
-1
tests/pytorch/utils.py
tests/pytorch/utils.py
+1
-1
transformer_engine/common/CMakeLists.txt
transformer_engine/common/CMakeLists.txt
+23
-5
transformer_engine/common/comm_gemm/comm_gemm.cpp
transformer_engine/common/comm_gemm/comm_gemm.cpp
+519
-0
transformer_engine/common/common.cu
transformer_engine/common/common.cu
+18
-0
transformer_engine/common/common.h
transformer_engine/common/common.h
+14
-2
transformer_engine/common/gemm/cublaslt_gemm.cu
transformer_engine/common/gemm/cublaslt_gemm.cu
+0
-18
transformer_engine/common/include/transformer_engine/comm_gemm.h
...rmer_engine/common/include/transformer_engine/comm_gemm.h
+156
-0
transformer_engine/common/util/logging.h
transformer_engine/common/util/logging.h
+17
-0
transformer_engine/pytorch/attention/dot_product_attention/utils.py
...r_engine/pytorch/attention/dot_product_attention/utils.py
+11
-5
transformer_engine/pytorch/onnx_extensions.py
transformer_engine/pytorch/onnx_extensions.py
+4
-4
No files found.
qa/L0_cppunittest/test.sh
View file @
e45d66a3
...
...
@@ -17,4 +17,4 @@ cd $TE_PATH/tests/cpp
cmake
-GNinja
-Bbuild
.
cmake
--build
build
export
OMP_NUM_THREADS
=
$((
NUM_PHYSICAL_CORES
/
NUM_PARALLEL_JOBS
))
ctest
--test-dir
build
-j
$NUM_PARALLEL_JOBS
ctest
--test-dir
build
-j
$NUM_PARALLEL_JOBS
-E
'(AgGemm|GemmRs|GemmAr)'
qa/L0_pytorch_unittest/test.sh
View file @
e45d66a3
...
...
@@ -23,8 +23,6 @@ set -x
mkdir
-p
"
$XML_LOG_DIR
"
pip3
install
pytest
==
8.2.1
||
error_exit
"Failed to install pytest"
pip3
install
onnxruntime
==
1.20.1
||
error_exit
"Failed to install onnxruntime"
pip3
install
onnxruntime_extensions
==
0.13.0
||
error_exit
"Failed to install onnxruntime_extensions"
python3
-m
pytest
-v
-s
--junitxml
=
$XML_LOG_DIR
/pytest_test_sanity.xml
$TE_PATH
/tests/pytorch/test_sanity.py
||
test_fail
"test_sanity.py"
python3
-m
pytest
-v
-s
--junitxml
=
$XML_LOG_DIR
/pytest_test_recipe.xml
$TE_PATH
/tests/pytorch/test_recipe.py
||
test_fail
"test_recipe.py"
...
...
qa/L1_cpp_distributed/test.sh
0 → 100755
View file @
e45d66a3
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
set
-e
# Find TE
:
${
TE_PATH
:
=/opt/transformerengine
}
TE_LIB_PATH
=
$(
pip3 show transformer-engine |
grep
-E
"Location:|Editable project location:"
|
tail
-n
1 |
awk
'{print $NF}'
)
export
LD_LIBRARY_PATH
=
$TE_LIB_PATH
:
$LD_LIBRARY_PATH
cd
$TE_PATH
/tests/cpp
cmake
-GNinja
-S
.
-Bbuild
cmake
--build
build
mpirun
--allow-run-as-root
--np
4
--oversubscribe
./build/comm_gemm/test_comm_gemm
qa/L1_pytorch_onnx_unittest/test.sh
0 → 100644
View file @
e45d66a3
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
pip3
install
onnxruntime
==
1.20.1
pip3
install
onnxruntime_extensions
==
0.13.0
:
${
TE_PATH
:
=/opt/transformerengine
}
python3
-m
pytest
--tb
=
auto
$TE_PATH
/tests/pytorch/test_onnx_export.py
setup.py
View file @
e45d66a3
...
...
@@ -6,6 +6,7 @@
# NVTE_FRAMEWORK=pytorch NVTE_USE_ROCM=1 NVTE_USE_HIPBLASLT=1 NVTE_USE_ROCBLAS=1 CMAKE_PREFIX_PATH=/opt/dtk/lib/cmake/amd_comgr/ MPI_HOME=/opt/mpi/ NVTE_UB_WITH_MPI=1 CXX=hipcc pip3 install . -v
# NVTE_FRAMEWORK=pytorch NVTE_USE_ROCM=1 NVTE_USE_HIPBLASLT=1 NVTE_USE_ROCBLAS=1 CMAKE_PREFIX_PATH=/opt/dtk/lib/cmake/amd_comgr/ MPI_HOME=/opt/mpi/ NVTE_UB_WITH_MPI=1 CXX=hipcc PYTHONPATH=/home/TransformerEngine/3rdparty/hipify_torch:$PYTHONPATH python3 setup.py bdist_wheel
from
importlib
import
metadata
import
os
import
time
from
pathlib
import
Path
...
...
@@ -82,6 +83,18 @@ def setup_common_extension() -> CMakeExtension:
if
bool
(
int
(
os
.
getenv
(
"NVTE_BUILD_ACTIVATION_WITH_FAST_MATH"
,
"0"
))):
cmake_flags
.
append
(
"-DNVTE_BUILD_ACTIVATION_WITH_FAST_MATH=ON"
)
if
bool
(
int
(
os
.
getenv
(
"NVTE_WITH_CUBLASMP"
,
"0"
))):
cmake_flags
.
append
(
"-DNVTE_WITH_CUBLASMP=ON"
)
cublasmp_dir
=
os
.
getenv
(
"CUBLASMP_HOME"
)
or
metadata
.
distribution
(
"nvidia-cublasmp-cu12"
).
locate_file
(
"nvidia/cublasmp/cu12"
)
cmake_flags
.
append
(
f
"-DCUBLASMP_DIR=
{
cublasmp_dir
}
"
)
nvshmem_dir
=
os
.
getenv
(
"NVSHMEM_HOME"
)
or
metadata
.
distribution
(
"nvidia-nvshmem-cu12"
).
locate_file
(
"nvidia/nvshmem"
)
cmake_flags
.
append
(
f
"-DNVSHMEM_DIR=
{
nvshmem_dir
}
"
)
print
(
"CMAKE_FLAGS:"
,
cmake_flags
[
-
2
:])
# Add custom CMake arguments from environment variable
nvte_cmake_extra_args
=
os
.
getenv
(
"NVTE_CMAKE_EXTRA_ARGS"
)
if
nvte_cmake_extra_args
:
...
...
tests/cpp/CMakeLists.txt
View file @
e45d66a3
...
...
@@ -77,6 +77,7 @@ find_library(TE_LIB NAMES transformer_engine PATHS "${TE_LIB_PATH}/.." ${TE_LIB_
message
(
STATUS
"Found transformer_engine library:
${
TE_LIB
}
"
)
include_directories
(
../../transformer_engine/common/include
)
include_directories
(
../../transformer_engine/common
)
include_directories
(
../../transformer_engine
)
include_directories
(
${
CMAKE_SOURCE_DIR
}
)
if
(
USE_CUDA
)
...
...
tests/cpp/comm_gemm/CMakeLists.txt
0 → 100644
View file @
e45d66a3
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
if
(
USE_CUDA
)
add_executable
(
test_comm_gemm
test_comm_gemm.cu
../test_common.cu
)
find_package
(
OpenMP REQUIRED
)
find_package
(
MPI REQUIRED
)
find_library
(
NCCL_LIB
NAMES nccl libnccl
PATH_SUFFIXES lib
REQUIRED
)
target_include_directories
(
test_comm_gemm PRIVATE
${
MPI_CXX_INCLUDE_PATH
}
$ENV{CUBLASMP_HOME}/include
)
target_link_libraries
(
test_comm_gemm PUBLIC CUDA::cuda_driver CUDA::cudart GTest::gtest
${
TE_LIB
}
CUDA::nvrtc CUDNN::cudnn MPI::MPI_CXX
${
NCCL_LIB
}
OpenMP::OpenMP_CXX
)
include
(
GoogleTest
)
gtest_discover_tests
(
test_comm_gemm DISCOVERY_TIMEOUT 600
)
endif
()
tests/cpp/comm_gemm/test_comm_gemm.cu
0 → 100644
View file @
e45d66a3
/*************************************************************************
* Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
*
* See LICENSE for license information.
************************************************************************/
#include <cuda.h>
#include <gtest/gtest.h>
#include <mpi.h>
#include <nccl.h>
#include <transformer_engine/comm_gemm.h>
#include <transformer_engine/gemm.h>
#include <transformer_engine/transformer_engine.h>
#include <iostream>
#include <limits>
#include <random>
#include <sstream>
#include <string>
#include <vector>
#include "../test_common.h"
#include "common.h"
using
transformer_engine
::
DType
;
using
transformer_engine
::
TypeInfo
;
#define CHECK_MPI(expr) \
do { \
int err = (expr); \
if (err != MPI_SUCCESS) { \
char err_str[MPI_MAX_ERROR_STRING + 1]{}; \
int _len{}; \
MPI_Error_string(err, err_str, &_len); \
EXPECT_TRUE(false) << "MPI error: " << err << ": " << err_str; \
} \
} while (false)
#define CHECK_NCCL(expr) \
do { \
ncclResult_t err = (expr); \
if (err != ncclSuccess) { \
EXPECT_TRUE(false) << "NCCL error: " << err << ": " << ncclGetErrorString(err); \
} \
} while (false)
#define CHECK_CU(expr) \
do { \
CUresult err = (expr); \
if (err != CUDA_SUCCESS) { \
const char* str{}; \
CUresult e_str = cuGetErrorString(err, &str); \
if (e_str != CUDA_SUCCESS) str = "(unknown)"; \
EXPECT_TRUE(false) << "CU error: " << err << ": " << str; \
} \
} while (false)
int
main
(
int
argc
,
char
*
argv
[])
{
::
testing
::
InitGoogleTest
(
&
argc
,
argv
);
CHECK_MPI
(
MPI_Init
(
&
argc
,
&
argv
));
auto
ret
=
RUN_ALL_TESTS
();
CHECK_MPI
(
MPI_Finalize
());
return
ret
;
}
bool
IsMulticastSupported
(
int
device_id
)
{
int
supported
=
0
;
CHECK_CU
(
cuDeviceGetAttribute
(
&
supported
,
CU_DEVICE_ATTRIBUTE_MULTICAST_SUPPORTED
,
device_id
));
return
supported
;
}
template
<
typename
T
>
std
::
vector
<
T
>
CopyMatrix
(
const
std
::
vector
<
T
>&
data
,
size_t
mstart
,
size_t
nstart
,
size_t
msize
,
size_t
nsize
,
size_t
ld
)
{
std
::
vector
<
T
>
ret
(
msize
*
nsize
);
size_t
dst
=
0
;
for
(
size_t
j
=
nstart
;
j
<
nstart
+
nsize
;
++
j
)
{
for
(
size_t
i
=
mstart
;
i
<
mstart
+
msize
;
++
i
)
{
ret
[
dst
++
]
=
data
[
j
*
ld
+
i
];
}
}
return
ret
;
}
template
<
typename
T
>
test
::
Tensor
Make
(
size_t
m
,
size_t
n
,
float
scale
)
{
test
::
Tensor
ret
(
""
,
std
::
vector
{
n
,
m
},
TypeInfo
<
T
>::
dtype
);
ret
.
set_scale
(
scale
);
ret
.
set_scale_inv
(
1.0
/
scale
);
return
ret
;
}
template
<
typename
T
>
test
::
Tensor
MakeFromData
(
const
std
::
vector
<
T
>&
data
,
size_t
mstart
,
size_t
nstart
,
size_t
msize
,
size_t
nsize
,
size_t
ld
,
float
scale
)
{
test
::
Tensor
ret
(
""
,
std
::
vector
{
nsize
,
msize
},
TypeInfo
<
T
>::
dtype
);
ret
.
set_scale
(
scale
);
ret
.
set_scale_inv
(
1.0
/
scale
);
auto
local
=
CopyMatrix
(
data
,
mstart
,
nstart
,
msize
,
nsize
,
ld
);
NVTE_CHECK_CUDA
(
cudaMemcpy
(
ret
.
rowwise_dptr
(),
local
.
data
(),
local
.
size
()
*
sizeof
local
[
0
],
cudaMemcpyDefault
));
return
ret
;
}
template
<
typename
T
>
float
GetScale
(
float
amax
)
{
if
constexpr
(
sizeof
(
T
)
>
1
)
return
1.0
;
return
static_cast
<
float
>
(
static_cast
<
T
>
(
std
::
numeric_limits
<
float
>::
max
()))
/
amax
;
}
struct
Params
{
DType
a_type
;
DType
b_type
;
DType
d_type
;
bool
transa
;
bool
transb
;
size_t
m
;
size_t
n
;
size_t
k
;
float
tol
;
};
class
CommGemmFixure
:
public
::
testing
::
TestWithParam
<
Params
>
{
protected:
CommGemmFixure
()
{
CHECK_MPI
(
MPI_Comm_size
(
MPI_COMM_WORLD
,
&
nranks_
));
CHECK_MPI
(
MPI_Comm_rank
(
MPI_COMM_WORLD
,
&
rank_
));
NVTE_CHECK_CUDA
(
cudaSetDevice
(
rank_
));
ncclUniqueId
id
{};
if
(
rank_
==
0
)
CHECK_NCCL
(
ncclGetUniqueId
(
&
id
));
CHECK_MPI
(
MPI_Bcast
(
&
id
,
sizeof
(
id
),
MPI_BYTE
,
0
,
MPI_COMM_WORLD
));
CHECK_NCCL
(
ncclCommInitRank
(
&
comm_
,
nranks_
,
id
,
rank_
));
ctx_
=
nvte_comm_gemm_ctx_create
(
comm_
,
nranks_
,
rank_
);
}
~
CommGemmFixure
()
{
nvte_comm_gemm_ctx_destroy
(
ctx_
);
ncclCommDestroy
(
comm_
);
}
struct
PatternDims
{
int64_t
a_rows_start
;
int64_t
a_rows_num
;
int64_t
a_cols_start
;
int64_t
a_cols_num
;
int64_t
b_rows_start
;
int64_t
b_rows_num
;
int64_t
b_cols_start
;
int64_t
b_cols_num
;
int64_t
d_rows_start
;
int64_t
d_rows_num
;
int64_t
d_cols_start
;
int64_t
d_cols_num
;
};
virtual
PatternDims
DistributeTensors
(
int64_t
m
,
int64_t
n
,
int64_t
k
)
=
0
;
virtual
void
CommGemm
(
int64_t
m
,
int64_t
n
,
int64_t
k
,
const
NVTETensor
a
,
const
NVTETensor
b
,
const
NVTETensor
d
,
const
NVTETensor
bias
,
const
NVTETensor
pre_act_out
,
bool
transa
,
bool
transb
,
bool
grad
,
bool
accumulate
,
int
comm_sm_count
,
cudaStream_t
stream
)
=
0
;
template
<
typename
AType
,
typename
BType
,
typename
DType
,
typename
BiasType
>
void
Run
(
bool
transa
,
bool
transb
,
size_t
m
,
size_t
n
,
size_t
k
,
float
tol
)
{
cudaStream_t
stream
{};
NVTE_CHECK_CUDA
(
cudaStreamCreate
(
&
stream
));
constexpr
float
MAX_IN
=
1.0
;
std
::
mt19937
rng
(
12
);
std
::
uniform_real_distribution
<
float
>
dist
(
0.0
,
MAX_IN
);
float
a_scale
=
GetScale
<
AType
>
(
MAX_IN
);
float
b_scale
=
GetScale
<
BType
>
(
MAX_IN
);
float
d_scale
=
GetScale
<
DType
>
(
MAX_IN
*
MAX_IN
*
k
);
float
bias_scale
=
GetScale
<
BiasType
>
(
MAX_IN
);
std
::
vector
<
AType
>
adata
(
m
*
k
);
std
::
generate
(
adata
.
begin
(),
adata
.
end
(),
[
&
rng
,
&
dist
,
a_scale
]
{
return
static_cast
<
AType
>
(
dist
(
rng
)
*
a_scale
);
});
std
::
vector
<
BType
>
bdata
(
k
*
n
);
std
::
generate
(
bdata
.
begin
(),
bdata
.
end
(),
[
&
rng
,
&
dist
,
b_scale
]
{
return
static_cast
<
BType
>
(
dist
(
rng
)
*
b_scale
);
});
std
::
vector
<
BiasType
>
biasdata
(
m
*
n
);
std
::
generate
(
biasdata
.
begin
(),
biasdata
.
end
(),
[
&
rng
,
&
dist
,
bias_scale
]
{
return
static_cast
<
BiasType
>
(
dist
(
rng
)
*
bias_scale
);
});
auto
ga
=
transa
?
MakeFromData
<
AType
>
(
adata
,
0
,
0
,
k
,
m
,
k
,
a_scale
)
:
MakeFromData
<
AType
>
(
adata
,
0
,
0
,
m
,
k
,
m
,
a_scale
);
auto
gb
=
transb
?
MakeFromData
<
BType
>
(
bdata
,
0
,
0
,
n
,
k
,
n
,
b_scale
)
:
MakeFromData
<
BType
>
(
bdata
,
0
,
0
,
k
,
n
,
k
,
b_scale
);
auto
gbias
=
MakeFromData
<
BiasType
>
(
biasdata
,
0
,
0
,
m
,
n
,
m
,
bias_scale
);
auto
gd
=
Make
<
DType
>
(
m
,
n
,
d_scale
);
auto
gaux
=
Make
<
DType
>
(
m
,
n
,
d_scale
);
auto
dims
=
DistributeTensors
(
m
,
n
,
k
);
auto
a
=
transa
?
MakeFromData
<
AType
>
(
adata
,
dims
.
a_rows_start
,
dims
.
a_cols_start
,
dims
.
a_rows_num
,
dims
.
a_cols_num
,
k
,
a_scale
)
:
MakeFromData
<
AType
>
(
adata
,
dims
.
a_cols_start
,
dims
.
a_rows_start
,
dims
.
a_cols_num
,
dims
.
a_rows_num
,
m
,
a_scale
);
auto
b
=
transb
?
MakeFromData
<
BType
>
(
bdata
,
dims
.
b_cols_start
,
dims
.
b_rows_start
,
dims
.
b_cols_num
,
dims
.
b_rows_num
,
n
,
b_scale
)
:
MakeFromData
<
BType
>
(
bdata
,
dims
.
b_rows_start
,
dims
.
b_cols_start
,
dims
.
b_rows_num
,
dims
.
b_cols_num
,
k
,
b_scale
);
auto
bias
=
MakeFromData
<
BiasType
>
(
biasdata
,
dims
.
d_rows_start
,
dims
.
d_cols_start
,
dims
.
d_rows_num
,
dims
.
d_cols_num
,
m
,
bias_scale
);
auto
d
=
Make
<
DType
>
(
dims
.
d_rows_num
,
dims
.
d_cols_num
,
d_scale
);
auto
aux
=
Make
<
DType
>
(
dims
.
d_rows_num
,
dims
.
d_cols_num
,
d_scale
);
bool
grad
=
false
;
bool
accumulate
=
false
;
CommGemm
(
m
,
n
,
k
,
a
.
data
(),
b
.
data
(),
d
.
data
(),
bias
.
data
(),
aux
.
data
(),
transa
,
transb
,
grad
,
accumulate
,
0
/*comm_sm_count*/
,
stream
);
auto
workspace
=
Make
<
uint8_t
>
(
1
,
32
<<
20
,
1.0
);
nvte_cublas_gemm
(
ga
.
data
(),
gb
.
data
(),
gd
.
data
(),
gbias
.
data
(),
gaux
.
data
(),
transa
,
transb
,
grad
,
workspace
.
data
(),
accumulate
,
false
/* use_split_accumulator */
,
0
/* math_sm_count */
,
stream
);
NVTE_CHECK_CUDA
(
cudaStreamSynchronize
(
stream
));
NVTE_CHECK_CUDA
(
cudaStreamDestroy
(
stream
));
std
::
vector
<
DType
>
out
(
dims
.
d_rows_num
*
dims
.
d_cols_num
);
NVTE_CHECK_CUDA
(
cudaMemcpy
(
out
.
data
(),
d
.
rowwise_dptr
(),
out
.
size
()
*
sizeof
out
[
0
],
cudaMemcpyDefault
));
std
::
vector
<
DType
>
out_golden_global
(
m
*
n
);
NVTE_CHECK_CUDA
(
cudaMemcpy
(
out_golden_global
.
data
(),
gd
.
rowwise_dptr
(),
out_golden_global
.
size
()
*
sizeof
out_golden_global
[
0
],
cudaMemcpyDefault
));
auto
out_golden
=
CopyMatrix
(
out_golden_global
,
dims
.
d_rows_start
,
dims
.
d_cols_start
,
dims
.
d_rows_num
,
dims
.
d_cols_num
,
m
);
NVTE_CHECK
(
out
.
size
()
==
out_golden
.
size
());
for
(
size_t
i
=
0
;
i
<
out
.
size
();
++
i
)
{
EXPECT_NEAR
(
static_cast
<
float
>
(
out
[
i
]),
static_cast
<
float
>
(
out_golden
[
i
]),
tol
*
k
);
}
}
NVTECommGemmCtx
*
ctx_
{};
int
nranks_
{};
int
rank_
{};
ncclComm_t
comm_
{};
};
struct
AgGemm
:
public
CommGemmFixure
{
PatternDims
DistributeTensors
(
int64_t
m
,
int64_t
n
,
int64_t
k
)
override
{
auto
a_cols_num
=
nvte_comm_gemm_numroc
(
ctx_
,
m
);
auto
b_cols_num
=
nvte_comm_gemm_numroc
(
ctx_
,
n
);
int64_t
a_cols_start
{};
int64_t
b_cols_start
{};
MPI_Exscan
(
&
a_cols_num
,
&
a_cols_start
,
1
,
MPI_INT64_T
,
MPI_SUM
,
MPI_COMM_WORLD
);
MPI_Exscan
(
&
b_cols_num
,
&
b_cols_start
,
1
,
MPI_INT64_T
,
MPI_SUM
,
MPI_COMM_WORLD
);
return
PatternDims
{
.
a_rows_start
=
0
,
.
a_rows_num
=
k
,
.
a_cols_start
=
a_cols_start
,
.
a_cols_num
=
a_cols_num
,
.
b_rows_start
=
0
,
.
b_rows_num
=
k
,
.
b_cols_start
=
b_cols_start
,
.
b_cols_num
=
b_cols_num
,
.
d_rows_start
=
a_cols_start
,
.
d_rows_num
=
a_cols_num
,
.
d_cols_start
=
0
,
.
d_cols_num
=
n
,
};
}
void
CommGemm
(
int64_t
m
,
int64_t
n
,
int64_t
k
,
const
NVTETensor
a
,
const
NVTETensor
b
,
const
NVTETensor
d
,
const
NVTETensor
bias
,
const
NVTETensor
pre_act_out
,
bool
transa
,
bool
transb
,
bool
grad
,
bool
accumulate
,
int
comm_sm_count
,
cudaStream_t
stream
)
override
{
nvte_all_gather_gemm
(
ctx_
,
m
,
n
,
k
,
a
,
b
,
d
,
bias
,
pre_act_out
,
transa
,
transb
,
grad
,
accumulate
,
comm_sm_count
,
stream
,
kNVTECommGemmAlgoDefault
);
}
};
struct
GemmRs
:
public
CommGemmFixure
{
PatternDims
DistributeTensors
(
int64_t
m
,
int64_t
n
,
int64_t
k
)
override
{
auto
rows_num
=
nvte_comm_gemm_numroc
(
ctx_
,
k
);
auto
d_cols_num
=
nvte_comm_gemm_numroc
(
ctx_
,
n
);
int64_t
rows_start
{};
int64_t
d_cols_start
{};
MPI_Exscan
(
&
rows_num
,
&
rows_start
,
1
,
MPI_INT64_T
,
MPI_SUM
,
MPI_COMM_WORLD
);
MPI_Exscan
(
&
d_cols_num
,
&
d_cols_start
,
1
,
MPI_INT64_T
,
MPI_SUM
,
MPI_COMM_WORLD
);
return
PatternDims
{
.
a_rows_start
=
rows_start
,
.
a_rows_num
=
rows_num
,
.
a_cols_start
=
0
,
.
a_cols_num
=
m
,
.
b_rows_start
=
rows_start
,
.
b_rows_num
=
rows_num
,
.
b_cols_start
=
0
,
.
b_cols_num
=
n
,
.
d_rows_start
=
0
,
.
d_rows_num
=
m
,
.
d_cols_start
=
d_cols_start
,
.
d_cols_num
=
d_cols_num
,
};
}
void
CommGemm
(
int64_t
m
,
int64_t
n
,
int64_t
k
,
const
NVTETensor
a
,
const
NVTETensor
b
,
const
NVTETensor
d
,
const
NVTETensor
bias
,
const
NVTETensor
pre_act_out
,
bool
transa
,
bool
transb
,
bool
grad
,
bool
accumulate
,
int
comm_sm_count
,
cudaStream_t
stream
)
override
{
nvte_gemm_reduce_scatter
(
ctx_
,
m
,
n
,
k
,
a
,
b
,
d
,
bias
,
pre_act_out
,
transa
,
transb
,
grad
,
accumulate
,
comm_sm_count
,
stream
,
kNVTECommGemmAlgoDefault
);
}
};
struct
GemmAr
:
public
CommGemmFixure
{
PatternDims
DistributeTensors
(
int64_t
m
,
int64_t
n
,
int64_t
k
)
override
{
auto
rows_num
=
nvte_comm_gemm_numroc
(
ctx_
,
k
);
int64_t
rows_start
{};
MPI_Exscan
(
&
rows_num
,
&
rows_start
,
1
,
MPI_INT64_T
,
MPI_SUM
,
MPI_COMM_WORLD
);
return
PatternDims
{
.
a_rows_start
=
rows_start
,
.
a_rows_num
=
rows_num
,
.
a_cols_start
=
0
,
.
a_cols_num
=
m
,
.
b_rows_start
=
rows_start
,
.
b_rows_num
=
rows_num
,
.
b_cols_start
=
0
,
.
b_cols_num
=
n
,
.
d_rows_start
=
0
,
.
d_rows_num
=
m
,
.
d_cols_start
=
0
,
.
d_cols_num
=
n
,
};
}
void
CommGemm
(
int64_t
m
,
int64_t
n
,
int64_t
k
,
const
NVTETensor
a
,
const
NVTETensor
b
,
const
NVTETensor
d
,
const
NVTETensor
bias
,
const
NVTETensor
pre_act_out
,
bool
transa
,
bool
transb
,
bool
grad
,
bool
accumulate
,
int
comm_sm_count
,
cudaStream_t
stream
)
override
{
nvte_gemm_all_reduce
(
ctx_
,
m
,
n
,
k
,
a
,
b
,
d
,
bias
,
pre_act_out
,
transa
,
transb
,
grad
,
accumulate
,
comm_sm_count
,
stream
,
kNVTECommGemmAlgoDefault
);
}
void
SetUp
()
override
{
if
(
!
IsMulticastSupported
(
rank_
))
GTEST_SKIP
()
<<
"Multicast is not supported on device "
<<
rank_
;
}
};
TEST_P
(
AgGemm
,
Gemm
)
{
auto
[
a_type
,
b_type
,
d_type
,
transa
,
transb
,
m
,
n
,
k
,
tol
]
=
GetParam
();
TRANSFORMER_ENGINE_TYPE_SWITCH_OUTPUT
(
a_type
,
AType
,
TRANSFORMER_ENGINE_TYPE_SWITCH_OUTPUT
(
b_type
,
BType
,
TRANSFORMER_ENGINE_TYPE_SWITCH_OUTPUT
(
d_type
,
DType
,
Run
<
AType
,
BType
,
DType
,
DType
>
(
transa
,
transb
,
m
,
n
,
k
,
tol
);)));
}
TEST_P
(
GemmRs
,
Gemm
)
{
auto
[
a_type
,
b_type
,
d_type
,
transa
,
transb
,
m
,
n
,
k
,
tol
]
=
GetParam
();
TRANSFORMER_ENGINE_TYPE_SWITCH_OUTPUT
(
a_type
,
AType
,
TRANSFORMER_ENGINE_TYPE_SWITCH_OUTPUT
(
b_type
,
BType
,
TRANSFORMER_ENGINE_TYPE_SWITCH_OUTPUT
(
d_type
,
DType
,
Run
<
AType
,
BType
,
DType
,
DType
>
(
transa
,
transb
,
m
,
n
,
k
,
tol
);)));
}
TEST_P
(
GemmAr
,
Gemm
)
{
auto
[
a_type
,
b_type
,
d_type
,
transa
,
transb
,
m
,
n
,
k
,
tol
]
=
GetParam
();
TRANSFORMER_ENGINE_TYPE_SWITCH_OUTPUT
(
a_type
,
AType
,
TRANSFORMER_ENGINE_TYPE_SWITCH_OUTPUT
(
b_type
,
BType
,
TRANSFORMER_ENGINE_TYPE_SWITCH_OUTPUT
(
d_type
,
DType
,
Run
<
AType
,
BType
,
DType
,
DType
>
(
transa
,
transb
,
m
,
n
,
k
,
tol
);)));
}
std
::
string
ParamSuffix
(
const
testing
::
TestParamInfo
<
Params
>&
info
)
{
const
auto
[
a_type
,
b_type
,
d_type
,
transa
,
transb
,
m
,
n
,
k
,
_tol
]
=
info
.
param
;
std
::
ostringstream
ss
;
ss
<<
static_cast
<
int
>
(
a_type
)
<<
"_"
<<
static_cast
<
int
>
(
b_type
)
<<
"_"
<<
static_cast
<
int
>
(
d_type
)
<<
"_"
<<
(
transa
?
"T"
:
"N"
)
<<
(
transb
?
"T"
:
"N"
)
<<
"_"
<<
m
<<
"x"
<<
n
<<
"x"
<<
k
;
return
ss
.
str
();
}
INSTANTIATE_TEST_SUITE_P
(
AgGemm
,
AgGemm
,
testing
::
Values
(
Params
{
DType
::
kFloat16
,
DType
::
kFloat16
,
DType
::
kFloat16
,
false
,
false
,
256
,
128
,
64
,
1e-3
},
Params
{
DType
::
kFloat16
,
DType
::
kFloat16
,
DType
::
kFloat16
,
false
,
true
,
256
,
128
,
64
,
1e-3
},
Params
{
DType
::
kFloat16
,
DType
::
kFloat16
,
DType
::
kFloat16
,
true
,
false
,
256
,
128
,
64
,
1e-3
},
Params
{
DType
::
kBFloat16
,
DType
::
kBFloat16
,
DType
::
kBFloat16
,
false
,
false
,
256
,
128
,
64
,
1e-3
},
Params
{
DType
::
kBFloat16
,
DType
::
kBFloat16
,
DType
::
kBFloat16
,
false
,
true
,
256
,
128
,
64
,
1e-3
},
Params
{
DType
::
kBFloat16
,
DType
::
kBFloat16
,
DType
::
kBFloat16
,
true
,
false
,
256
,
128
,
64
,
1e-3
},
Params
{
DType
::
kFloat8E4M3
,
DType
::
kFloat8E4M3
,
DType
::
kFloat16
,
true
,
false
,
256
,
128
,
64
,
1e-3
},
Params
{
DType
::
kFloat8E4M3
,
DType
::
kFloat8E5M2
,
DType
::
kFloat16
,
true
,
false
,
256
,
128
,
64
,
1e-3
},
Params
{
DType
::
kFloat8E5M2
,
DType
::
kFloat8E4M3
,
DType
::
kFloat16
,
true
,
false
,
256
,
128
,
64
,
1e-3
}),
&
ParamSuffix
);
INSTANTIATE_TEST_SUITE_P
(
GemmRs
,
GemmRs
,
testing
::
Values
(
Params
{
DType
::
kFloat16
,
DType
::
kFloat16
,
DType
::
kFloat16
,
false
,
false
,
64
,
128
,
256
,
5e-2
},
Params
{
DType
::
kFloat16
,
DType
::
kFloat16
,
DType
::
kFloat16
,
false
,
true
,
64
,
128
,
256
,
5e-2
},
Params
{
DType
::
kFloat16
,
DType
::
kFloat16
,
DType
::
kFloat16
,
true
,
false
,
64
,
128
,
256
,
5e-2
},
Params
{
DType
::
kBFloat16
,
DType
::
kBFloat16
,
DType
::
kBFloat16
,
false
,
false
,
64
,
128
,
256
,
5e-2
},
Params
{
DType
::
kBFloat16
,
DType
::
kBFloat16
,
DType
::
kBFloat16
,
false
,
true
,
64
,
128
,
256
,
5e-2
},
Params
{
DType
::
kBFloat16
,
DType
::
kBFloat16
,
DType
::
kBFloat16
,
true
,
false
,
64
,
128
,
256
,
5e-2
},
Params
{
DType
::
kFloat8E4M3
,
DType
::
kFloat8E4M3
,
DType
::
kFloat16
,
true
,
false
,
64
,
128
,
256
,
5e-2
},
Params
{
DType
::
kFloat8E4M3
,
DType
::
kFloat8E5M2
,
DType
::
kFloat16
,
true
,
false
,
64
,
128
,
256
,
5e-2
},
Params
{
DType
::
kFloat8E5M2
,
DType
::
kFloat8E4M3
,
DType
::
kFloat16
,
true
,
false
,
64
,
128
,
256
,
5e-2
}),
&
ParamSuffix
);
INSTANTIATE_TEST_SUITE_P
(
GemmAr
,
GemmAr
,
testing
::
Values
(
Params
{
DType
::
kFloat16
,
DType
::
kFloat16
,
DType
::
kFloat16
,
true
,
false
,
64
,
64
*
4
,
64
*
4
,
5e-2
},
Params
{
DType
::
kBFloat16
,
DType
::
kBFloat16
,
DType
::
kBFloat16
,
true
,
false
,
64
,
64
*
4
,
64
*
4
,
5e-2
},
Params
{
DType
::
kFloat8E5M2
,
DType
::
kFloat8E4M3
,
DType
::
kFloat16
,
true
,
false
,
128
,
128
*
4
,
128
*
4
,
5e-2
},
Params
{
DType
::
kFloat8E4M3
,
DType
::
kFloat8E5M2
,
DType
::
kFloat16
,
true
,
false
,
128
,
128
*
4
,
128
*
4
,
5e-2
},
Params
{
DType
::
kFloat8E4M3
,
DType
::
kFloat8E4M3
,
DType
::
kFloat16
,
true
,
false
,
128
,
128
*
4
,
128
*
4
,
5e-2
}),
&
ParamSuffix
);
tests/pytorch/test_numerics.py
View file @
e45d66a3
...
...
@@ -116,13 +116,18 @@ if fp8_available:
def
is_fused_attn_available
(
config
:
ModelConfig
,
dtype
:
torch
.
dtype
,
qkv_layout
=
"bshd_bshd_bshd"
,
is_training
=
True
config
:
ModelConfig
,
dtype
:
torch
.
dtype
,
qkv_layout
=
"bshd_bshd_bshd"
,
is_training
=
True
,
deterministic
=
False
,
):
_
,
_
,
fused_attn_backends
=
get_available_attention_backends
(
config
,
qkv_dtype
=
dtype
,
qkv_layout
=
qkv_layout
,
is_training
=
is_training
,
deterministic
=
deterministic
,
)
return
FusedAttnBackend
[
"F16_arbitrary_seqlen"
]
in
fused_attn_backends
...
...
@@ -830,7 +835,7 @@ def _test_e2e_checkpointing(bs, dtype, config, checkpoint=False, steps=10, path=
@
pytest
.
mark
.
parametrize
(
"model"
,
[
"126m"
])
def
test_gpt_checkpointing
(
dtype
,
bs
,
model
):
config
=
model_configs
[
model
]
if
not
is_fused_attn_available
(
config
,
dtype
):
if
not
is_fused_attn_available
(
config
,
dtype
,
deterministic
=
True
):
pytest
.
skip
(
"No attention backend available."
)
outputs
=
_test_e2e_checkpointing
(
bs
,
dtype
,
config
,
checkpoint
=
False
)
outputs_checkpoint
=
_test_e2e_checkpointing
(
bs
,
dtype
,
config
,
checkpoint
=
True
)
...
...
@@ -878,7 +883,9 @@ def _test_e2e_gpt_accuracy(block, bs, dtype, config):
@
pytest
.
mark
.
parametrize
(
"parallel_attention_mlp"
,
all_boolean
)
def
test_gpt_accuracy
(
dtype
,
bs
,
model
,
parallel_attention_mlp
):
config
=
model_configs
[
model
]
if
not
is_fused_attn_available
(
config
,
dtype
,
qkv_layout
=
"sb3hd"
,
is_training
=
False
):
if
not
is_fused_attn_available
(
config
,
dtype
,
qkv_layout
=
"sb3hd"
,
is_training
=
True
,
deterministic
=
True
):
pytest
.
skip
(
"No attention backend available."
)
te_gpt
=
TransformerLayer
(
...
...
@@ -991,7 +998,9 @@ def _test_mha_accuracy(block, bs, dtype, config, mask_type, te=True):
@
pytest
.
mark
.
parametrize
(
"mask_type"
,
mask_types
)
def
test_mha_accuracy
(
dtype
,
bs
,
model
,
mask_type
):
config
=
model_configs
[
model
]
if
not
is_fused_attn_available
(
config
,
dtype
,
qkv_layout
=
"sb3hd"
,
is_training
=
False
):
if
not
is_fused_attn_available
(
config
,
dtype
,
qkv_layout
=
"sb3hd"
,
is_training
=
True
,
deterministic
=
True
):
pytest
.
skip
(
"No attention backend available."
)
te_mha
=
MultiheadAttention
(
...
...
tests/pytorch/test_onnx_export.py
View file @
e45d66a3
...
...
@@ -36,6 +36,7 @@ import transformer_engine_torch as tex
from
transformer_engine.pytorch.export
import
is_in_onnx_export_mode
,
te_translation_table
from
transformer_engine.pytorch.fp8
import
FP8GlobalStateManager
from
transformer_engine.pytorch.utils
import
get_default_init_method
import
tensorrt
as
trt
# Global test configuration knobs.
...
...
@@ -113,7 +114,7 @@ def trt_fp8_dequantize(t, scale):
@
onnx_op
(
op_type
=
"trt::TRT_MXFP8Quantize
Linear
"
,
op_type
=
"trt::TRT_MXFP8
Dynamic
Quantize"
,
domain
=
"trt"
,
inputs
=
[
PyCustomOpDef
.
dt_float
,
...
...
@@ -1139,3 +1140,59 @@ def test_export_ctx_manager(enabled):
with
te
.
onnx_export
(
enabled
):
assert
is_in_onnx_export_mode
()
==
enabled
assert
is_in_onnx_export_mode
()
==
False
@
pytest
.
mark
.
parametrize
(
"fp8_recipe"
,
fp8_recipes
)
def
test_trt_integration
(
fp8_recipe
:
recipe
.
Recipe
):
model
=
te
.
TransformerLayer
(
hidden_size
=
128
,
ffn_hidden_size
=
128
,
num_attention_heads
=
4
,
).
eval
()
inps
=
(
torch
.
randn
([
16
,
16
,
128
],
device
=
"cuda"
,
requires_grad
=
False
),)
with
te
.
fp8_autocast
(
enabled
=
fp8_recipe
is
not
None
,
fp8_recipe
=
fp8_recipe
):
out_ref
=
model
(
*
inps
)
onnx_fd
,
onnx_path
=
tempfile
.
mkstemp
(
suffix
=
".onnx"
)
os
.
close
(
onnx_fd
)
try
:
with
te
.
fp8_autocast
(
enabled
=
fp8_recipe
is
not
None
,
fp8_recipe
=
fp8_recipe
):
with
te
.
onnx_export
(
enabled
=
True
):
torch
.
onnx
.
export
(
model
,
inps
,
onnx_path
,
output_names
=
[
"output"
],
dynamo
=
True
,
custom_translation_table
=
te_translation_table
,
)
os
.
system
(
f
"trtexec --onnx=
{
onnx_path
}
--saveEngine=
{
onnx_path
}
.engine"
)
# Run TRT engine
logger
=
trt
.
Logger
(
trt
.
Logger
.
WARNING
)
runtime
=
trt
.
Runtime
(
logger
)
with
open
(
onnx_path
+
".engine"
,
"rb"
)
as
f
:
engine_data
=
f
.
read
()
engine
=
runtime
.
deserialize_cuda_engine
(
engine_data
)
context
=
engine
.
create_execution_context
()
context
.
set_tensor_address
(
engine
.
get_tensor_name
(
0
),
inps
[
0
].
data_ptr
())
stream
=
torch
.
cuda
.
Stream
()
out
=
torch
.
zeros_like
(
out_ref
)
context
.
set_tensor_address
(
"output"
,
out
.
data_ptr
())
context
.
execute_async_v3
(
stream_handle
=
stream
.
cuda_stream
)
stream
.
synchronize
()
# Compare TRT and TE outputs
atol
=
5e-2
if
fp8_recipe
is
not
None
else
1e-4
rtol
=
5e-2
if
fp8_recipe
is
not
None
else
1e-4
torch
.
testing
.
assert_close
(
out
,
out_ref
,
atol
=
atol
,
rtol
=
rtol
)
finally
:
try
:
os
.
remove
(
onnx_path
)
except
FileNotFoundError
:
pass
tests/pytorch/utils.py
View file @
e45d66a3
...
...
@@ -266,8 +266,8 @@ def get_available_attention_backends(
)
(
use_flash_attention
,
use_fused_attention
,
flash_attention_backend
,
use_fused_attention
,
fused_attention_backend
,
use_unfused_attention
,
available_backends
,
...
...
transformer_engine/common/CMakeLists.txt
View file @
e45d66a3
...
...
@@ -169,6 +169,10 @@ if(USE_CUDA)
comm_gemm_overlap/userbuffers/userbuffers-host.cpp
comm_gemm_overlap/userbuffers/userbuffers.cu
comm_gemm_overlap/comm_gemm_overlap.cpp
)
if
(
NVTE_WITH_CUBLASMP
)
list
(
APPEND transformer_engine_SOURCES
comm_gemm/comm_gemm.cpp
)
endif
()
add_library
(
transformer_engine SHARED
${
transformer_engine_SOURCES
}
)
else
()
list
(
APPEND transformer_engine_SOURCES
...
...
@@ -272,6 +276,8 @@ if (USE_CUDA)
CUDNN::cudnn_all
)
target_include_directories
(
transformer_engine PRIVATE
${
CMAKE_CUDA_TOOLKIT_INCLUDE_DIRECTORIES
}
)
target_include_directories
(
transformer_engine SYSTEM PRIVATE
${
CMAKE_CUDA_TOOLKIT_INCLUDE_DIRECTORIES
}
/cccl
)
target_include_directories
(
transformer_engine PRIVATE
"
${
CUDNN_FRONTEND_INCLUDE_DIR
}
"
)
else
()
target_include_directories
(
transformer_engine PUBLIC
"
${
CMAKE_CURRENT_SOURCE_DIR
}
"
)
...
...
@@ -313,11 +319,23 @@ if (NVTE_ENABLE_NVSHMEM)
target_include_directories
(
transformer_engine PUBLIC
${
NVSHMEMAPI_INCLUDE_DIR
}
)
endif
()
option
(
NVTE_ENABLE_NVSHMEM
"Compile with NVSHMEM library"
OFF
)
if
(
NVTE_ENABLE_NVSHMEM
)
add_subdirectory
(
nvshmem_api
)
target_link_libraries
(
transformer_engine PUBLIC nvshmemapi
)
target_include_directories
(
transformer_engine PUBLIC
${
NVSHMEMAPI_INCLUDE_DIR
}
)
option
(
NVTE_WITH_CUBLASMP
"Use cuBLASMp for tensor parallel GEMMs"
OFF
)
if
(
NVTE_WITH_CUBLASMP
)
target_compile_definitions
(
transformer_engine PRIVATE NVTE_WITH_CUBLASMP
)
target_include_directories
(
transformer_engine PRIVATE
${
CUBLASMP_DIR
}
/include
${
NVSHMEM_DIR
}
/include
)
find_library
(
CUBLASMP_LIB
NAMES cublasmp libcublasmp
PATHS
${
CUBLASMP_DIR
}
PATH_SUFFIXES lib
REQUIRED
)
find_library
(
NVSHMEM_HOST_LIB
NAMES nvshmem_host libnvshmem_host.so.3
PATHS
${
NVSHMEM_DIR
}
PATH_SUFFIXES lib
REQUIRED
)
target_link_libraries
(
transformer_engine PUBLIC
${
CUBLASMP_LIB
}
${
NVSHMEM_HOST_LIB
}
)
message
(
STATUS
"Using cuBLASMp at:
${
CUBLASMP_DIR
}
"
)
message
(
STATUS
"Using nvshmem at:
${
NVSHMEM_DIR
}
"
)
endif
()
if
(
USE_CUDA
)
...
...
transformer_engine/common/comm_gemm/comm_gemm.cpp
0 → 100644
View file @
e45d66a3
/*************************************************************************
* Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
*
* See LICENSE for license information.
************************************************************************/
#include "transformer_engine/comm_gemm.h"
#include <cublasmp.h>
#include <cuda_runtime.h>
#include <nvshmem.h>
#include <map>
#include <memory>
#include <string>
#include <tuple>
#include <type_traits>
#include <unordered_map>
#include <utility>
#include <vector>
#include "../common.h"
#include "../util/logging.h"
using
namespace
transformer_engine
;
namespace
{
// TODO: log warnings on failures of the *Destroy calls below, once TE has such ability.
// For now, just silently ignoring the errors, since the only diag available in TE is throwing
// exceptions, but these calls will typically be made from destructors, so cannot throw.
template
<
typename
HandlePtr
,
typename
CreateFn
,
typename
DestroyFn
,
typename
...
Args
>
auto
CreateWithCudaCheck
(
CreateFn
create_fn
,
DestroyFn
destroy_fn
,
Args
&&
...
args
)
{
using
Handle
=
std
::
remove_pointer_t
<
HandlePtr
>
;
HandlePtr
raw
{};
NVTE_CHECK_CUDA
(
create_fn
(
&
raw
,
std
::
forward
<
Args
>
(
args
)...));
return
std
::
unique_ptr
<
Handle
,
DestroyFn
>
(
raw
,
destroy_fn
);
}
using
CudaStream
=
std
::
unique_ptr
<
std
::
remove_pointer_t
<
cudaStream_t
>
,
decltype
(
&
cudaStreamDestroy
)
>
;
CudaStream
CudaStreamCreate
()
{
return
CreateWithCudaCheck
<
cudaStream_t
>
(
cudaStreamCreate
,
cudaStreamDestroy
);
}
using
CudaEvent
=
std
::
unique_ptr
<
std
::
remove_pointer_t
<
cudaEvent_t
>
,
decltype
(
&
cudaEventDestroy
)
>
;
CudaEvent
CudaEventCreate
(
unsigned
flags
)
{
return
CreateWithCudaCheck
<
cudaEvent_t
>
(
cudaEventCreateWithFlags
,
cudaEventDestroy
,
flags
);
}
template
<
bool
raw_last
,
typename
HandlePtr
,
typename
CreateFn
,
typename
DestroyFn
,
typename
...
Args
>
auto
CreateWithCublasMpCheck
(
CreateFn
create_fn
,
DestroyFn
destroy_fn
,
Args
&&
...
args
)
{
using
Handle
=
std
::
remove_pointer_t
<
HandlePtr
>
;
HandlePtr
raw
{};
if
constexpr
(
raw_last
)
{
NVTE_CHECK_CUBLASMP
(
create_fn
(
std
::
forward
<
Args
>
(
args
)...,
&
raw
));
}
else
{
NVTE_CHECK_CUBLASMP
(
create_fn
(
&
raw
,
std
::
forward
<
Args
>
(
args
)...));
}
return
std
::
unique_ptr
<
Handle
,
DestroyFn
>
(
raw
,
destroy_fn
);
}
using
CublasMp
=
std
::
unique_ptr
<
std
::
remove_pointer_t
<
cublasMpHandle_t
>
,
decltype
(
&
cublasMpDestroy
)
>
;
CublasMp
CublasMpCreate
(
cudaStream_t
stream
)
{
return
CreateWithCublasMpCheck
<
false
,
cublasMpHandle_t
>
(
cublasMpCreate
,
cublasMpDestroy
,
stream
);
}
using
CublasMpGrid
=
std
::
unique_ptr
<
std
::
remove_pointer_t
<
cublasMpGrid_t
>
,
decltype
(
&
cublasMpGridDestroy
)
>
;
CublasMpGrid
CublasMpGridCreate
(
int64_t
nprow
,
int64_t
npcol
,
cublasMpGridLayout_t
layout
,
ncclComm_t
comm
)
{
return
CreateWithCublasMpCheck
<
true
,
cublasMpGrid_t
>
(
cublasMpGridCreate
,
cublasMpGridDestroy
,
nprow
,
npcol
,
layout
,
comm
);
}
using
CublasMpMatrixDesc
=
std
::
unique_ptr
<
std
::
remove_pointer_t
<
cublasMpMatrixDescriptor_t
>
,
decltype
(
&
cublasMpMatrixDescriptorDestroy
)
>
;
CublasMpMatrixDesc
CublasMpMatrixDescCreate
(
int64_t
m
,
int64_t
n
,
int64_t
mb
,
int64_t
nb
,
int64_t
rsrc
,
int64_t
csrc
,
int64_t
lld
,
cudaDataType_t
type
,
cublasMpGrid_t
grid
)
{
return
CreateWithCublasMpCheck
<
true
,
cublasMpMatrixDescriptor_t
>
(
cublasMpMatrixDescriptorCreate
,
cublasMpMatrixDescriptorDestroy
,
m
,
n
,
mb
,
nb
,
rsrc
,
csrc
,
lld
,
type
,
grid
);
}
using
CublasMpMatmulDesc
=
std
::
unique_ptr
<
std
::
remove_pointer_t
<
cublasMpMatmulDescriptor_t
>
,
decltype
(
&
cublasMpMatmulDescriptorDestroy
)
>
;
CublasMpMatmulDesc
CublasMpMatmulDescCreate
(
cublasComputeType_t
compute_type
)
{
return
CreateWithCublasMpCheck
<
false
,
cublasMpMatmulDescriptor_t
>
(
cublasMpMatmulDescriptorCreate
,
cublasMpMatmulDescriptorDestroy
,
compute_type
);
}
}
// namespace
struct
NVTECommGemmCtx
{
int64_t
nranks
;
int64_t
rank
;
ncclComm_t
comm
;
CudaStream
stream
;
CudaEvent
event
;
CublasMp
cublas_mp
;
CublasMpGrid
grid_col_major
;
CublasMpGrid
grid_row_major
;
CublasMpMatrixDesc
a_desc
;
CublasMpMatrixDesc
b_desc
;
CublasMpMatrixDesc
d_desc
;
CublasMpMatmulDesc
matmul_desc
;
void
*
workspace
;
size_t
workspace_size
;
};
namespace
{
int64_t
block_size
(
NVTECommGemmCtx
*
ctx
,
int64_t
global_size
)
{
// Use non-cyclic layout to maximize opportunity for comm overlap.
return
(
global_size
+
ctx
->
nranks
-
1
)
/
ctx
->
nranks
;
}
void
AgGemmInitMatrices
(
NVTECommGemmCtx
*
ctx
,
int64_t
*
ldd
,
int64_t
m
,
int64_t
n
,
int64_t
k
,
const
Tensor
*
a
,
const
Tensor
*
b
,
const
Tensor
*
d
,
bool
transa
,
bool
transb
)
{
const
auto
a0
=
a
->
flat_first_dim
();
const
auto
a1
=
a
->
flat_last_dim
();
const
auto
b0
=
b
->
flat_first_dim
();
const
auto
b1
=
b
->
flat_last_dim
();
const
auto
d0
=
d
->
flat_first_dim
();
const
auto
d1
=
d
->
flat_last_dim
();
if
(
transa
)
{
NVTE_CHECK
(
a1
==
k
,
"Unsupported tensor dimension in A: expected "
,
k
,
", got "
,
a1
);
NVTE_CHECK_CUBLASMP
(
cublasMpMatrixDescriptorInit
(
k
,
m
,
k
,
block_size
(
ctx
,
m
),
0
,
0
,
k
,
get_cuda_dtype
(
a
->
dtype
()),
ctx
->
grid_row_major
.
get
(),
ctx
->
a_desc
.
get
()));
}
else
{
NVTE_CHECK
(
a0
==
k
,
"Unsupported tensor dimension in A: expected "
,
k
,
", got "
,
a0
);
NVTE_CHECK_CUBLASMP
(
cublasMpMatrixDescriptorInit
(
m
,
k
,
block_size
(
ctx
,
m
),
k
,
0
,
0
,
block_size
(
ctx
,
m
),
get_cuda_dtype
(
a
->
dtype
()),
ctx
->
grid_col_major
.
get
(),
ctx
->
a_desc
.
get
()));
}
if
(
transb
)
{
NVTE_CHECK
(
b0
==
k
,
"Unsupported tensor dimensionin B: expected "
,
k
,
", got "
,
b0
);
NVTE_CHECK_CUBLASMP
(
cublasMpMatrixDescriptorInit
(
n
,
k
,
block_size
(
ctx
,
n
),
k
,
0
,
0
,
block_size
(
ctx
,
n
),
get_cuda_dtype
(
b
->
dtype
()),
ctx
->
grid_col_major
.
get
(),
ctx
->
b_desc
.
get
()));
}
else
{
NVTE_CHECK
(
b1
==
k
,
"Unsupported tensor dimension in B: expected "
,
k
,
", got "
,
b1
);
NVTE_CHECK_CUBLASMP
(
cublasMpMatrixDescriptorInit
(
k
,
n
,
k
,
block_size
(
ctx
,
n
),
0
,
0
,
k
,
get_cuda_dtype
(
b
->
dtype
()),
ctx
->
grid_row_major
.
get
(),
ctx
->
b_desc
.
get
()));
}
NVTE_CHECK
(
d0
==
n
,
"Unsupported tensor dimension in D: expected "
,
n
,
", got "
,
d0
);
*
ldd
=
block_size
(
ctx
,
m
);
NVTE_CHECK_CUBLASMP
(
cublasMpMatrixDescriptorInit
(
m
,
n
,
block_size
(
ctx
,
m
),
block_size
(
ctx
,
n
),
0
,
0
,
*
ldd
,
get_cuda_dtype
(
d
->
dtype
()),
ctx
->
grid_col_major
.
get
(),
ctx
->
d_desc
.
get
()));
}
void
GemmRsInitMatrices
(
NVTECommGemmCtx
*
ctx
,
int64_t
*
ldd
,
int64_t
m
,
int64_t
n
,
int64_t
k
,
const
Tensor
*
a
,
const
Tensor
*
b
,
const
Tensor
*
d
,
bool
transa
,
bool
transb
)
{
const
auto
a0
=
a
->
flat_first_dim
();
const
auto
a1
=
a
->
flat_last_dim
();
const
auto
b0
=
b
->
flat_first_dim
();
const
auto
b1
=
b
->
flat_last_dim
();
const
auto
d0
=
d
->
flat_first_dim
();
const
auto
d1
=
d
->
flat_last_dim
();
if
(
transa
)
{
NVTE_CHECK
(
a0
==
m
,
"Unsupported tensor dimension in A: expected "
,
m
,
", got "
,
a0
);
NVTE_CHECK_CUBLASMP
(
cublasMpMatrixDescriptorInit
(
k
,
m
,
block_size
(
ctx
,
k
),
m
,
0
,
0
,
block_size
(
ctx
,
k
),
get_cuda_dtype
(
a
->
dtype
()),
ctx
->
grid_col_major
.
get
(),
ctx
->
a_desc
.
get
()));
}
else
{
NVTE_CHECK
(
a1
==
m
,
"Unsupported tensor dimension in A: expected "
,
m
,
", got "
,
a1
);
NVTE_CHECK_CUBLASMP
(
cublasMpMatrixDescriptorInit
(
m
,
k
,
m
,
block_size
(
ctx
,
k
),
0
,
0
,
m
,
get_cuda_dtype
(
a
->
dtype
()),
ctx
->
grid_row_major
.
get
(),
ctx
->
a_desc
.
get
()));
}
if
(
transb
)
{
NVTE_CHECK
(
b1
==
n
,
"Unsupported tensor dimension in B: expected "
,
n
,
", got "
,
b1
);
NVTE_CHECK_CUBLASMP
(
cublasMpMatrixDescriptorInit
(
n
,
k
,
block_size
(
ctx
,
n
),
block_size
(
ctx
,
k
),
0
,
0
,
block_size
(
ctx
,
n
),
get_cuda_dtype
(
b
->
dtype
()),
ctx
->
grid_row_major
.
get
(),
ctx
->
b_desc
.
get
()));
}
else
{
NVTE_CHECK
(
b0
==
n
,
"Unsupported tensor dimension in B: expected "
,
n
,
", got "
,
b0
);
NVTE_CHECK_CUBLASMP
(
cublasMpMatrixDescriptorInit
(
k
,
n
,
block_size
(
ctx
,
k
),
block_size
(
ctx
,
n
),
0
,
0
,
block_size
(
ctx
,
k
),
get_cuda_dtype
(
b
->
dtype
()),
ctx
->
grid_col_major
.
get
(),
ctx
->
b_desc
.
get
()));
}
NVTE_CHECK
(
d1
==
m
,
"Unsupported tensor dimension in D: expected "
,
m
,
", got "
,
d1
);
*
ldd
=
m
;
NVTE_CHECK_CUBLASMP
(
cublasMpMatrixDescriptorInit
(
m
,
n
,
m
,
block_size
(
ctx
,
n
),
0
,
0
,
*
ldd
,
get_cuda_dtype
(
d
->
dtype
()),
ctx
->
grid_row_major
.
get
(),
ctx
->
d_desc
.
get
()));
}
void
GemmArInitMatrices
(
NVTECommGemmCtx
*
ctx
,
int64_t
*
ldd
,
int64_t
m
,
int64_t
n
,
int64_t
k
,
const
Tensor
*
a
,
const
Tensor
*
b
,
const
Tensor
*
d
,
bool
transa
,
bool
transb
)
{
const
auto
a0
=
a
->
flat_first_dim
();
const
auto
a1
=
a
->
flat_last_dim
();
const
auto
b0
=
b
->
flat_first_dim
();
const
auto
b1
=
b
->
flat_last_dim
();
const
auto
d0
=
d
->
flat_first_dim
();
const
auto
d1
=
d
->
flat_last_dim
();
if
(
transa
)
{
NVTE_CHECK
(
a0
==
m
,
"Unsupported tensor dimension in A: expected "
,
m
,
", got "
,
a0
);
NVTE_CHECK_CUBLASMP
(
cublasMpMatrixDescriptorInit
(
k
,
m
,
block_size
(
ctx
,
k
),
m
,
0
,
0
,
block_size
(
ctx
,
k
),
get_cuda_dtype
(
a
->
dtype
()),
ctx
->
grid_col_major
.
get
(),
ctx
->
a_desc
.
get
()));
}
else
{
NVTE_ERROR
(
"N transpose flag is not supported for input A"
);
}
if
(
transb
)
{
NVTE_ERROR
(
"T transpose flag is not supported for input B"
);
}
else
{
NVTE_CHECK
(
b0
==
n
,
"Unsupported tensor dimension in B: expected "
,
n
,
", got "
,
b0
);
NVTE_CHECK_CUBLASMP
(
cublasMpMatrixDescriptorInit
(
k
,
n
,
block_size
(
ctx
,
k
),
n
,
0
,
0
,
block_size
(
ctx
,
k
),
get_cuda_dtype
(
b
->
dtype
()),
ctx
->
grid_col_major
.
get
(),
ctx
->
b_desc
.
get
()));
}
NVTE_CHECK
(
d1
==
m
,
"Unsupported tensor dimension in D: expected "
,
m
,
", got "
,
d1
);
*
ldd
=
m
;
NVTE_CHECK_CUBLASMP
(
cublasMpMatrixDescriptorInit
(
m
,
n
*
ctx
->
nranks
,
m
,
n
,
0
,
0
,
*
ldd
,
get_cuda_dtype
(
d
->
dtype
()),
ctx
->
grid_row_major
.
get
(),
ctx
->
d_desc
.
get
()));
const
cublasMpMatmulEpilogue_t
epilogue
=
CUBLASMP_MATMUL_EPILOGUE_ALLREDUCE
;
NVTE_CHECK_CUBLASMP
(
cublasMpMatmulDescriptorAttributeSet
(
ctx
->
matmul_desc
.
get
(),
CUBLASMP_MATMUL_DESCRIPTOR_ATTRIBUTE_EPILOGUE
,
&
epilogue
,
sizeof
epilogue
));
}
using
InitMatricesFn
=
void
(
*
)(
NVTECommGemmCtx
*
,
int64_t
*
,
int64_t
,
int64_t
,
int64_t
,
const
Tensor
*
,
const
Tensor
*
,
const
Tensor
*
,
bool
,
bool
);
cublasMpMatmulAlgoType_t
cublasmp_algo
(
NVTECommGemmAlgoType
algo
)
{
static
const
std
::
unordered_map
<
NVTECommGemmAlgoType
,
cublasMpMatmulAlgoType_t
>
s_map
{
{
kNVTECommGemmAlgoDefault
,
CUBLASMP_MATMUL_ALGO_TYPE_DEFAULT
},
{
kNVTECommGemmAlgoSplitP2P
,
CUBLASMP_MATMUL_ALGO_TYPE_SPLIT_P2P
},
{
kNVTECommGemmAlgoSplitMulticast
,
CUBLASMP_MATMUL_ALGO_TYPE_SPLIT_MULTICAST
},
{
kNVTECommGemmAlgoAtomicP2P
,
CUBLASMP_MATMUL_ALGO_TYPE_ATOMIC_P2P
},
{
kNVTECommGemmAlgoAtomicMulticast
,
CUBLASMP_MATMUL_ALGO_TYPE_ATOMIC_MULTICAST
},
};
auto
it
=
s_map
.
find
(
algo
);
return
it
!=
s_map
.
end
()
?
it
->
second
:
static_cast
<
cublasMpMatmulAlgoType_t
>
(
algo
);
}
void
cublasmp_gemm
(
InitMatricesFn
init_matrices_fn
,
NVTECommGemmCtx
*
ctx
,
NVTECommGemmAlgoType
algo
,
int64_t
m
,
int64_t
n
,
int64_t
k
,
const
Tensor
*
a
,
const
Tensor
*
b
,
const
Tensor
*
d
,
const
Tensor
*
bias
,
const
Tensor
*
pre_act_out
,
bool
transa
,
bool
transb
,
bool
grad
,
bool
accumulate
,
int
comm_sm_count
,
cudaStream_t
main_stream
)
{
for
(
auto
t
:
{
a
,
b
,
d
})
{
NVTE_CHECK
(
is_tensor_scaling
(
t
->
scaling_mode
),
"Unsupported scaling mode: "
+
std
::
to_string
(
t
->
scaling_mode
));
}
NVTE_CHECK_CUBLASMP
(
cublasMpMatmulDescriptorInit
(
ctx
->
matmul_desc
.
get
(),
CUBLAS_COMPUTE_32F
));
int64_t
ldd
{};
init_matrices_fn
(
ctx
,
&
ldd
,
m
,
n
,
k
,
a
,
b
,
d
,
transa
,
transb
);
const
cublasOperation_t
trans_a
=
transa
?
CUBLAS_OP_T
:
CUBLAS_OP_N
;
const
cublasOperation_t
trans_b
=
transb
?
CUBLAS_OP_T
:
CUBLAS_OP_N
;
NVTE_CHECK_CUBLASMP
(
cublasMpMatmulDescriptorAttributeSet
(
ctx
->
matmul_desc
.
get
(),
CUBLASMP_MATMUL_DESCRIPTOR_ATTRIBUTE_TRANSA
,
&
trans_a
,
sizeof
trans_a
));
NVTE_CHECK_CUBLASMP
(
cublasMpMatmulDescriptorAttributeSet
(
ctx
->
matmul_desc
.
get
(),
CUBLASMP_MATMUL_DESCRIPTOR_ATTRIBUTE_TRANSB
,
&
trans_b
,
sizeof
trans_b
));
cublasMpMatmulAlgoType_t
algo_attr
=
cublasmp_algo
(
algo
);
NVTE_CHECK_CUBLASMP
(
cublasMpMatmulDescriptorAttributeSet
(
ctx
->
matmul_desc
.
get
(),
CUBLASMP_MATMUL_DESCRIPTOR_ATTRIBUTE_ALGO_TYPE
,
&
algo_attr
,
sizeof
algo_attr
));
const
cublasMpMatmulMatrixScale_t
scale_mode
=
CUBLASMP_MATMUL_MATRIX_SCALE_SCALAR_FP32
;
if
(
is_fp8_dtype
(
a
->
dtype
()))
{
NVTE_CHECK
(
a
->
scale_inv
.
dptr
,
"Scaling must be set for FP8 dtype"
);
NVTE_CHECK_CUBLASMP
(
cublasMpMatmulDescriptorAttributeSet
(
ctx
->
matmul_desc
.
get
(),
CUBLASMP_MATMUL_DESCRIPTOR_ATTRIBUTE_A_SCALE_MODE
,
&
scale_mode
,
sizeof
scale_mode
));
NVTE_CHECK_CUBLASMP
(
cublasMpMatmulDescriptorAttributeSet
(
ctx
->
matmul_desc
.
get
(),
CUBLASMP_MATMUL_DESCRIPTOR_ATTRIBUTE_A_SCALE_POINTER
,
&
a
->
scale_inv
.
dptr
,
sizeof
(
void
*
)));
}
if
(
is_fp8_dtype
(
b
->
dtype
()))
{
NVTE_CHECK
(
b
->
scale_inv
.
dptr
,
"Scaling must be set for FP8 dtype"
);
NVTE_CHECK_CUBLASMP
(
cublasMpMatmulDescriptorAttributeSet
(
ctx
->
matmul_desc
.
get
(),
CUBLASMP_MATMUL_DESCRIPTOR_ATTRIBUTE_B_SCALE_MODE
,
&
scale_mode
,
sizeof
scale_mode
));
NVTE_CHECK_CUBLASMP
(
cublasMpMatmulDescriptorAttributeSet
(
ctx
->
matmul_desc
.
get
(),
CUBLASMP_MATMUL_DESCRIPTOR_ATTRIBUTE_B_SCALE_POINTER
,
&
b
->
scale_inv
.
dptr
,
sizeof
(
void
*
)));
}
if
(
is_fp8_dtype
(
d
->
dtype
()))
{
NVTE_CHECK
(
d
->
scale
.
dptr
,
"Scaling must be set for FP8 dtype"
);
NVTE_CHECK_CUBLASMP
(
cublasMpMatmulDescriptorAttributeSet
(
ctx
->
matmul_desc
.
get
(),
CUBLASMP_MATMUL_DESCRIPTOR_ATTRIBUTE_D_SCALE_MODE
,
&
scale_mode
,
sizeof
scale_mode
));
NVTE_CHECK_CUBLASMP
(
cublasMpMatmulDescriptorAttributeSet
(
ctx
->
matmul_desc
.
get
(),
CUBLASMP_MATMUL_DESCRIPTOR_ATTRIBUTE_D_SCALE_POINTER
,
&
d
->
scale
.
dptr
,
sizeof
(
void
*
)));
if
(
d
->
amax
.
dptr
)
{
NVTE_CHECK_CUBLASMP
(
cublasMpMatmulDescriptorAttributeSet
(
ctx
->
matmul_desc
.
get
(),
CUBLASMP_MATMUL_DESCRIPTOR_ATTRIBUTE_AMAX_D_POINTER
,
&
d
->
amax
.
dptr
,
sizeof
(
void
*
)));
}
}
// Might be set to ALLREDUCE before, need to OR with the new flags to set.
cublasMpMatmulEpilogue_t
epilogue
{};
size_t
size_read
{};
NVTE_CHECK_CUBLASMP
(
cublasMpMatmulDescriptorAttributeGet
(
ctx
->
matmul_desc
.
get
(),
CUBLASMP_MATMUL_DESCRIPTOR_ATTRIBUTE_EPILOGUE
,
&
epilogue
,
sizeof
epilogue
,
&
size_read
));
NVTE_CHECK
(
size_read
==
sizeof
epilogue
);
// (bias, gelu, grad) -> epilogue
const
std
::
map
<
std
::
tuple
<
bool
,
bool
,
bool
>
,
cublasMpMatmulEpilogue_t
>
flags_to_epilogue
{
{{
true
,
true
,
false
},
CUBLASMP_MATMUL_EPILOGUE_GELU_AUX_BIAS
},
{{
true
,
true
,
true
},
CUBLASMP_MATMUL_EPILOGUE_DGELU_BGRAD
},
{{
true
,
false
,
false
},
CUBLASMP_MATMUL_EPILOGUE_BIAS
},
{{
true
,
false
,
true
},
CUBLASMP_MATMUL_EPILOGUE_BGRADB
},
{{
false
,
true
,
false
},
CUBLASMP_MATMUL_EPILOGUE_GELU_AUX
},
{{
false
,
true
,
true
},
CUBLASMP_MATMUL_EPILOGUE_DGELU
},
};
if
(
auto
it
=
flags_to_epilogue
.
find
({
bias
?
bias
->
data
.
dptr
!=
nullptr
:
false
,
pre_act_out
?
pre_act_out
->
data
.
dptr
!=
nullptr
:
false
,
grad
});
it
!=
flags_to_epilogue
.
end
())
{
epilogue
=
static_cast
<
cublasMpMatmulEpilogue_t
>
(
epilogue
|
it
->
second
);
NVTE_CHECK_CUBLASMP
(
cublasMpMatmulDescriptorAttributeSet
(
ctx
->
matmul_desc
.
get
(),
CUBLASMP_MATMUL_DESCRIPTOR_ATTRIBUTE_EPILOGUE
,
&
epilogue
,
sizeof
epilogue
));
}
if
(
bias
&&
bias
->
data
.
dptr
)
{
cudaDataType_t
bias_type
=
get_cuda_dtype
(
bias
->
data
.
dtype
);
NVTE_CHECK_CUBLASMP
(
cublasMpMatmulDescriptorAttributeSet
(
ctx
->
matmul_desc
.
get
(),
CUBLASMP_MATMUL_DESCRIPTOR_ATTRIBUTE_BIAS_DATA_TYPE
,
&
bias_type
,
sizeof
bias_type
));
NVTE_CHECK_CUBLASMP
(
cublasMpMatmulDescriptorAttributeSet
(
ctx
->
matmul_desc
.
get
(),
CUBLASMP_MATMUL_DESCRIPTOR_ATTRIBUTE_BIAS_POINTER
,
&
bias
->
data
.
dptr
,
sizeof
bias
->
data
.
dptr
));
}
if
(
pre_act_out
&&
pre_act_out
->
data
.
dptr
)
{
cudaDataType_t
aux_type
=
get_cuda_dtype
(
pre_act_out
->
data
.
dtype
);
NVTE_CHECK_CUBLASMP
(
cublasMpMatmulDescriptorAttributeSet
(
ctx
->
matmul_desc
.
get
(),
CUBLASMP_MATMUL_DESCRIPTOR_ATTRIBUTE_EPILOGUE_AUX_DATA_TYPE
,
&
aux_type
,
sizeof
aux_type
));
NVTE_CHECK_CUBLASMP
(
cublasMpMatmulDescriptorAttributeSet
(
ctx
->
matmul_desc
.
get
(),
CUBLASMP_MATMUL_DESCRIPTOR_ATTRIBUTE_EPILOGUE_AUX_POINTER
,
&
pre_act_out
->
data
.
dptr
,
sizeof
pre_act_out
->
data
.
dptr
));
NVTE_CHECK_CUBLASMP
(
cublasMpMatmulDescriptorAttributeSet
(
ctx
->
matmul_desc
.
get
(),
CUBLASMP_MATMUL_DESCRIPTOR_ATTRIBUTE_EPILOGUE_AUX_LD
,
&
ldd
,
sizeof
ldd
));
if
(
is_fp8_dtype
(
pre_act_out
->
dtype
()))
{
NVTE_CHECK
(
pre_act_out
->
scale
.
dptr
,
"Scaling must be set for FP8 dtype"
);
NVTE_CHECK_CUBLASMP
(
cublasMpMatmulDescriptorAttributeSet
(
ctx
->
matmul_desc
.
get
(),
CUBLASMP_MATMUL_DESCRIPTOR_ATTRIBUTE_EPILOGUE_AUX_SCALE_MODE
,
&
scale_mode
,
sizeof
scale_mode
));
NVTE_CHECK_CUBLASMP
(
cublasMpMatmulDescriptorAttributeSet
(
ctx
->
matmul_desc
.
get
(),
CUBLASMP_MATMUL_DESCRIPTOR_ATTRIBUTE_EPILOGUE_AUX_SCALE_POINTER
,
&
pre_act_out
->
scale
.
dptr
,
sizeof
(
void
*
)));
if
(
pre_act_out
->
amax
.
dptr
)
{
NVTE_CHECK_CUBLASMP
(
cublasMpMatmulDescriptorAttributeSet
(
ctx
->
matmul_desc
.
get
(),
CUBLASMP_MATMUL_DESCRIPTOR_ATTRIBUTE_EPILOGUE_AUX_AMAX_POINTER
,
&
pre_act_out
->
amax
.
dptr
,
sizeof
(
void
*
)));
}
}
}
if
(
comm_sm_count
)
{
NVTE_CHECK_CUBLASMP
(
cublasMpMatmulDescriptorAttributeSet
(
ctx
->
matmul_desc
.
get
(),
CUBLASMP_MATMUL_DESCRIPTOR_ATTRIBUTE_COMMUNICATION_SM_COUNT
,
&
comm_sm_count
,
sizeof
comm_sm_count
));
}
NVTE_CHECK_CUBLASMP
(
cublasMpStreamSet
(
ctx
->
cublas_mp
.
get
(),
main_stream
));
size_t
wrksp_size_device
{};
size_t
wrksp_size_host
{};
float
alpha
=
1.0
;
float
beta
=
accumulate
?
1.0
:
0.0
;
std
::
tuple
args
{
ctx
->
cublas_mp
.
get
(),
ctx
->
matmul_desc
.
get
(),
m
,
n
,
k
,
&
alpha
,
a
->
data
.
dptr
,
1
,
1
,
ctx
->
a_desc
.
get
(),
b
->
data
.
dptr
,
1
,
1
,
ctx
->
b_desc
.
get
(),
&
beta
,
accumulate
?
d
->
data
.
dptr
:
nullptr
,
1
,
1
,
accumulate
?
ctx
->
d_desc
.
get
()
:
nullptr
,
d
->
data
.
dptr
,
1
,
1
,
ctx
->
d_desc
.
get
()};
NVTE_CHECK_CUBLASMP
(
std
::
apply
(
cublasMpMatmul_bufferSize
,
std
::
tuple_cat
(
args
,
std
::
tuple
{
&
wrksp_size_device
,
&
wrksp_size_host
})));
std
::
vector
<
uint8_t
>
workspace_host
(
wrksp_size_host
);
if
(
ctx
->
workspace_size
<
wrksp_size_device
)
{
nvshmem_free
(
ctx
->
workspace
);
ctx
->
workspace
=
nvshmem_malloc
(
wrksp_size_device
);
ctx
->
workspace_size
=
wrksp_size_device
;
}
NVTE_CHECK_CUBLASMP
(
std
::
apply
(
cublasMpMatmul
,
std
::
tuple_cat
(
args
,
std
::
tuple
{
ctx
->
workspace
,
ctx
->
workspace_size
,
workspace_host
.
data
(),
workspace_host
.
size
()})));
NVTE_CHECK_CUDA
(
cudaEventRecord
(
ctx
->
event
.
get
(),
main_stream
));
NVTE_CHECK_CUDA
(
cudaStreamWaitEvent
(
ctx
->
stream
.
get
(),
ctx
->
event
.
get
(),
0
));
}
}
// namespace
NVTECommGemmCtx
*
nvte_comm_gemm_ctx_create
(
ncclComm_t
comm
,
int
nranks
,
int
rank
)
{
NVTE_API_CALL
(
nvte_comm_gemm_ctx_create
);
auto
stream
=
CudaStreamCreate
();
auto
event
=
CudaEventCreate
(
cudaEventDisableTiming
);
auto
cublas_mp
=
CublasMpCreate
(
stream
.
get
());
auto
col_major
=
CublasMpGridCreate
(
nranks
,
1
,
CUBLASMP_GRID_LAYOUT_COL_MAJOR
,
comm
);
auto
row_major
=
CublasMpGridCreate
(
1
,
nranks
,
CUBLASMP_GRID_LAYOUT_ROW_MAJOR
,
comm
);
// Pre-creating matrix descriptors here, will be initialized with the actual params later.
auto
a_desc
=
CublasMpMatrixDescCreate
(
1
,
1
,
1
,
1
,
0
,
0
,
1
,
CUDA_R_16F
,
row_major
.
get
());
auto
b_desc
=
CublasMpMatrixDescCreate
(
1
,
1
,
1
,
1
,
0
,
0
,
1
,
CUDA_R_16F
,
row_major
.
get
());
auto
d_desc
=
CublasMpMatrixDescCreate
(
1
,
1
,
1
,
1
,
0
,
0
,
1
,
CUDA_R_16F
,
row_major
.
get
());
auto
matmul_desc
=
CublasMpMatmulDescCreate
(
CUBLAS_COMPUTE_32F
);
return
new
NVTECommGemmCtx
{
.
nranks
=
nranks
,
.
rank
=
rank
,
.
comm
=
comm
,
.
stream
=
std
::
move
(
stream
),
.
event
=
std
::
move
(
event
),
.
cublas_mp
=
std
::
move
(
cublas_mp
),
.
grid_col_major
=
std
::
move
(
col_major
),
.
grid_row_major
=
std
::
move
(
row_major
),
.
a_desc
=
std
::
move
(
a_desc
),
.
b_desc
=
std
::
move
(
b_desc
),
.
d_desc
=
std
::
move
(
d_desc
),
.
matmul_desc
=
std
::
move
(
matmul_desc
),
};
}
void
nvte_comm_gemm_ctx_destroy
(
NVTECommGemmCtx
*
ctx
)
{
NVTE_API_CALL
(
nvte_comm_gemm_ctx_destroy
);
nvshmemx_sync_all_on_stream
(
ctx
->
stream
.
get
());
delete
ctx
;
}
void
nvte_all_gather_gemm
(
NVTECommGemmCtx
*
ctx
,
int64_t
m
,
int64_t
n
,
int64_t
k
,
const
NVTETensor
a
,
const
NVTETensor
b
,
const
NVTETensor
d
,
const
NVTETensor
bias
,
const
NVTETensor
pre_act_out
,
bool
transa
,
bool
transb
,
bool
grad
,
bool
accumulate
,
int
comm_sm_count
,
cudaStream_t
main_stream
,
NVTECommGemmAlgoType
algo
)
{
NVTE_API_CALL
(
nvte_all_gather_gemm
);
cublasmp_gemm
(
AgGemmInitMatrices
,
ctx
,
algo
,
m
,
n
,
k
,
convertNVTETensorCheck
(
a
),
convertNVTETensorCheck
(
b
),
convertNVTETensorCheck
(
d
),
convertNVTETensorCheck
(
bias
),
convertNVTETensorCheck
(
pre_act_out
),
transa
,
transb
,
grad
,
accumulate
,
comm_sm_count
,
main_stream
);
}
void
nvte_gemm_reduce_scatter
(
NVTECommGemmCtx
*
ctx
,
int64_t
m
,
int64_t
n
,
int64_t
k
,
const
NVTETensor
a
,
const
NVTETensor
b
,
const
NVTETensor
d
,
const
NVTETensor
bias
,
const
NVTETensor
pre_act_out
,
bool
transa
,
bool
transb
,
bool
grad
,
bool
accumulate
,
int
comm_sm_count
,
cudaStream_t
main_stream
,
NVTECommGemmAlgoType
algo
)
{
NVTE_API_CALL
(
nvte_gemm_reduce_scatter
);
cublasmp_gemm
(
GemmRsInitMatrices
,
ctx
,
algo
,
m
,
n
,
k
,
convertNVTETensorCheck
(
a
),
convertNVTETensorCheck
(
b
),
convertNVTETensorCheck
(
d
),
convertNVTETensorCheck
(
bias
),
convertNVTETensorCheck
(
pre_act_out
),
transa
,
transb
,
grad
,
accumulate
,
comm_sm_count
,
main_stream
);
}
void
nvte_gemm_all_reduce
(
NVTECommGemmCtx
*
ctx
,
int64_t
m
,
int64_t
n
,
int64_t
k
,
const
NVTETensor
a
,
const
NVTETensor
b
,
const
NVTETensor
d
,
const
NVTETensor
bias
,
const
NVTETensor
pre_act_out
,
bool
transa
,
bool
transb
,
bool
grad
,
bool
accumulate
,
int
comm_sm_count
,
cudaStream_t
main_stream
,
NVTECommGemmAlgoType
algo
)
{
NVTE_API_CALL
(
nvte_gemm_all_reduce
);
cublasmp_gemm
(
GemmArInitMatrices
,
ctx
,
algo
,
m
,
n
,
k
,
convertNVTETensorCheck
(
a
),
convertNVTETensorCheck
(
b
),
convertNVTETensorCheck
(
d
),
convertNVTETensorCheck
(
bias
),
convertNVTETensorCheck
(
pre_act_out
),
transa
,
transb
,
grad
,
accumulate
,
comm_sm_count
,
main_stream
);
}
int64_t
nvte_comm_gemm_numroc
(
NVTECommGemmCtx
*
ctx
,
int64_t
global_size
)
{
NVTE_API_CALL
(
nvte_comm_gemm_numroc
);
return
cublasMpNumroc
(
global_size
,
block_size
(
ctx
,
global_size
),
ctx
->
rank
,
0
,
ctx
->
nranks
);
}
transformer_engine/common/common.cu
View file @
e45d66a3
...
...
@@ -26,6 +26,24 @@ __global__ void __launch_bounds__(1)
}
// namespace
cudaDataType_t
get_cuda_dtype
(
const
transformer_engine
::
DType
t
)
{
using
namespace
transformer_engine
;
switch
(
t
)
{
case
DType
::
kFloat16
:
return
CUDA_R_16F
;
case
DType
::
kFloat32
:
return
CUDA_R_32F
;
case
DType
::
kBFloat16
:
return
CUDA_R_16BF
;
case
DType
::
kFloat8E4M3
:
return
CUDA_R_8F_E4M3
;
case
DType
::
kFloat8E5M2
:
return
CUDA_R_8F_E5M2
;
default:
NVTE_ERROR
(
"Invalid type"
);
}
}
void
update_tensor_scale_inv
(
Tensor
*
t
,
cudaStream_t
stream
)
{
if
(
is_fp8_dtype
(
t
->
data
.
dtype
)
&&
is_tensor_scaling
(
t
->
scaling_mode
))
{
NVTE_CHECK
(
t
->
scale_inv
.
dptr
!=
nullptr
,
"Tensor should have allocated scale_inv."
);
...
...
transformer_engine/common/common.h
View file @
e45d66a3
...
...
@@ -276,6 +276,8 @@ struct QuantizationConfig {
};
};
cudaDataType_t
get_cuda_dtype
(
const
transformer_engine
::
DType
t
);
template
<
typename
T
>
constexpr
T
DIVUP
(
const
T
&
x
,
const
T
&
y
)
{
return
(((
x
)
+
((
y
)
-
1
))
/
(
y
));
...
...
@@ -395,9 +397,19 @@ struct BitsNumber {
template
<
typename
T
>
struct
TypeInfo
{
#if FP4_TYPE_SUPPORTED
using
types
=
std
::
tuple
<
byte
,
int16
,
int32
,
int64
,
fp32
,
fp16
,
bf16
,
fp8e4m3
,
fp8e5m2
,
int8
,
fp4e2m1
>
;
using
types
=
std
::
tuple
<
byte
,
int16
,
int32
,
int64
,
fp32
,
fp16
,
bf16
,
fp8e4m3
,
fp8e5m2
,
int8
,
fp4e2m1
#if CUDA_VERSION >= 12080
,
fp8e8m0
#endif
>
;
#else
using
types
=
std
::
tuple
<
byte
,
int16
,
int32
,
int64
,
fp32
,
fp16
,
bf16
,
fp8e4m3
,
fp8e5m2
,
int8
>
;
using
types
=
std
::
tuple
<
byte
,
int16
,
int32
,
int64
,
fp32
,
fp16
,
bf16
,
fp8e4m3
,
fp8e5m2
,
int8
#if CUDA_VERSION >= 12080
,
fp8e8m0
#endif
>
;
#endif
template
<
typename
U
,
DType
current
>
...
...
transformer_engine/common/gemm/cublaslt_gemm.cu
View file @
e45d66a3
...
...
@@ -29,24 +29,6 @@
#ifndef __HIP_PLATFORM_AMD__
namespace
{
cudaDataType_t
get_cuda_dtype
(
const
transformer_engine
::
DType
t
)
{
using
namespace
transformer_engine
;
switch
(
t
)
{
case
DType
::
kFloat16
:
return
CUDA_R_16F
;
case
DType
::
kFloat32
:
return
CUDA_R_32F
;
case
DType
::
kBFloat16
:
return
CUDA_R_16BF
;
case
DType
::
kFloat8E4M3
:
return
CUDA_R_8F_E4M3
;
case
DType
::
kFloat8E5M2
:
return
CUDA_R_8F_E5M2
;
default:
NVTE_ERROR
(
"Invalid type"
);
}
}
uint32_t
_getAlignment
(
uintptr_t
address
)
{
// alignment are in bytes
uint32_t
alignment
=
256
;
...
...
transformer_engine/common/include/transformer_engine/comm_gemm.h
0 → 100644
View file @
e45d66a3
/*************************************************************************
* Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
*
* See LICENSE for license information.
************************************************************************/
/*! \file comm_gemm.h
* \brief Functions for distributed (multi-GPU) matrix multiplication.
*
* This API is a TE-native binding to cuBLASMp library.
* Refer here: https://docs.nvidia.com/cuda/cublasmp/usage/tp.html for specific
* patterns, which allow communication-computation overlap.
*
* All GEMM functions here have the same computation semantic, as expressed
* on global matrices, similar to nvte_cublas_gemm call:
* - `D = AB` if both `bias` and `pre_gelu_out` are empty tensors
* - `D = AB + bias` if `pre_gelu_out` is empty and `bias` is not empty
* - `D = GELU(AB + bias)` if both `bias` and `pre_gelu_out` are not empty tensors
*
* Functions differ in matrix distribution patterns
*/
#ifndef TRANSFORMER_ENGINE_COMMON_COMM_GEMM_H_
#define TRANSFORMER_ENGINE_COMMON_COMM_GEMM_H_
#include <nccl.h>
#include <stdint.h>
#include "transformer_engine.h"
#ifdef __cplusplus
extern
"C"
{
#else
#include <stdbool.h>
#endif
typedef
struct
NVTECommGemmCtx
NVTECommGemmCtx
;
enum
NVTECommGemmAlgoType
{
kNVTECommGemmAlgoDefault
=
0
,
kNVTECommGemmAlgoSplitP2P
=
1
,
kNVTECommGemmAlgoSplitMulticast
=
2
,
kNVTECommGemmAlgoAtomicP2P
=
3
,
kNVTECommGemmAlgoAtomicMulticast
=
4
};
/*! \brief Create a comm-gemm context.
*
* \param[in] comm NCCL communicator.
* \param[in] nranks Number of ranks.
* \param[in] rank Local rank.
*/
NVTECommGemmCtx
*
nvte_comm_gemm_ctx_create
(
ncclComm_t
comm
,
int
nranks
,
int
rank
);
/*! \brief Destroy a comm-gemm context.
*
* \param[in] ctx Context to destroy.
*/
void
nvte_comm_gemm_ctx_destroy
(
NVTECommGemmCtx
*
ctx
);
/*! \brief Perform AllGather communication followed by GEMM
*
* Gathers distributed data from all ranks, then computes matrix multiplication.
*
* \param[in] ctx Comm-GEMM context.
* \param[in] m Global m dimension.
* \param[in] n Global n dimension.
* \param[in] k Global k dimension.
* \param[in] a Local part of A matrix.
* \param[in] b Local part of B matrix.
* \param[in,out] d Local part of D matrix.
* \param[in] bias Bias tensor.
* \param[in,out] pre_act_out Local part of output matrix before GELU activation.
* \param[in] transa Whether A matrix is transposed.
* \param[in] transb Whether B matrix is transposed.
* \param[in] grad Whether this operation is part of gradient computation.
* \param[in] accumulate Whether to accumulate the result into the D matrix.
* \param[in] comm_sm_count Number of GPU SMs to use for communication (default=0: use heuristics)
* \param[in] main_stream CUDA stream used for computation.
* \param[in] algo Algorithm to use.
*/
void
nvte_all_gather_gemm
(
NVTECommGemmCtx
*
ctx
,
int64_t
m
,
int64_t
n
,
int64_t
k
,
const
NVTETensor
a
,
const
NVTETensor
b
,
const
NVTETensor
d
,
const
NVTETensor
bias
,
const
NVTETensor
pre_act_out
,
bool
transa
,
bool
transb
,
bool
grad
,
bool
accumulate
,
int
comm_sm_count
,
cudaStream_t
main_stream
,
NVTECommGemmAlgoType
algo
);
/*! \brief Perform GEMM followed by ReduceScatter communication
*
* Computes matrix multiplication, then distributes results across ranks with reduction.
*
* \param[in] ctx Comm-GEMM context.
* \param[in] m Global m dimension.
* \param[in] n Global n dimension.
* \param[in] k Global k dimension.
* \param[in] a Local part of A matrix.
* \param[in] b Local part of B matrix.
* \param[in,out] d Local part of D matrix.
* \param[in] bias Bias tensor.
* \param[in,out] pre_act_out Local part of output matrix before GELU activation.
* \param[in] transa Whether A matrix is transposed.
* \param[in] transb Whether B matrix is transposed.
* \param[in] grad Whether this operation is part of gradient computation.
* \param[in] accumulate Whether to accumulate the result into the D matrix.
* \param[in] comm_sm_count Number of GPU SMs to use for communication (default=0: use heuristics)
* \param[in] main_stream CUDA stream used for computation.
* \param[in] algo Algorithm to use.
*/
void
nvte_gemm_reduce_scatter
(
NVTECommGemmCtx
*
ctx
,
int64_t
m
,
int64_t
n
,
int64_t
k
,
const
NVTETensor
a
,
const
NVTETensor
b
,
const
NVTETensor
d
,
const
NVTETensor
bias
,
const
NVTETensor
pre_act_out
,
bool
transa
,
bool
transb
,
bool
grad
,
bool
accumulate
,
int
comm_sm_count
,
cudaStream_t
main_stream
,
NVTECommGemmAlgoType
algo
);
/*! \brief Perform GEMM followed by AllReduce communication
*
* Computes matrix multiplication, then reduces results across all ranks.
*
* \param[in] ctx Comm-GEMM context.
* \param[in] m Global m dimension.
* \param[in] n Global n dimension.
* \param[in] k Global k dimension.
* \param[in] a Local part of A matrix.
* \param[in] b Local part of B matrix.
* \param[in,out] d Local part of D matrix.
* \param[in] bias Bias tensor.
* \param[in,out] pre_act_out Local part of output matrix before GELU activation.
* \param[in] transa Whether A matrix is transposed.
* \param[in] transb Whether B matrix is transposed.
* \param[in] grad Whether this operation is part of gradient computation.
* \param[in] accumulate Whether to accumulate the result into the D matrix.
* \param[in] comm_sm_count Number of GPU SMs to use for communication (default=0: use heuristics)
* \param[in] main_stream CUDA stream used for computation.
* \param[in] algo Algorithm to use.
*/
void
nvte_gemm_all_reduce
(
NVTECommGemmCtx
*
ctx
,
int64_t
m
,
int64_t
n
,
int64_t
k
,
const
NVTETensor
a
,
const
NVTETensor
b
,
const
NVTETensor
d
,
const
NVTETensor
bias
,
const
NVTETensor
pre_act_out
,
bool
transa
,
bool
transb
,
bool
grad
,
bool
accumulate
,
int
comm_sm_count
,
cudaStream_t
main_stream
,
NVTECommGemmAlgoType
algo
);
/*! \brief Get local number of rows or columns.
*
* Utility function to get local dimension.
* Block size, nranks and local rank is derived from the context ctx.
*
* \param[in] ctx Comm-GEMM context.
* \param[in] global_size Global dimension.
*/
int64_t
nvte_comm_gemm_numroc
(
NVTECommGemmCtx
*
ctx
,
int64_t
global_size
);
#ifdef __cplusplus
}
// extern "C"
#endif
#endif // TRANSFORMER_ENGINE_COMM_GEMM_H_
transformer_engine/common/util/logging.h
View file @
e45d66a3
...
...
@@ -23,8 +23,13 @@
#endif // __HIP_PLATFORM_AMD__
#include <nvrtc.h>
#ifdef NVTE_WITH_CUBLASMP
#include <cublasmp.h>
#endif // NVTE_WITH_CUBLASMP
#include <iostream>
#include <stdexcept>
#include <string>
#include "../util/string.h"
...
...
@@ -130,4 +135,16 @@
} \
} while (false)
#ifdef NVTE_WITH_CUBLASMP
#define NVTE_CHECK_CUBLASMP(expr) \
do { \
const cublasMpStatus_t status = (expr); \
if (status != CUBLASMP_STATUS_SUCCESS) { \
NVTE_ERROR("cuBLASMp Error: ", std::to_string(status)); \
} \
} while (false)
#endif // NVTE_WITH_CUBLASMP
#endif // TRANSFORMER_ENGINE_COMMON_UTIL_LOGGING_H_
transformer_engine/pytorch/attention/dot_product_attention/utils.py
View file @
e45d66a3
...
...
@@ -438,8 +438,8 @@ def get_attention_backend(
# | FP8 | non-paged/paged | sm90 | thd | >= 1
# Unfused | FP32/FP16/BF16 | non-paged/paged | all | bshd,sbhd,thd | >= 1
if
inference_params
is
not
None
:
if
device_compute_capability
==
(
8
,
9
)
and
cudnn_version
<=
(
9
,
1
2
,
0
):
logger
.
debug
(
"Disabling FusedAttention for KV caching for sm89 and cuDNN <= 9.1
2
"
)
if
device_compute_capability
==
(
8
,
9
)
and
cudnn_version
<=
(
9
,
1
3
,
0
):
logger
.
debug
(
"Disabling FusedAttention for KV caching for sm89 and cuDNN <= 9.1
3
"
)
use_fused_attention
=
False
if
context_parallel
:
logger
.
debug
(
"Disabling all backends for KV caching with context parallelism"
)
...
...
@@ -838,7 +838,7 @@ def get_attention_backend(
# flash-attn >=2.4.1 | yes
# FusedAttention |
# sub-backend 0 | yes
# sub-backend 1 | workspace optimization path and sm90
+
: yes;
# sub-backend 1 | workspace optimization path and sm90: yes;
# | otherwise: no
# sub-backend 2 | no
# UnfusedDotProductAttention | yes
...
...
@@ -854,8 +854,9 @@ def get_attention_backend(
use_flash_attention_2
=
False
if
use_fused_attention
and
deterministic
:
if
fused_attention_backend
==
FusedAttnBackend
[
"FP8"
]
and
is_training
:
logger
.
debug
(
"Disabling FusedAttention for determinism reasons"
)
logger
.
debug
(
"Disabling FusedAttention for determinism reasons
with FP8
"
)
use_fused_attention
=
False
fused_attention_backend
=
None
if
(
fused_attention_backend
==
FusedAttnBackend
[
"F16_arbitrary_seqlen"
]
and
is_training
...
...
@@ -865,8 +866,13 @@ def get_attention_backend(
or
cudnn_version
<
(
8
,
9
,
5
)
)
):
logger
.
debug
(
"Disabling FusedAttention for determinism reasons"
)
logger
.
debug
(
"Disabling FusedAttention for determinism reasons with post_scale_bias"
)
use_fused_attention
=
False
fused_attention_backend
=
None
if
is_training
and
device_compute_capability
>=
(
10
,
0
)
and
cudnn_version
<=
(
9
,
14
,
0
):
logger
.
debug
(
"Disabling FusedAttention for determinism reasons on Blackwell"
)
use_fused_attention
=
False
fused_attention_backend
=
None
# use_flash_attention may have been set above
use_flash_attention_2
=
use_flash_attention
and
use_flash_attention_2
...
...
transformer_engine/pytorch/onnx_extensions.py
View file @
e45d66a3
...
...
@@ -194,12 +194,12 @@ def onnx_quantize_mxfp8_symbolic(
tensor
:
onnxscript
.
onnx_types
.
TensorType
,
)
->
Tuple
[
onnxscript
.
onnx_types
.
TensorType
,
onnxscript
.
onnx_types
.
TensorType
]:
"""Symbolic quantize to MXFP8Tensor used for inference."""
tensor_out
,
scale_inv_out
=
TRT_MXFP8Quantize
Linear
(
tensor
)
tensor_out
,
scale_inv_out
=
TRT_MXFP8
Dynamic
Quantize
(
tensor
)
return
tensor_out
,
scale_inv_out
schema
=
defs
.
OpSchema
(
name
=
"TRT_MXFP8Quantize
Linear
"
,
name
=
"TRT_MXFP8
Dynamic
Quantize"
,
domain
=
"trt"
,
since_version
=
1
,
doc
=
"TRT MXFP8 Quantize Linear used for inference."
,
...
...
@@ -214,8 +214,8 @@ schema = defs.OpSchema(
],
)
TRT_MXFP8Quantize
Linear
=
onnxscript
.
values
.
Op
(
opset
=
trt_opset
,
name
=
"TRT_MXFP8Quantize
Linear
"
,
op_schema
=
schema
TRT_MXFP8
Dynamic
Quantize
=
onnxscript
.
values
.
Op
(
opset
=
trt_opset
,
name
=
"TRT_MXFP8
Dynamic
Quantize"
,
op_schema
=
schema
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment