Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
TransformerEngine
Commits
27ddce40
Commit
27ddce40
authored
Oct 11, 2025
by
wenjh
Browse files
Merge branch 'nv_main'
parents
d262ef4c
5b3092a0
Changes
208
Show whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
1616 additions
and
155 deletions
+1616
-155
tests/cpp/test_common.cu
tests/cpp/test_common.cu
+10
-1
tests/cpp_distributed/CMakeLists.txt
tests/cpp_distributed/CMakeLists.txt
+57
-0
tests/cpp_distributed/test_comm_gemm.cu
tests/cpp_distributed/test_comm_gemm.cu
+441
-0
tests/jax/distributed_test_base.py
tests/jax/distributed_test_base.py
+5
-5
tests/jax/multi_process_launch.sh
tests/jax/multi_process_launch.sh
+23
-0
tests/jax/test_custom_call_compute.py
tests/jax/test_custom_call_compute.py
+20
-9
tests/jax/test_distributed_fused_attn.py
tests/jax/test_distributed_fused_attn.py
+2
-2
tests/jax/test_distributed_helper.py
tests/jax/test_distributed_helper.py
+35
-0
tests/jax/test_distributed_layernorm_mlp.py
tests/jax/test_distributed_layernorm_mlp.py
+26
-19
tests/jax/test_distributed_softmax.py
tests/jax/test_distributed_softmax.py
+2
-2
tests/jax/test_fused_attn.py
tests/jax/test_fused_attn.py
+12
-3
tests/jax/test_helper.py
tests/jax/test_helper.py
+34
-50
tests/jax/test_layer.py
tests/jax/test_layer.py
+35
-15
tests/jax/test_multi_process_distributed_grouped_gemm.py
tests/jax/test_multi_process_distributed_grouped_gemm.py
+172
-0
tests/jax/test_sharding.py
tests/jax/test_sharding.py
+0
-38
tests/pytorch/attention/test_attention.py
tests/pytorch/attention/test_attention.py
+2
-0
tests/pytorch/attention/test_attention_with_cp.py
tests/pytorch/attention/test_attention_with_cp.py
+8
-0
tests/pytorch/attention/test_cp_utils.py
tests/pytorch/attention/test_cp_utils.py
+715
-0
tests/pytorch/debug/test_api_features.py
tests/pytorch/debug/test_api_features.py
+7
-5
tests/pytorch/debug/test_log.py
tests/pytorch/debug/test_log.py
+10
-6
No files found.
tests/cpp/test_common.cu
View file @
27ddce40
...
...
@@ -891,9 +891,18 @@ void fillCase(Tensor *t, const InputsFillCase fill_case) {
}
}
template
void
fillCase
<
byte
>(
Tensor
*
t
,
const
InputsFillCase
fill_case
);
template
void
fillCase
<
int16
>(
Tensor
*
t
,
const
InputsFillCase
fill_case
);
template
void
fillCase
<
int32
>(
Tensor
*
t
,
const
InputsFillCase
fill_case
);
template
void
fillCase
<
int64
>(
Tensor
*
t
,
const
InputsFillCase
fill_case
);
template
void
fillCase
<
fp32
>(
Tensor
*
t
,
const
InputsFillCase
fill_case
);
template
void
fillCase
<
fp16
>(
Tensor
*
t
,
const
InputsFillCase
fill_case
);
template
void
fillCase
<
bf16
>(
Tensor
*
t
,
const
InputsFillCase
fill_case
);
template
void
fillCase
<
fp8e4m3
>(
Tensor
*
t
,
const
InputsFillCase
fill_case
);
template
void
fillCase
<
fp8e5m2
>(
Tensor
*
t
,
const
InputsFillCase
fill_case
);
template
void
fillCase
<
fp32
>(
Tensor
*
t
,
const
InputsFillCase
fill_case
);
#if FP4_TYPE_SUPPORTED
template
void
fillCase
<
fp4e2m1
>(
Tensor
*
t
,
const
InputsFillCase
fill_case
);
#endif
void
setRandomScale
(
Tensor
*
t
)
{
std
::
uniform_real_distribution
<>
dis
(
-
2.0
,
1.0
);
...
...
tests/cpp_distributed/CMakeLists.txt
0 → 100644
View file @
27ddce40
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
cmake_minimum_required
(
VERSION 3.18
)
if
(
NOT DEFINED CMAKE_CUDA_ARCHITECTURES
)
if
(
CUDAToolkit_VERSION VERSION_GREATER_EQUAL 12.8
)
set
(
CMAKE_CUDA_ARCHITECTURES 75 80 89 90 100 120
)
else
()
set
(
CMAKE_CUDA_ARCHITECTURES 75 80 89 90
)
endif
()
endif
()
set
(
CMAKE_CXX_STANDARD 17
)
set
(
CMAKE_CUDA_STANDARD 17
)
set
(
CMAKE_CUDA_STANDARD_REQUIRED ON
)
project
(
transformer_engine_distributed_tests LANGUAGES CUDA CXX
)
add_subdirectory
(
../../3rdparty/googletest
${
PROJECT_BINARY_DIR
}
/googletest
)
include_directories
(
${
gtest_SOURCE_DIR
}
/include
${
gtest_SOURCE_DIR
}
)
if
(
NOT DEFINED TE_LIB_PATH
)
execute_process
(
COMMAND bash -c
"python3 -c 'import transformer_engine as te; print(te.__file__)'"
OUTPUT_VARIABLE TE_LIB_FILE
OUTPUT_STRIP_TRAILING_WHITESPACE
)
get_filename_component
(
TE_LIB_PATH
${
TE_LIB_FILE
}
DIRECTORY
)
endif
()
find_library
(
TE_LIB NAMES transformer_engine PATHS
"
${
TE_LIB_PATH
}
/.."
${
TE_LIB_PATH
}
ENV TE_LIB_PATH REQUIRED
)
message
(
STATUS
"Found transformer_engine library:
${
TE_LIB
}
"
)
include_directories
(
../../transformer_engine/common/include
)
include_directories
(
../../transformer_engine/common
)
include_directories
(
../../transformer_engine
)
include_directories
(
${
CMAKE_SOURCE_DIR
}
)
find_package
(
CUDAToolkit REQUIRED
)
add_executable
(
test_comm_gemm
test_comm_gemm.cu
../cpp/test_common.cu
)
find_package
(
OpenMP REQUIRED
)
find_package
(
MPI REQUIRED
)
find_library
(
NCCL_LIB
NAMES nccl libnccl
PATH_SUFFIXES lib
REQUIRED
)
target_include_directories
(
test_comm_gemm PRIVATE
${
MPI_CXX_INCLUDE_PATH
}
$ENV{CUBLASMP_HOME}/include
)
target_link_libraries
(
test_comm_gemm PUBLIC CUDA::cuda_driver CUDA::cudart GTest::gtest
${
TE_LIB
}
CUDA::nvrtc MPI::MPI_CXX
${
NCCL_LIB
}
OpenMP::OpenMP_CXX
)
include
(
GoogleTest
)
gtest_discover_tests
(
test_comm_gemm DISCOVERY_TIMEOUT 600
)
tests/cpp_distributed/test_comm_gemm.cu
0 → 100644
View file @
27ddce40
/*************************************************************************
* Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
*
* See LICENSE for license information.
************************************************************************/
#include <cuda.h>
#include <gtest/gtest.h>
#include <mpi.h>
#include <nccl.h>
#include <transformer_engine/comm_gemm.h>
#include <transformer_engine/gemm.h>
#include <transformer_engine/transformer_engine.h>
#include <iostream>
#include <limits>
#include <random>
#include <sstream>
#include <string>
#include <vector>
#include "../cpp/test_common.h"
#include "common.h"
using
transformer_engine
::
DType
;
using
transformer_engine
::
TypeInfo
;
#define CHECK_MPI(expr) \
do { \
int err = (expr); \
if (err != MPI_SUCCESS) { \
char err_str[MPI_MAX_ERROR_STRING + 1]{}; \
int _len{}; \
MPI_Error_string(err, err_str, &_len); \
EXPECT_TRUE(false) << "MPI error: " << err << ": " << err_str; \
} \
} while (false)
#define CHECK_NCCL(expr) \
do { \
ncclResult_t err = (expr); \
if (err != ncclSuccess) { \
EXPECT_TRUE(false) << "NCCL error: " << err << ": " << ncclGetErrorString(err); \
} \
} while (false)
#define CHECK_CU(expr) \
do { \
CUresult err = (expr); \
if (err != CUDA_SUCCESS) { \
const char* str{}; \
CUresult e_str = cuGetErrorString(err, &str); \
if (e_str != CUDA_SUCCESS) str = "(unknown)"; \
EXPECT_TRUE(false) << "CU error: " << err << ": " << str; \
} \
} while (false)
int
main
(
int
argc
,
char
*
argv
[])
{
::
testing
::
InitGoogleTest
(
&
argc
,
argv
);
CHECK_MPI
(
MPI_Init
(
&
argc
,
&
argv
));
auto
ret
=
RUN_ALL_TESTS
();
CHECK_MPI
(
MPI_Finalize
());
return
ret
;
}
bool
IsMulticastSupported
(
int
device_id
)
{
int
supported
=
0
;
CHECK_CU
(
cuDeviceGetAttribute
(
&
supported
,
CU_DEVICE_ATTRIBUTE_MULTICAST_SUPPORTED
,
device_id
));
return
supported
;
}
template
<
typename
T
>
std
::
vector
<
T
>
CopyMatrix
(
const
std
::
vector
<
T
>&
data
,
size_t
mstart
,
size_t
nstart
,
size_t
msize
,
size_t
nsize
,
size_t
ld
)
{
std
::
vector
<
T
>
ret
(
msize
*
nsize
);
size_t
dst
=
0
;
for
(
size_t
j
=
nstart
;
j
<
nstart
+
nsize
;
++
j
)
{
for
(
size_t
i
=
mstart
;
i
<
mstart
+
msize
;
++
i
)
{
ret
[
dst
++
]
=
data
[
j
*
ld
+
i
];
}
}
return
ret
;
}
template
<
typename
T
>
test
::
Tensor
Make
(
size_t
m
,
size_t
n
,
float
scale
)
{
test
::
Tensor
ret
(
""
,
std
::
vector
{
n
,
m
},
TypeInfo
<
T
>::
dtype
);
ret
.
set_scale
(
scale
);
ret
.
set_scale_inv
(
1.0
/
scale
);
return
ret
;
}
template
<
typename
T
>
test
::
Tensor
MakeFromData
(
const
std
::
vector
<
T
>&
data
,
size_t
mstart
,
size_t
nstart
,
size_t
msize
,
size_t
nsize
,
size_t
ld
,
float
scale
)
{
test
::
Tensor
ret
(
""
,
std
::
vector
{
nsize
,
msize
},
TypeInfo
<
T
>::
dtype
);
ret
.
set_scale
(
scale
);
ret
.
set_scale_inv
(
1.0
/
scale
);
auto
local
=
CopyMatrix
(
data
,
mstart
,
nstart
,
msize
,
nsize
,
ld
);
NVTE_CHECK_CUDA
(
cudaMemcpy
(
ret
.
rowwise_dptr
(),
local
.
data
(),
local
.
size
()
*
sizeof
local
[
0
],
cudaMemcpyDefault
));
return
ret
;
}
template
<
typename
T
>
float
GetScale
(
float
amax
)
{
if
constexpr
(
sizeof
(
T
)
>
1
)
return
1.0
;
return
static_cast
<
float
>
(
static_cast
<
T
>
(
std
::
numeric_limits
<
float
>::
max
()))
/
amax
;
}
struct
Params
{
DType
a_type
;
DType
b_type
;
DType
d_type
;
bool
transa
;
bool
transb
;
size_t
m
;
size_t
n
;
size_t
k
;
float
tol
;
};
class
CommGemmFixure
:
public
::
testing
::
TestWithParam
<
Params
>
{
protected:
CommGemmFixure
()
{
CHECK_MPI
(
MPI_Comm_size
(
MPI_COMM_WORLD
,
&
nranks_
));
CHECK_MPI
(
MPI_Comm_rank
(
MPI_COMM_WORLD
,
&
rank_
));
NVTE_CHECK_CUDA
(
cudaSetDevice
(
rank_
));
ncclUniqueId
id
{};
if
(
rank_
==
0
)
CHECK_NCCL
(
ncclGetUniqueId
(
&
id
));
CHECK_MPI
(
MPI_Bcast
(
&
id
,
sizeof
(
id
),
MPI_BYTE
,
0
,
MPI_COMM_WORLD
));
CHECK_NCCL
(
ncclCommInitRank
(
&
comm_
,
nranks_
,
id
,
rank_
));
ctx_
=
nvte_comm_gemm_ctx_create
(
comm_
,
nranks_
,
rank_
);
}
~
CommGemmFixure
()
{
nvte_comm_gemm_ctx_destroy
(
ctx_
);
ncclCommDestroy
(
comm_
);
}
struct
PatternDims
{
int64_t
a_rows_start
;
int64_t
a_rows_num
;
int64_t
a_cols_start
;
int64_t
a_cols_num
;
int64_t
b_rows_start
;
int64_t
b_rows_num
;
int64_t
b_cols_start
;
int64_t
b_cols_num
;
int64_t
d_rows_start
;
int64_t
d_rows_num
;
int64_t
d_cols_start
;
int64_t
d_cols_num
;
};
virtual
PatternDims
DistributeTensors
(
int64_t
m
,
int64_t
n
,
int64_t
k
)
=
0
;
virtual
void
CommGemm
(
int64_t
m
,
int64_t
n
,
int64_t
k
,
const
NVTETensor
a
,
const
NVTETensor
b
,
const
NVTETensor
d
,
const
NVTETensor
bias
,
const
NVTETensor
pre_act_out
,
bool
transa
,
bool
transb
,
bool
grad
,
bool
accumulate
,
int
comm_sm_count
,
cudaStream_t
stream
)
=
0
;
template
<
typename
AType
,
typename
BType
,
typename
DType
,
typename
BiasType
>
void
Run
(
bool
transa
,
bool
transb
,
size_t
m
,
size_t
n
,
size_t
k
,
float
tol
)
{
cudaStream_t
stream
{};
NVTE_CHECK_CUDA
(
cudaStreamCreate
(
&
stream
));
constexpr
float
MAX_IN
=
1.0
;
std
::
mt19937
rng
(
12
);
std
::
uniform_real_distribution
<
float
>
dist
(
0.0
,
MAX_IN
);
float
a_scale
=
GetScale
<
AType
>
(
MAX_IN
);
float
b_scale
=
GetScale
<
BType
>
(
MAX_IN
);
float
d_scale
=
GetScale
<
DType
>
(
MAX_IN
*
MAX_IN
*
k
);
float
bias_scale
=
GetScale
<
BiasType
>
(
MAX_IN
);
std
::
vector
<
AType
>
adata
(
m
*
k
);
std
::
generate
(
adata
.
begin
(),
adata
.
end
(),
[
&
rng
,
&
dist
,
a_scale
]
{
return
static_cast
<
AType
>
(
dist
(
rng
)
*
a_scale
);
});
std
::
vector
<
BType
>
bdata
(
k
*
n
);
std
::
generate
(
bdata
.
begin
(),
bdata
.
end
(),
[
&
rng
,
&
dist
,
b_scale
]
{
return
static_cast
<
BType
>
(
dist
(
rng
)
*
b_scale
);
});
std
::
vector
<
BiasType
>
biasdata
(
m
*
n
);
std
::
generate
(
biasdata
.
begin
(),
biasdata
.
end
(),
[
&
rng
,
&
dist
,
bias_scale
]
{
return
static_cast
<
BiasType
>
(
dist
(
rng
)
*
bias_scale
);
});
auto
ga
=
transa
?
MakeFromData
<
AType
>
(
adata
,
0
,
0
,
k
,
m
,
k
,
a_scale
)
:
MakeFromData
<
AType
>
(
adata
,
0
,
0
,
m
,
k
,
m
,
a_scale
);
auto
gb
=
transb
?
MakeFromData
<
BType
>
(
bdata
,
0
,
0
,
n
,
k
,
n
,
b_scale
)
:
MakeFromData
<
BType
>
(
bdata
,
0
,
0
,
k
,
n
,
k
,
b_scale
);
auto
gbias
=
MakeFromData
<
BiasType
>
(
biasdata
,
0
,
0
,
m
,
n
,
m
,
bias_scale
);
auto
gd
=
Make
<
DType
>
(
m
,
n
,
d_scale
);
auto
gaux
=
Make
<
DType
>
(
m
,
n
,
d_scale
);
auto
dims
=
DistributeTensors
(
m
,
n
,
k
);
auto
a
=
transa
?
MakeFromData
<
AType
>
(
adata
,
dims
.
a_rows_start
,
dims
.
a_cols_start
,
dims
.
a_rows_num
,
dims
.
a_cols_num
,
k
,
a_scale
)
:
MakeFromData
<
AType
>
(
adata
,
dims
.
a_cols_start
,
dims
.
a_rows_start
,
dims
.
a_cols_num
,
dims
.
a_rows_num
,
m
,
a_scale
);
auto
b
=
transb
?
MakeFromData
<
BType
>
(
bdata
,
dims
.
b_cols_start
,
dims
.
b_rows_start
,
dims
.
b_cols_num
,
dims
.
b_rows_num
,
n
,
b_scale
)
:
MakeFromData
<
BType
>
(
bdata
,
dims
.
b_rows_start
,
dims
.
b_cols_start
,
dims
.
b_rows_num
,
dims
.
b_cols_num
,
k
,
b_scale
);
auto
bias
=
MakeFromData
<
BiasType
>
(
biasdata
,
dims
.
d_rows_start
,
dims
.
d_cols_start
,
dims
.
d_rows_num
,
dims
.
d_cols_num
,
m
,
bias_scale
);
auto
d
=
Make
<
DType
>
(
dims
.
d_rows_num
,
dims
.
d_cols_num
,
d_scale
);
auto
aux
=
Make
<
DType
>
(
dims
.
d_rows_num
,
dims
.
d_cols_num
,
d_scale
);
bool
grad
=
false
;
bool
accumulate
=
false
;
CommGemm
(
m
,
n
,
k
,
a
.
data
(),
b
.
data
(),
d
.
data
(),
bias
.
data
(),
aux
.
data
(),
transa
,
transb
,
grad
,
accumulate
,
0
/*comm_sm_count*/
,
stream
);
auto
workspace
=
Make
<
uint8_t
>
(
1
,
32
<<
20
,
1.0
);
nvte_cublas_gemm
(
ga
.
data
(),
gb
.
data
(),
gd
.
data
(),
gbias
.
data
(),
gaux
.
data
(),
transa
,
transb
,
grad
,
workspace
.
data
(),
accumulate
,
false
/* use_split_accumulator */
,
0
/* math_sm_count */
,
stream
);
NVTE_CHECK_CUDA
(
cudaStreamSynchronize
(
stream
));
NVTE_CHECK_CUDA
(
cudaStreamDestroy
(
stream
));
std
::
vector
<
DType
>
out
(
dims
.
d_rows_num
*
dims
.
d_cols_num
);
NVTE_CHECK_CUDA
(
cudaMemcpy
(
out
.
data
(),
d
.
rowwise_dptr
(),
out
.
size
()
*
sizeof
out
[
0
],
cudaMemcpyDefault
));
std
::
vector
<
DType
>
out_golden_global
(
m
*
n
);
NVTE_CHECK_CUDA
(
cudaMemcpy
(
out_golden_global
.
data
(),
gd
.
rowwise_dptr
(),
out_golden_global
.
size
()
*
sizeof
out_golden_global
[
0
],
cudaMemcpyDefault
));
auto
out_golden
=
CopyMatrix
(
out_golden_global
,
dims
.
d_rows_start
,
dims
.
d_cols_start
,
dims
.
d_rows_num
,
dims
.
d_cols_num
,
m
);
NVTE_CHECK
(
out
.
size
()
==
out_golden
.
size
());
for
(
size_t
i
=
0
;
i
<
out
.
size
();
++
i
)
{
EXPECT_NEAR
(
static_cast
<
float
>
(
out
[
i
]),
static_cast
<
float
>
(
out_golden
[
i
]),
tol
*
k
);
}
}
NVTECommGemmCtx
*
ctx_
{};
int
nranks_
{};
int
rank_
{};
ncclComm_t
comm_
{};
};
struct
AgGemm
:
public
CommGemmFixure
{
PatternDims
DistributeTensors
(
int64_t
m
,
int64_t
n
,
int64_t
k
)
override
{
auto
a_cols_num
=
nvte_comm_gemm_numroc
(
ctx_
,
m
);
auto
b_cols_num
=
nvte_comm_gemm_numroc
(
ctx_
,
n
);
int64_t
a_cols_start
{};
int64_t
b_cols_start
{};
MPI_Exscan
(
&
a_cols_num
,
&
a_cols_start
,
1
,
MPI_INT64_T
,
MPI_SUM
,
MPI_COMM_WORLD
);
MPI_Exscan
(
&
b_cols_num
,
&
b_cols_start
,
1
,
MPI_INT64_T
,
MPI_SUM
,
MPI_COMM_WORLD
);
return
PatternDims
{
.
a_rows_start
=
0
,
.
a_rows_num
=
k
,
.
a_cols_start
=
a_cols_start
,
.
a_cols_num
=
a_cols_num
,
.
b_rows_start
=
0
,
.
b_rows_num
=
k
,
.
b_cols_start
=
b_cols_start
,
.
b_cols_num
=
b_cols_num
,
.
d_rows_start
=
a_cols_start
,
.
d_rows_num
=
a_cols_num
,
.
d_cols_start
=
0
,
.
d_cols_num
=
n
,
};
}
void
CommGemm
(
int64_t
m
,
int64_t
n
,
int64_t
k
,
const
NVTETensor
a
,
const
NVTETensor
b
,
const
NVTETensor
d
,
const
NVTETensor
bias
,
const
NVTETensor
pre_act_out
,
bool
transa
,
bool
transb
,
bool
grad
,
bool
accumulate
,
int
comm_sm_count
,
cudaStream_t
stream
)
override
{
nvte_all_gather_gemm
(
ctx_
,
m
,
n
,
k
,
a
,
b
,
d
,
bias
,
pre_act_out
,
transa
,
transb
,
grad
,
accumulate
,
comm_sm_count
,
stream
,
kNVTECommGemmAlgoDefault
);
}
};
struct
GemmRs
:
public
CommGemmFixure
{
PatternDims
DistributeTensors
(
int64_t
m
,
int64_t
n
,
int64_t
k
)
override
{
auto
rows_num
=
nvte_comm_gemm_numroc
(
ctx_
,
k
);
auto
d_cols_num
=
nvte_comm_gemm_numroc
(
ctx_
,
n
);
int64_t
rows_start
{};
int64_t
d_cols_start
{};
MPI_Exscan
(
&
rows_num
,
&
rows_start
,
1
,
MPI_INT64_T
,
MPI_SUM
,
MPI_COMM_WORLD
);
MPI_Exscan
(
&
d_cols_num
,
&
d_cols_start
,
1
,
MPI_INT64_T
,
MPI_SUM
,
MPI_COMM_WORLD
);
return
PatternDims
{
.
a_rows_start
=
rows_start
,
.
a_rows_num
=
rows_num
,
.
a_cols_start
=
0
,
.
a_cols_num
=
m
,
.
b_rows_start
=
rows_start
,
.
b_rows_num
=
rows_num
,
.
b_cols_start
=
0
,
.
b_cols_num
=
n
,
.
d_rows_start
=
0
,
.
d_rows_num
=
m
,
.
d_cols_start
=
d_cols_start
,
.
d_cols_num
=
d_cols_num
,
};
}
void
CommGemm
(
int64_t
m
,
int64_t
n
,
int64_t
k
,
const
NVTETensor
a
,
const
NVTETensor
b
,
const
NVTETensor
d
,
const
NVTETensor
bias
,
const
NVTETensor
pre_act_out
,
bool
transa
,
bool
transb
,
bool
grad
,
bool
accumulate
,
int
comm_sm_count
,
cudaStream_t
stream
)
override
{
nvte_gemm_reduce_scatter
(
ctx_
,
m
,
n
,
k
,
a
,
b
,
d
,
bias
,
pre_act_out
,
transa
,
transb
,
grad
,
accumulate
,
comm_sm_count
,
stream
,
kNVTECommGemmAlgoDefault
);
}
};
struct
GemmAr
:
public
CommGemmFixure
{
PatternDims
DistributeTensors
(
int64_t
m
,
int64_t
n
,
int64_t
k
)
override
{
auto
rows_num
=
nvte_comm_gemm_numroc
(
ctx_
,
k
);
int64_t
rows_start
{};
MPI_Exscan
(
&
rows_num
,
&
rows_start
,
1
,
MPI_INT64_T
,
MPI_SUM
,
MPI_COMM_WORLD
);
return
PatternDims
{
.
a_rows_start
=
rows_start
,
.
a_rows_num
=
rows_num
,
.
a_cols_start
=
0
,
.
a_cols_num
=
m
,
.
b_rows_start
=
rows_start
,
.
b_rows_num
=
rows_num
,
.
b_cols_start
=
0
,
.
b_cols_num
=
n
,
.
d_rows_start
=
0
,
.
d_rows_num
=
m
,
.
d_cols_start
=
0
,
.
d_cols_num
=
n
,
};
}
void
CommGemm
(
int64_t
m
,
int64_t
n
,
int64_t
k
,
const
NVTETensor
a
,
const
NVTETensor
b
,
const
NVTETensor
d
,
const
NVTETensor
bias
,
const
NVTETensor
pre_act_out
,
bool
transa
,
bool
transb
,
bool
grad
,
bool
accumulate
,
int
comm_sm_count
,
cudaStream_t
stream
)
override
{
nvte_gemm_all_reduce
(
ctx_
,
m
,
n
,
k
,
a
,
b
,
d
,
bias
,
pre_act_out
,
transa
,
transb
,
grad
,
accumulate
,
comm_sm_count
,
stream
,
kNVTECommGemmAlgoDefault
);
}
void
SetUp
()
override
{
if
(
!
IsMulticastSupported
(
rank_
))
GTEST_SKIP
()
<<
"Multicast is not supported on device "
<<
rank_
;
}
};
TEST_P
(
AgGemm
,
Gemm
)
{
auto
[
a_type
,
b_type
,
d_type
,
transa
,
transb
,
m
,
n
,
k
,
tol
]
=
GetParam
();
TRANSFORMER_ENGINE_TYPE_SWITCH_OUTPUT
(
a_type
,
AType
,
TRANSFORMER_ENGINE_TYPE_SWITCH_OUTPUT
(
b_type
,
BType
,
TRANSFORMER_ENGINE_TYPE_SWITCH_OUTPUT
(
d_type
,
DType
,
Run
<
AType
,
BType
,
DType
,
DType
>
(
transa
,
transb
,
m
,
n
,
k
,
tol
);)));
}
TEST_P
(
GemmRs
,
Gemm
)
{
auto
[
a_type
,
b_type
,
d_type
,
transa
,
transb
,
m
,
n
,
k
,
tol
]
=
GetParam
();
TRANSFORMER_ENGINE_TYPE_SWITCH_OUTPUT
(
a_type
,
AType
,
TRANSFORMER_ENGINE_TYPE_SWITCH_OUTPUT
(
b_type
,
BType
,
TRANSFORMER_ENGINE_TYPE_SWITCH_OUTPUT
(
d_type
,
DType
,
Run
<
AType
,
BType
,
DType
,
DType
>
(
transa
,
transb
,
m
,
n
,
k
,
tol
);)));
}
TEST_P
(
GemmAr
,
Gemm
)
{
auto
[
a_type
,
b_type
,
d_type
,
transa
,
transb
,
m
,
n
,
k
,
tol
]
=
GetParam
();
TRANSFORMER_ENGINE_TYPE_SWITCH_OUTPUT
(
a_type
,
AType
,
TRANSFORMER_ENGINE_TYPE_SWITCH_OUTPUT
(
b_type
,
BType
,
TRANSFORMER_ENGINE_TYPE_SWITCH_OUTPUT
(
d_type
,
DType
,
Run
<
AType
,
BType
,
DType
,
DType
>
(
transa
,
transb
,
m
,
n
,
k
,
tol
);)));
}
std
::
string
ParamSuffix
(
const
testing
::
TestParamInfo
<
Params
>&
info
)
{
const
auto
[
a_type
,
b_type
,
d_type
,
transa
,
transb
,
m
,
n
,
k
,
_tol
]
=
info
.
param
;
std
::
ostringstream
ss
;
ss
<<
static_cast
<
int
>
(
a_type
)
<<
"_"
<<
static_cast
<
int
>
(
b_type
)
<<
"_"
<<
static_cast
<
int
>
(
d_type
)
<<
"_"
<<
(
transa
?
"T"
:
"N"
)
<<
(
transb
?
"T"
:
"N"
)
<<
"_"
<<
m
<<
"x"
<<
n
<<
"x"
<<
k
;
return
ss
.
str
();
}
INSTANTIATE_TEST_SUITE_P
(
AgGemm
,
AgGemm
,
testing
::
Values
(
Params
{
DType
::
kFloat16
,
DType
::
kFloat16
,
DType
::
kFloat16
,
false
,
false
,
256
,
128
,
64
,
1e-3
},
Params
{
DType
::
kFloat16
,
DType
::
kFloat16
,
DType
::
kFloat16
,
false
,
true
,
256
,
128
,
64
,
1e-3
},
Params
{
DType
::
kFloat16
,
DType
::
kFloat16
,
DType
::
kFloat16
,
true
,
false
,
256
,
128
,
64
,
1e-3
},
Params
{
DType
::
kBFloat16
,
DType
::
kBFloat16
,
DType
::
kBFloat16
,
false
,
false
,
256
,
128
,
64
,
1e-3
},
Params
{
DType
::
kBFloat16
,
DType
::
kBFloat16
,
DType
::
kBFloat16
,
false
,
true
,
256
,
128
,
64
,
1e-3
},
Params
{
DType
::
kBFloat16
,
DType
::
kBFloat16
,
DType
::
kBFloat16
,
true
,
false
,
256
,
128
,
64
,
1e-3
},
Params
{
DType
::
kFloat8E4M3
,
DType
::
kFloat8E4M3
,
DType
::
kFloat16
,
true
,
false
,
256
,
128
,
64
,
1e-3
},
Params
{
DType
::
kFloat8E4M3
,
DType
::
kFloat8E5M2
,
DType
::
kFloat16
,
true
,
false
,
256
,
128
,
64
,
1e-3
},
Params
{
DType
::
kFloat8E5M2
,
DType
::
kFloat8E4M3
,
DType
::
kFloat16
,
true
,
false
,
256
,
128
,
64
,
1e-3
}),
&
ParamSuffix
);
INSTANTIATE_TEST_SUITE_P
(
GemmRs
,
GemmRs
,
testing
::
Values
(
Params
{
DType
::
kFloat16
,
DType
::
kFloat16
,
DType
::
kFloat16
,
false
,
false
,
64
,
128
,
256
,
5e-2
},
Params
{
DType
::
kFloat16
,
DType
::
kFloat16
,
DType
::
kFloat16
,
false
,
true
,
64
,
128
,
256
,
5e-2
},
Params
{
DType
::
kFloat16
,
DType
::
kFloat16
,
DType
::
kFloat16
,
true
,
false
,
64
,
128
,
256
,
5e-2
},
Params
{
DType
::
kBFloat16
,
DType
::
kBFloat16
,
DType
::
kBFloat16
,
false
,
false
,
64
,
128
,
256
,
5e-2
},
Params
{
DType
::
kBFloat16
,
DType
::
kBFloat16
,
DType
::
kBFloat16
,
false
,
true
,
64
,
128
,
256
,
5e-2
},
Params
{
DType
::
kBFloat16
,
DType
::
kBFloat16
,
DType
::
kBFloat16
,
true
,
false
,
64
,
128
,
256
,
5e-2
},
Params
{
DType
::
kFloat8E4M3
,
DType
::
kFloat8E4M3
,
DType
::
kFloat16
,
true
,
false
,
64
,
128
,
256
,
5e-2
},
Params
{
DType
::
kFloat8E4M3
,
DType
::
kFloat8E5M2
,
DType
::
kFloat16
,
true
,
false
,
64
,
128
,
256
,
5e-2
},
Params
{
DType
::
kFloat8E5M2
,
DType
::
kFloat8E4M3
,
DType
::
kFloat16
,
true
,
false
,
64
,
128
,
256
,
5e-2
}),
&
ParamSuffix
);
INSTANTIATE_TEST_SUITE_P
(
GemmAr
,
GemmAr
,
testing
::
Values
(
Params
{
DType
::
kFloat16
,
DType
::
kFloat16
,
DType
::
kFloat16
,
true
,
false
,
64
,
64
*
4
,
64
*
4
,
5e-2
},
Params
{
DType
::
kBFloat16
,
DType
::
kBFloat16
,
DType
::
kBFloat16
,
true
,
false
,
64
,
64
*
4
,
64
*
4
,
5e-2
},
Params
{
DType
::
kFloat8E5M2
,
DType
::
kFloat8E4M3
,
DType
::
kFloat16
,
true
,
false
,
128
,
128
*
4
,
128
*
4
,
5e-2
},
Params
{
DType
::
kFloat8E4M3
,
DType
::
kFloat8E5M2
,
DType
::
kFloat16
,
true
,
false
,
128
,
128
*
4
,
128
*
4
,
5e-2
},
Params
{
DType
::
kFloat8E4M3
,
DType
::
kFloat8E4M3
,
DType
::
kFloat16
,
true
,
false
,
128
,
128
*
4
,
128
*
4
,
5e-2
}),
&
ParamSuffix
);
tests/jax/distributed_test_base.py
View file @
27ddce40
...
...
@@ -22,7 +22,7 @@ def generate_configs():
pytest
.
param
(
2
,
(
2
,),
(
"dp"
,),
MeshResource
(
dp_resource
=
"dp"
),
id
=
"n2_dp2_tp1"
)
)
configs
.
append
(
pytest
.
param
(
2
,
(
2
,),
(
"tp"
,),
MeshResource
(
tp_resource
=
"tp"
),
id
=
"n2_dp1_tp2"
)
pytest
.
param
(
2
,
(
2
,),
(
"tp
sp
"
,),
MeshResource
(
tp
sp
_resource
=
"tp
sp
"
),
id
=
"n2_dp1_tp2"
)
)
if
is_devices_enough
(
4
):
...
...
@@ -30,8 +30,8 @@ def generate_configs():
pytest
.
param
(
4
,
(
2
,
2
),
(
"dp"
,
"tp"
),
MeshResource
(
dp_resource
=
"dp"
,
tp_resource
=
"tp"
),
(
"dp"
,
"tp
sp
"
),
MeshResource
(
dp_resource
=
"dp"
,
tp
sp
_resource
=
"tp
sp
"
),
id
=
f
"n4_dp2_tp2"
,
)
)
...
...
@@ -43,8 +43,8 @@ def generate_context_parallel_configs_for_attn():
"""Generate CP combinations along with TP+DP for TestDistributedContextParallelSelfAttn only"""
configsL1
=
[]
configsL2
=
[]
mr
=
MeshResource
(
dp_resource
=
"dp"
,
cp_resource
=
"cp"
,
tp_resource
=
"tp"
)
axes
=
(
"dp"
,
"cp"
,
"tp"
)
mr
=
MeshResource
(
dp_resource
=
"dp"
,
cp_resource
=
"cp"
,
tp
sp
_resource
=
"tp
sp
"
)
axes
=
(
"dp"
,
"cp"
,
"tp
sp
"
)
DP_sizes
=
(
1
,
2
)
CP_sizes
=
(
1
,
2
,
4
,
8
)
TP_sizes
=
(
1
,
2
)
...
...
tests/jax/multi_process_launch.sh
0 → 100644
View file @
27ddce40
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
#!/bin/bash
SCRIPT_NAME
=
"
${
SCRIPT_NAME
:-
test
.py
}
"
XLA_BASE_FLAGS
=
"--xla_gpu_enable_latency_hiding_scheduler=true
--xla_gpu_enable_command_buffer=''"
export
XLA_FLAGS
=
"
${
XLA_BASE_FLAGS
}
"
NUM_RUNS
=
$(
nvidia-smi
-L
|
wc
-l
)
for
((
i
=
1
;
i<NUM_RUNS
;
i++
))
do
CUDA_VISIBLE_DEVICES
=
$i
python
$SCRIPT_NAME
127.0.0.1:12345
$i
$NUM_RUNS
>
/dev/null 2>&1 &
done
CUDA_VISIBLE_DEVICES
=
0 python
$SCRIPT_NAME
127.0.0.1:12345 0
$NUM_RUNS
wait
tests/jax/test_custom_call_compute.py
View file @
27ddce40
...
...
@@ -31,6 +31,7 @@ from transformer_engine.jax.cpp_extensions.quantization import (
from
transformer_engine.jax.cpp_extensions.misc
import
get_cudnn_version
from
transformer_engine.jax
import
cpp_extensions
as
tex
from
transformer_engine.jax.quantize
import
(
NoScaleTensor
,
ScaledTensor
,
ScaledTensor1x
,
ScaledTensor2x
,
...
...
@@ -182,7 +183,7 @@ ACTIVATION_TYPES = {
class
TestActivation
:
def
ref_act
(
self
,
x
,
activation_type
):
return
_jax_act_lu
(
x
,
activation_type
)
return
_jax_act_lu
(
x
,
activation_type
)
.
data
def
value_n_grad_ref_func
(
self
,
x
,
activation_type
):
jitted_reference
=
jit
(
...
...
@@ -337,8 +338,8 @@ class TestNorm:
ln_out
,
_
=
_jax_rmsnorm
(
x
,
gamma
,
zero_centered_gamma
,
eps
,
quantizer
)
else
:
ln_out
,
_
,
_
=
_jax_layernorm
(
x
,
gamma
,
beta
,
zero_centered_gamma
,
eps
,
quantizer
)
#
if isinstance(ln_out, ScaledTensor):
#
ln_out = ln_out.dequantize()
#
This is a no-op for non-quantized data
ln_out
=
ln_out
.
dequantize
()
return
ln_out
key
=
jax
.
random
.
PRNGKey
(
0
)
...
...
@@ -464,14 +465,23 @@ class TestNorm:
x
,
gamma
,
beta
,
zero_centered_gamma
,
epsilon
,
quantizer
=
quantizer
)
ref_out
,
ref_mu
,
ref_rsigma
=
_jax_layernorm
(
x
,
gamma
,
beta
,
zero_centered_gamma
,
epsilon
,
quantizer
=
ref_quantizer
x
,
gamma
,
beta
,
zero_centered_gamma
,
epsilon
,
quantizer
=
ref_quantizer
,
)
else
:
output
,
rsigma
=
tex
.
rmsnorm_fwd
(
x
,
gamma
,
zero_centered_gamma
,
epsilon
,
quantizer
=
quantizer
)
ref_out
,
ref_rsigma
=
_jax_rmsnorm
(
x
,
gamma
,
zero_centered_gamma
,
epsilon
,
quantizer
=
ref_quantizer
x
,
gamma
,
zero_centered_gamma
,
epsilon
,
quantizer
=
ref_quantizer
,
)
ref_mu
=
None
...
...
@@ -765,7 +775,9 @@ class TestFusedQuantize:
te_output
,
jax_output
,
precise_comparison
=
precise_comparison
)
else
:
assert_allclose
(
te_output
,
jax_output
)
assert
isinstance
(
te_output
,
NoScaleTensor
)
assert
isinstance
(
jax_output
,
NoScaleTensor
)
assert_allclose
(
te_output
.
data
,
jax_output
.
data
)
if
is_dbias
:
# TE kernels cast the intermediate results to the input dtype which reduces precision compared to the JAX implementation, for dbias this typically only affects bfloat16.
...
...
@@ -1020,7 +1032,6 @@ def _ref_jax_norm_impl(x, gamma, beta, norm_type, zero_centered_gamma, eps, quan
ln_out
,
_
=
_jax_rmsnorm
(
x
,
gamma
,
zero_centered_gamma
,
eps
,
quantizer
)
else
:
ln_out
,
_
,
_
=
_jax_layernorm
(
x
,
gamma
,
beta
,
zero_centered_gamma
,
eps
,
quantizer
)
if
isinstance
(
ln_out
,
ScaledTensor
):
ln_out
=
ln_out
.
dequantize
()
return
ln_out
...
...
@@ -1177,7 +1188,7 @@ class TestFusedDense:
bias_1_shape
=
(
1
,)
*
(
linear_1_out
.
ndim
-
bias_1
.
ndim
)
+
bias_1
.
shape
linear_1_out
+=
jnp
.
reshape
(
bias_1
,
bias_1_shape
)
x
=
_jax_act_lu
(
linear_1_out
,
activation_type
)
x
=
_jax_act_lu
(
linear_1_out
,
activation_type
)
.
data
linear_2_out
=
jax
.
lax
.
dot_general
(
x
,
kernel_2
,
(((
1
,),
(
0
,)),
((),
())))
if
use_bias
:
bias_2_shape
=
(
1
,)
*
(
linear_2_out
.
ndim
-
bias_2
.
ndim
)
+
bias_2
.
shape
...
...
tests/jax/test_distributed_fused_attn.py
View file @
27ddce40
...
...
@@ -45,8 +45,8 @@ class TestDistributedSelfAttn:
_
,
seqlen
,
heads
,
_
=
shape
is_dp_enabled
=
mesh_resource
.
dp_resource
is
not
None
tp_size
=
1
if
mesh_resource
.
tp_resource
is
not
None
:
idx
=
mesh_axes
.
index
(
mesh_resource
.
tp_resource
)
if
mesh_resource
.
tp
sp
_resource
is
not
None
:
idx
=
mesh_axes
.
index
(
mesh_resource
.
tp
sp
_resource
)
tp_size
=
mesh_shape
[
idx
]
all_reduce_loss_bytes
=
4
# 1 * FP32
...
...
tests/jax/test_distributed_helper.py
0 → 100644
View file @
27ddce40
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
import
unittest
import
jax
import
numpy
as
np
from
utils
import
pytest_parametrize_wrapper
,
is_devices_enough
from
transformer_engine.jax.sharding
import
MeshResource
,
global_mesh_resource
from
transformer_engine.jax
import
fp8_autocast
def
generate_mesh_configs
():
configs
=
[]
if
is_devices_enough
(
2
):
configs
.
append
(
[
2
,
(
1
,
2
),
(
"dp"
,
"tpsp"
),
MeshResource
(
dp_resource
=
"dp"
,
tpsp_resource
=
"tpsp"
)]
)
if
is_devices_enough
(
4
):
configs
.
append
(
[
4
,
(
2
,
2
),
(
"fsdp"
,
"tp"
),
MeshResource
(
tp_resource
=
"tp"
,
fsdp_resource
=
"fsdp"
)]
)
return
configs
class
TestMeshResource
(
unittest
.
TestCase
):
def
test_fp8_autocast_with_mesh_resource
(
self
):
for
mesh_config
in
generate_mesh_configs
():
device_count
,
mesh_shape
,
mesh_axes
,
mesh_resource
=
mesh_config
devices
=
np
.
asarray
(
jax
.
devices
()[:
device_count
]).
reshape
(
*
mesh_shape
)
mesh
=
jax
.
sharding
.
Mesh
(
devices
,
mesh_axes
)
with
mesh
,
fp8_autocast
(
enabled
=
False
,
mesh_resource
=
mesh_resource
):
self
.
assertEqual
(
mesh_resource
,
global_mesh_resource
())
tests/jax/test_distributed_layernorm_mlp.py
View file @
27ddce40
...
...
@@ -62,16 +62,16 @@ BIAS_2_AXES = (W_NO_SHARD_AXES,)
INTERMEDIATE
=
64
# Only test with FSDP and TP as DP is not used
def
generate_fsdp_and_tp_configs
():
# Only test with FSDP and TP
SP
as DP is not used
def
generate_fsdp_and_tp
sp
_configs
():
configs
=
[]
if
is_devices_enough
(
2
):
configs
.
append
(
[
2
,
(
1
,
2
),
(
"fsdp"
,
"tp"
),
MeshResource
(
fsdp_resource
=
"fsdp"
,
tp_resource
=
"tp"
)]
[
2
,
(
1
,
2
),
(
"fsdp"
,
"tp
sp
"
),
MeshResource
(
fsdp_resource
=
"fsdp"
,
tp
sp
_resource
=
"tp
sp
"
)]
)
if
is_devices_enough
(
4
):
configs
.
append
(
[
4
,
(
2
,
2
),
(
"fsdp"
,
"tp"
),
MeshResource
(
fsdp_resource
=
"fsdp"
,
tp_resource
=
"tp"
)]
[
4
,
(
2
,
2
),
(
"fsdp"
,
"tp
sp
"
),
MeshResource
(
fsdp_resource
=
"fsdp"
,
tp
sp
_resource
=
"tp
sp
"
)]
)
return
configs
...
...
@@ -173,7 +173,9 @@ class TestDistributedLayernormMLP:
)
# Single GPU
with
fp8_autocast
(
enabled
=
True
,
fp8_recipe
=
fp8_recipe
):
with
fp8_autocast
(
enabled
=
fp8_recipe
is
not
None
,
fp8_recipe
=
fp8_recipe
,
mesh_resource
=
MeshResource
()
):
single_jitter
=
jax
.
jit
(
value_and_grad_func
,
static_argnums
=
range
(
len
(
inputs
),
len
(
static_inputs
)
+
len
(
inputs
)),
...
...
@@ -184,14 +186,14 @@ class TestDistributedLayernormMLP:
devices
=
np
.
asarray
(
jax
.
devices
()[:
device_count
]).
reshape
(
*
mesh_shape
)
mesh
=
Mesh
(
devices
,
mesh_axes
)
with
mesh
,
fp8_autocast
(
enabled
=
Tru
e
,
fp8_recipe
=
fp8_recipe
,
mesh_resource
=
mesh_resource
enabled
=
fp8_recipe
is
not
Non
e
,
fp8_recipe
=
fp8_recipe
,
mesh_resource
=
mesh_resource
):
k1_sharding
=
NamedSharding
(
mesh
,
PartitionSpec
(
"fsdp"
,
None
,
"tp"
))
k2_sharding
=
NamedSharding
(
mesh
,
PartitionSpec
(
"tp"
,
"fsdp"
))
k1_sharding
=
NamedSharding
(
mesh
,
PartitionSpec
(
"fsdp"
,
None
,
"tp
sp
"
))
k2_sharding
=
NamedSharding
(
mesh
,
PartitionSpec
(
"tp
sp
"
,
"fsdp"
))
k1_
=
jax
.
device_put
(
k1
,
k1_sharding
)
k2_
=
jax
.
device_put
(
k2
,
k2_sharding
)
if
use_bias
:
b1_sharding
=
NamedSharding
(
mesh
,
PartitionSpec
(
None
,
"tp"
))
b1_sharding
=
NamedSharding
(
mesh
,
PartitionSpec
(
None
,
"tp
sp
"
))
b1_
=
jax
.
device_put
(
b1
,
b1_sharding
)
else
:
b1_sharding
=
b1_
=
None
...
...
@@ -226,7 +228,12 @@ class TestDistributedLayernormMLP:
fwd_test_type
=
dtype
if
fp8_recipe
is
None
else
jnp
.
float8_e4m3fn
bwd_test_type
=
dtype
if
fp8_recipe
is
None
else
jnp
.
float8_e5m2
if
fwd_test_type
==
jnp
.
float16
and
use_bias
:
assert_allclose
(
multi_fwd
,
single_fwd
,
dtype
=
fwd_test_type
,
atol
=
0.04
,
rtol
=
1.5
)
else
:
assert_allclose
(
multi_fwd
,
single_fwd
,
dtype
=
fwd_test_type
)
for
i
in
range
(
len
(
inputs
)):
if
multi_grads
[
i
]
is
not
None
:
if
isinstance
(
multi_grads
[
i
],
list
):
...
...
@@ -247,12 +254,12 @@ class TestDistributedLayernormMLP:
)
@
pytest
.
mark
.
skipif
(
not
is_fp8_supported
,
reason
=
reason
)
@
pytest_parametrize_wrapper
(
"mesh_config"
,
generate_fsdp_and_tp_configs
())
@
pytest_parametrize_wrapper
(
"mesh_config"
,
generate_fsdp_and_tp
sp
_configs
())
@
pytest_parametrize_wrapper
(
"input_shape"
,
INPUT_SHAPE
)
@
pytest_parametrize_wrapper
(
"activation_type"
,
[(
"gelu"
,),
(
"gelu"
,
"linear"
)])
@
pytest_parametrize_wrapper
(
"dtype"
,
DTYPES
)
@
pytest_parametrize_wrapper
(
"use_bias"
,
[
True
,
False
])
@
pytest_parametrize_wrapper
(
"fp8_recipe"
,
SUPPORTED_RECIPES
)
@
pytest_parametrize_wrapper
(
"fp8_recipe"
,
[
None
]
+
SUPPORTED_RECIPES
)
@
pytest_parametrize_wrapper
(
"with_jax_gemm"
,
[
False
,
True
])
def
test_layernorm_mlp_grad
(
self
,
...
...
@@ -276,12 +283,12 @@ class TestDistributedLayernormMLP:
)
@
pytest
.
mark
.
skipif
(
not
is_fp8_supported
,
reason
=
reason
)
@
pytest_parametrize_wrapper
(
"mesh_config"
,
generate_fsdp_and_tp_configs
())
@
pytest_parametrize_wrapper
(
"mesh_config"
,
generate_fsdp_and_tp
sp
_configs
())
@
pytest_parametrize_wrapper
(
"input_shape"
,
INPUT_SHAPE
)
@
pytest_parametrize_wrapper
(
"activation_type"
,
[(
"gelu"
,),
(
"gelu"
,
"linear"
)])
@
pytest_parametrize_wrapper
(
"dtype"
,
DTYPES
)
@
pytest_parametrize_wrapper
(
"use_bias"
,
[
True
,
False
])
@
pytest_parametrize_wrapper
(
"fp8_recipe"
,
SUPPORTED_RECIPES
)
@
pytest_parametrize_wrapper
(
"fp8_recipe"
,
[
None
]
+
SUPPORTED_RECIPES
)
@
pytest_parametrize_wrapper
(
"with_jax_gemm"
,
[
False
,
True
])
def
test_layernorm_mlp_grad_shardy
(
self
,
...
...
@@ -330,7 +337,7 @@ class TestDistributedLayernormMLP:
with
use_jax_gemm
(
enabled
=
with_jax_gemm
):
# Single GPUs
with
fp8_autocast
(
enabled
=
use_fp8
,
fp8_recipe
=
fp8_recipe
):
with
fp8_autocast
(
enabled
=
use_fp8
,
fp8_recipe
=
fp8_recipe
,
mesh_resource
=
MeshResource
()
):
ln_mlp_single
=
LayerNormMLP
(
layernorm_type
=
layernorm_type
,
intermediate_dim
=
INTERMEDIATE
,
...
...
@@ -408,7 +415,7 @@ class TestDistributedLayernormMLP:
assert_allclose
(
mlp_out_sharded
,
mlp_out_single
,
dtype
=
dtype
,
atol
=
atol
,
rtol
=
rtol
)
@
pytest_parametrize_wrapper
(
"input_shape"
,
INPUT_SHAPE
)
@
pytest_parametrize_wrapper
(
"mesh_config"
,
generate_fsdp_and_tp_configs
())
@
pytest_parametrize_wrapper
(
"mesh_config"
,
generate_fsdp_and_tp
sp
_configs
())
@
pytest_parametrize_wrapper
(
"activation_type"
,
[(
"gelu"
,),
(
"silu"
,
"linear"
)])
@
pytest_parametrize_wrapper
(
"dtype"
,
DTYPES
)
@
pytest_parametrize_wrapper
(
"use_bias"
,
[
True
,
False
])
...
...
@@ -429,7 +436,7 @@ class TestDistributedLayernormMLP:
)
@
pytest
.
mark
.
skipif
(
not
is_fp8_supported
,
reason
=
reason
)
@
pytest_parametrize_wrapper
(
"mesh_config"
,
generate_fsdp_and_tp_configs
())
@
pytest_parametrize_wrapper
(
"mesh_config"
,
generate_fsdp_and_tp
sp
_configs
())
@
pytest_parametrize_wrapper
(
"activation_type"
,
[(
"gelu"
,),
(
"gelu"
,
"linear"
)])
@
pytest_parametrize_wrapper
(
"use_bias"
,
[
True
,
False
])
@
pytest_parametrize_wrapper
(
"input_shape"
,
INPUT_SHAPE
)
...
...
@@ -452,7 +459,7 @@ class TestDistributedLayernormMLP:
)
@
pytest_parametrize_wrapper
(
"input_shape"
,
INPUT_SHAPE
)
@
pytest_parametrize_wrapper
(
"mesh_config"
,
generate_fsdp_and_tp_configs
())
@
pytest_parametrize_wrapper
(
"mesh_config"
,
generate_fsdp_and_tp
sp
_configs
())
@
pytest_parametrize_wrapper
(
"activation_type"
,
[(
"gelu"
,),
(
"silu"
,
"linear"
)])
@
pytest_parametrize_wrapper
(
"dtype"
,
DTYPES
)
@
pytest_parametrize_wrapper
(
"use_bias"
,
[
True
,
False
])
...
...
@@ -473,7 +480,7 @@ class TestDistributedLayernormMLP:
)
@
pytest
.
mark
.
skipif
(
not
is_fp8_supported
,
reason
=
reason
)
@
pytest_parametrize_wrapper
(
"mesh_config"
,
generate_fsdp_and_tp_configs
())
@
pytest_parametrize_wrapper
(
"mesh_config"
,
generate_fsdp_and_tp
sp
_configs
())
@
pytest_parametrize_wrapper
(
"activation_type"
,
[(
"gelu"
,),
(
"gelu"
,
"linear"
)])
@
pytest_parametrize_wrapper
(
"use_bias"
,
[
True
,
False
])
@
pytest_parametrize_wrapper
(
"input_shape"
,
INPUT_SHAPE
)
...
...
tests/jax/test_distributed_softmax.py
View file @
27ddce40
...
...
@@ -41,11 +41,11 @@ class TestDistributedSoftmax:
if
not
bad_sharding
:
x_pspec
=
PartitionSpec
(
mesh_resource
.
dp_resource
,
mesh_resource
.
tp_resource
,
None
,
None
mesh_resource
.
dp_resource
,
mesh_resource
.
tp
sp
_resource
,
None
,
None
)
else
:
x_pspec
=
PartitionSpec
(
mesh_resource
.
dp_resource
,
None
,
None
,
mesh_resource
.
tp_resource
mesh_resource
.
dp_resource
,
None
,
None
,
mesh_resource
.
tp
sp
_resource
)
if
broadcast_batch_mask
:
...
...
tests/jax/test_fused_attn.py
View file @
27ddce40
...
...
@@ -41,6 +41,7 @@ from transformer_engine.jax.cpp_extensions import FusedAttnHelper
from
transformer_engine_jax
import
(
NVTE_Fused_Attn_Backend
,
get_cudnn_version
,
get_device_compute_capability
,
)
from
distributed_test_base
import
assert_equal_collectives
...
...
@@ -348,6 +349,14 @@ class FusedAttnRunner:
"seqlen_q > seqlen_kv is not supported with sliding window attention in cuDNN"
)
if
(
get_device_compute_capability
(
0
)
==
100
and
self
.
dropout_prob
==
0.1
and
self
.
attn_bias_type
is
not
AttnBiasType
.
NO_BIAS
):
pytest
.
skip
(
"For sm100, bprop kernel support for dropout + determinism (bias) is not supported"
)
# Test the MLA case where head dims for qk differ from head dims for v, only if the tensors
# are provided in BSHD_BSHD_BSHD or THD_THD_THD formats
if
self
.
head_dim_qk
!=
self
.
head_dim_v
and
not
self
.
qkv_layout
.
is_separate
():
...
...
@@ -397,7 +406,7 @@ class FusedAttnRunner:
self
.
mesh
=
Mesh
(
self
.
devices
,
self
.
mesh_axes
)
self
.
dp_size
=
self
.
mesh
.
shape
.
get
(
self
.
mesh_resource
.
dp_resource
,
1
)
self
.
cp_size
=
self
.
mesh
.
shape
.
get
(
self
.
mesh_resource
.
cp_resource
,
1
)
self
.
tp_size
=
self
.
mesh
.
shape
.
get
(
self
.
mesh_resource
.
tp_resource
,
1
)
self
.
tp_size
=
self
.
mesh
.
shape
.
get
(
self
.
mesh_resource
.
tp
sp
_resource
,
1
)
key
=
jax
.
random
.
PRNGKey
(
0
)
q_key
,
k_key
,
v_key
,
bias_key
,
dropout_key
=
jax
.
random
.
split
(
key
,
5
)
...
...
@@ -630,7 +639,7 @@ class FusedAttnRunner:
self
.
qkvo_psec
=
PartitionSpec
(
self
.
mesh_resource
.
dp_resource
,
self
.
mesh_resource
.
cp_resource
,
self
.
mesh_resource
.
tp_resource
,
self
.
mesh_resource
.
tp
sp
_resource
,
None
,
)
self
.
qkvo_sharding
=
NamedSharding
(
self
.
mesh
,
self
.
qkvo_psec
)
...
...
@@ -658,7 +667,7 @@ class FusedAttnRunner:
if
self
.
bias_shape
==
BiasShape
.
_1HSS
:
self
.
bias_pspec
=
PartitionSpec
(
None
,
self
.
mesh_resource
.
tp_resource
,
self
.
mesh_resource
.
cp_resource
,
None
None
,
self
.
mesh_resource
.
tp
sp
_resource
,
self
.
mesh_resource
.
cp_resource
,
None
)
elif
self
.
bias_shape
==
BiasShape
.
_B1SS
:
self
.
bias_pspec
=
PartitionSpec
(
...
...
tests/jax/test_helper.py
View file @
27ddce40
...
...
@@ -14,10 +14,11 @@ from transformer_engine.common.recipe import DelayedScaling, MXFP8BlockScaling,
from
transformer_engine.common.recipe
import
Format
as
FP8Format
from
transformer_engine.jax
import
fp8_autocast
,
get_delayed_scaling
from
transformer_engine.jax.quantize
import
(
Q
uantize
C
onfig
,
get_q
uantize
_c
onfig
,
is_fp8_available
,
ScalingMode
,
update_collections
,
TensorSource
,
)
from
transformer_engine.jax.sharding
import
MeshResource
,
global_mesh_resource
...
...
@@ -49,7 +50,7 @@ class TestHelper(unittest.TestCase):
class
TestFP8Functions
(
unittest
.
TestCase
):
def
_check_default_state
(
self
):
self
.
assertFalse
(
Q
uantize
C
onfig
.
is_fp8_enabled
())
self
.
assertFalse
(
get_q
uantize
_c
onfig
()
.
is_fp8_enabled
())
def
_compare_delay_scaling
(
self
,
ref
,
test
):
self
.
assertTrue
(
ref
.
margin
==
test
.
margin
)
...
...
@@ -58,107 +59,90 @@ class TestFP8Functions(unittest.TestCase):
self
.
assertTrue
(
ref
.
amax_compute_algo
==
test
.
amax_compute_algo
)
def
_compare_current_scaling
(
self
,
test
):
self
.
assertEqual
(
QuantizeConfig
.
FP8_FORMAT
,
test
.
fp8_format
)
self
.
assertEqual
(
QuantizeConfig
.
SCALING_MODE
,
ScalingMode
.
CURRENT_TENSOR_SCALING
)
self
.
assertEqual
(
get_quantize_config
().
FP8_FORMAT
,
test
.
fp8_format
)
for
tensor_source
in
TensorSource
:
self
.
assertEqual
(
get_quantize_config
().
get_scaling_mode
(
tensor_source
),
ScalingMode
.
CURRENT_TENSOR_SCALING
,
)
def
_compare_mxfp8_scaling
(
self
,
test
):
self
.
assertEqual
(
QuantizeConfig
.
MARGIN
,
test
.
margin
)
self
.
assertEqual
(
QuantizeConfig
.
FP8_FORMAT
,
test
.
fp8_format
)
self
.
assertEqual
(
QuantizeConfig
.
SCALING_MODE
,
ScalingMode
.
MXFP8_1D_SCALING
)
self
.
assertEqual
(
get_quantize_config
().
MARGIN
,
test
.
margin
)
self
.
assertEqual
(
get_quantize_config
().
FP8_FORMAT
,
test
.
fp8_format
)
for
tensor_source
in
TensorSource
:
self
.
assertEqual
(
get_quantize_config
().
get_scaling_mode
(
tensor_source
),
ScalingMode
.
MXFP8_1D_SCALING
)
@
unittest
.
skipIf
(
not
is_fp8_supported
,
reason
=
reason
)
def
test_fp8_autocast_delayed_scaling
(
self
):
QuantizeConfig
.
finalize
()
# Ensure the testing not affect by previous tests.
self
.
_check_default_state
()
with
fp8_autocast
(
enabled
=
False
,
fp8_recipe
=
DelayedScaling
()):
with
fp8_autocast
(
enabled
=
False
,
fp8_recipe
=
DelayedScaling
()
,
mesh_resource
=
MeshResource
()
):
self
.
_check_default_state
()
self
.
_check_default_state
()
ds
=
DelayedScaling
(
margin
=
5.0
,
fp8_format
=
FP8Format
.
E4M3
,
amax_history_len
=
1
)
with
fp8_autocast
(
enabled
=
True
,
fp8_recipe
=
ds
):
self
.
assertTrue
(
Q
uantize
C
onfig
.
is_fp8_enabled
())
with
fp8_autocast
(
enabled
=
True
,
fp8_recipe
=
ds
,
mesh_resource
=
MeshResource
()
):
self
.
assertTrue
(
get_q
uantize
_c
onfig
()
.
is_fp8_enabled
())
self
.
_compare_delay_scaling
(
get_delayed_scaling
(),
ds
)
self
.
_check_default_state
()
ds
=
DelayedScaling
(
margin
=
3.0
,
fp8_format
=
FP8Format
.
HYBRID
,
amax_history_len
=
1
)
with
fp8_autocast
(
enabled
=
True
,
fp8_recipe
=
ds
):
self
.
assertTrue
(
Q
uantize
C
onfig
.
is_fp8_enabled
())
with
fp8_autocast
(
enabled
=
True
,
fp8_recipe
=
ds
,
mesh_resource
=
MeshResource
()
):
self
.
assertTrue
(
get_q
uantize
_c
onfig
()
.
is_fp8_enabled
())
self
.
_compare_delay_scaling
(
get_delayed_scaling
(),
ds
)
self
.
_check_default_state
()
@
unittest
.
skipIf
(
not
is_fp8_supported
,
reason
=
reason
)
def
test_fp8_autocast_current_scaling
(
self
):
QuantizeConfig
.
finalize
()
# Ensure the testing not affect by previous tests.
self
.
_check_default_state
()
with
fp8_autocast
(
enabled
=
False
,
fp8_recipe
=
Float8CurrentScaling
()):
with
fp8_autocast
(
enabled
=
False
,
fp8_recipe
=
Float8CurrentScaling
(),
mesh_resource
=
MeshResource
()
):
self
.
_check_default_state
()
self
.
_check_default_state
()
cs
=
Float8CurrentScaling
(
fp8_format
=
FP8Format
.
E4M3
)
with
fp8_autocast
(
enabled
=
True
,
fp8_recipe
=
cs
):
self
.
assertTrue
(
Q
uantize
C
onfig
.
is_fp8_enabled
())
with
fp8_autocast
(
enabled
=
True
,
fp8_recipe
=
cs
,
mesh_resource
=
MeshResource
()
):
self
.
assertTrue
(
get_q
uantize
_c
onfig
()
.
is_fp8_enabled
())
self
.
_compare_current_scaling
(
cs
)
self
.
_check_default_state
()
cs
=
Float8CurrentScaling
(
fp8_format
=
FP8Format
.
HYBRID
)
with
fp8_autocast
(
enabled
=
True
,
fp8_recipe
=
cs
):
self
.
assertTrue
(
Q
uantize
C
onfig
.
is_fp8_enabled
())
with
fp8_autocast
(
enabled
=
True
,
fp8_recipe
=
cs
,
mesh_resource
=
MeshResource
()
):
self
.
assertTrue
(
get_q
uantize
_c
onfig
()
.
is_fp8_enabled
())
self
.
_compare_current_scaling
(
cs
)
self
.
_check_default_state
()
@
unittest
.
skipIf
(
not
is_mxfp8_supported
,
reason
=
mxfp8_reason
)
def
test_fp8_autocast_mxfp8_block_scaling
(
self
):
QuantizeConfig
.
finalize
()
# Ensure the testing not affect by previous tests.
self
.
_check_default_state
()
with
fp8_autocast
(
enabled
=
False
,
fp8_recipe
=
MXFP8BlockScaling
()):
with
fp8_autocast
(
enabled
=
False
,
fp8_recipe
=
MXFP8BlockScaling
(),
mesh_resource
=
MeshResource
()
):
self
.
_check_default_state
()
self
.
_check_default_state
()
bs
=
MXFP8BlockScaling
(
margin
=
5.0
,
fp8_format
=
FP8Format
.
E4M3
)
with
fp8_autocast
(
enabled
=
True
,
fp8_recipe
=
bs
):
self
.
assertTrue
(
Q
uantize
C
onfig
.
is_fp8_enabled
())
with
fp8_autocast
(
enabled
=
True
,
fp8_recipe
=
bs
,
mesh_resource
=
MeshResource
()
):
self
.
assertTrue
(
get_q
uantize
_c
onfig
()
.
is_fp8_enabled
())
self
.
_compare_mxfp8_scaling
(
bs
)
self
.
_check_default_state
()
bs
=
MXFP8BlockScaling
(
margin
=
3.0
,
fp8_format
=
FP8Format
.
HYBRID
)
with
fp8_autocast
(
enabled
=
True
,
fp8_recipe
=
bs
):
self
.
assertTrue
(
Q
uantize
C
onfig
.
is_fp8_enabled
())
with
fp8_autocast
(
enabled
=
True
,
fp8_recipe
=
bs
,
mesh_resource
=
MeshResource
()
):
self
.
assertTrue
(
get_q
uantize
_c
onfig
()
.
is_fp8_enabled
())
self
.
_compare_mxfp8_scaling
(
bs
)
self
.
_check_default_state
()
@
unittest
.
skipIf
(
not
is_fp8_supported
,
reason
=
reason
)
def
test_fp8_autocast_with_sharding_resource
(
self
):
QuantizeConfig
.
finalize
()
# Ensure the testing not affect by previous tests.
self
.
_check_default_state
()
ds
=
DelayedScaling
(
margin
=
5.0
,
fp8_format
=
FP8Format
.
E4M3
,
amax_history_len
=
1
)
mesh_s
=
(
(
MeshResource
(
None
,
None
)),
(
MeshResource
(
"dp"
,
None
)),
(
MeshResource
(
None
,
"tp"
)),
(
MeshResource
(
"dp"
,
"tp"
)),
)
# TODO (Ming Huang): Support multi-GPUs testing. # pylint: disable=fixme
mesh_shape
=
(
1
,
1
)
devices
=
np
.
asarray
(
jax
.
devices
()[:
1
]).
reshape
(
*
mesh_shape
)
with
jax
.
sharding
.
Mesh
(
devices
,
(
"dp"
,
"tp"
)):
for
sr
in
mesh_s
:
with
fp8_autocast
(
enabled
=
True
,
fp8_recipe
=
ds
,
mesh_resource
=
sr
):
self
.
assertTrue
(
QuantizeConfig
.
is_fp8_enabled
())
self
.
_compare_delay_scaling
(
get_delayed_scaling
(),
ds
)
self
.
assertEqual
(
sr
,
global_mesh_resource
())
self
.
_check_default_state
()
tests/jax/test_layer.py
View file @
27ddce40
...
...
@@ -23,11 +23,14 @@ from utils import EncoderLayer as RefEncoderLayer
from
transformer_engine.common
import
recipe
from
transformer_engine.jax.flax
import
TransformerLayer
,
TransformerLayerType
from
transformer_engine.jax.quantize
import
(
Q
uantize
C
onfig
,
get_q
uantize
_c
onfig
,
ScalingMode
,
is_fp8_available
,
update_collections
,
TensorSource
,
fp8_autocast
,
)
from
transformer_engine.jax.sharding
import
MeshResource
@
pytest
.
fixture
(
autouse
=
True
,
scope
=
"function"
)
...
...
@@ -262,6 +265,16 @@ ATTRS = [
_KEY_OF_RELATIVE_EMBEDDING
:
False
,
_KEY_OF_WINDOW_SIZE
:
(
2
,
2
),
},
# attrs29
{
_KEY_OF_RELATIVE_EMBEDDING
:
True
,
_KEY_OF_SELF_ATTN_BIAS_TYPE
:
"pre_scale_bias"
,
},
# attrs30
{
_KEY_OF_RELATIVE_EMBEDDING
:
True
,
_KEY_OF_SELF_ATTN_BIAS_TYPE
:
"post_scale_bias"
,
},
]
ATTRS
=
[{
**
BASE_ATTRS
,
**
attr
}
for
attr
in
ATTRS
]
...
...
@@ -345,7 +358,7 @@ class BaseRunner:
ref_params
,
test_params
=
self
.
_sync_params
(
ref_params
,
test_params
)
if
Q
uantize
C
onfig
.
is_fp8_enabled
():
if
get_q
uantize
_c
onfig
()
.
is_fp8_enabled
():
for
_
in
range
(
4
):
_
,
updated_state
=
jax
.
value_and_grad
(
self
.
_loss_fn
,
argnums
=
(
3
,),
has_aux
=
False
)(
inputs
,
...
...
@@ -354,12 +367,15 @@ class BaseRunner:
test_others
,
test_layer
,
)
if
QuantizeConfig
.
SCALING_MODE
==
ScalingMode
.
DELAYED_TENSOR_SCALING
:
if
(
get_quantize_config
().
get_scaling_mode
(
TensorSource
.
X
)
==
ScalingMode
.
DELAYED_TENSOR_SCALING
):
_
,
updated_quantize_meta
=
flax
.
core
.
pop
(
updated_state
[
0
],
Q
uantize
C
onfig
.
COLLECTION_NAME
updated_state
[
0
],
get_q
uantize
_c
onfig
()
.
COLLECTION_NAME
)
test_others
=
update_collections
(
{
Q
uantize
C
onfig
.
COLLECTION_NAME
:
updated_quantize_meta
},
test_others
{
get_q
uantize
_c
onfig
()
.
COLLECTION_NAME
:
updated_quantize_meta
},
test_others
)
del
updated_quantize_meta
del
updated_state
...
...
@@ -489,29 +505,33 @@ class BaseTester:
def
test_forward
(
self
,
data_shape
,
dtype
,
attrs
):
"""Test normal datatype forward"""
QuantizeConfig
.
finalize
()
# Ensure FP8 disabled.
# Ensure FP8 disabled.
# Empty MeshResource is used as we are running on a single device
with
fp8_autocast
(
enabled
=
False
,
mesh_resource
=
MeshResource
()):
self
.
runner
(
attrs
).
test_forward
(
data_shape
,
dtype
)
def
test_backward
(
self
,
data_shape
,
dtype
,
attrs
):
"""Test normal datatype backward"""
QuantizeConfig
.
finalize
()
# Ensure FP8 disabled.
# Ensure FP8 disabled.
# Empty MeshResource is used as we are running on a single device
with
fp8_autocast
(
enabled
=
False
,
mesh_resource
=
MeshResource
()):
self
.
runner
(
attrs
).
test_backward
(
data_shape
,
dtype
)
@
pytest
.
mark
.
skipif
(
not
is_fp8_supported
,
reason
=
reason
)
@
pytest
.
mark
.
parametrize
(
"fp8_recipe"
,
QUANTIZE_RECIPES
)
def
test_forward_with_fp8
(
self
,
data_shape
,
dtype
,
attrs
,
fp8_recipe
):
"""Test forward with fp8 enabled"""
QuantizeConfig
.
initialize
(
fp8_recipe
=
fp8_recipe
)
# Empty MeshResource is used as we are running on a single device
with
fp8_autocast
(
enabled
=
True
,
fp8_recipe
=
fp8_recipe
,
mesh_resource
=
MeshResource
()):
self
.
runner
(
attrs
).
test_forward
(
data_shape
,
dtype
,
rtol
=
1e-4
,
atol
=
1e-3
)
QuantizeConfig
.
finalize
()
@
pytest
.
mark
.
skipif
(
not
is_fp8_supported
,
reason
=
reason
)
@
pytest
.
mark
.
parametrize
(
"fp8_recipe"
,
QUANTIZE_RECIPES
)
def
test_backward_with_fp8
(
self
,
data_shape
,
dtype
,
attrs
,
fp8_recipe
):
"""Test backward with fp8 enabled"""
QuantizeConfig
.
initialize
(
fp8_recipe
=
fp8_recipe
)
# Empty MeshResource is used as we are running on a single device
with
fp8_autocast
(
enabled
=
True
,
fp8_recipe
=
fp8_recipe
,
mesh_resource
=
MeshResource
()):
self
.
runner
(
attrs
).
test_backward
(
data_shape
,
dtype
,
rtol
=
1e-4
,
atol
=
1e-3
)
QuantizeConfig
.
finalize
()
class
TestEncoderLayer
(
BaseTester
):
...
...
tests/jax/test_multi_process_distributed_grouped_gemm.py
0 → 100644
View file @
27ddce40
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
from
functools
import
partial
import
jax
import
jax.numpy
as
jnp
import
jax.experimental.multihost_utils
as
jem
from
transformer_engine.jax.dense
import
grouped_dense
as
te_grouped_dense
from
transformer_engine.jax.quantize
import
(
QuantizerFactory
,
ScalingMode
,
)
from
utils
import
assert_allclose
,
dtype_tols
N_GROUP
=
8
MESH_AXIS_NAME
=
"fsdp"
def
test_grouped_gemm_fp8_allgather
(
data_shapes
,
kernel_fsdp_axis
):
assert
kernel_fsdp_axis
in
[
1
,
2
]
x_shape
,
w_shape
=
data_shapes
x_sharding
=
NamedSharding
(
mesh
,
PartitionSpec
(
None
,
MESH_AXIS_NAME
,
None
,
None
,
None
))
w_sharding
=
(
NamedSharding
(
mesh
,
PartitionSpec
(
None
,
None
,
MESH_AXIS_NAME
))
if
kernel_fsdp_axis
==
2
else
NamedSharding
(
mesh
,
PartitionSpec
(
None
,
MESH_AXIS_NAME
,
None
))
)
w_no_sharding
=
NamedSharding
(
mesh
,
PartitionSpec
(
None
,
None
,
None
))
def
init_data
():
x_key
=
jax
.
random
.
PRNGKey
(
0
)
w_key
=
jax
.
random
.
PRNGKey
(
1
)
x
=
jax
.
random
.
normal
(
x_key
,
shape
=
(
N_GROUP
,
*
x_shape
),
dtype
=
jnp
.
bfloat16
)
w
=
jax
.
random
.
normal
(
w_key
,
shape
=
(
N_GROUP
,
*
w_shape
),
dtype
=
jnp
.
bfloat16
)
w_amax
=
jnp
.
max
(
jnp
.
abs
(
w
),
axis
=
range
(
1
,
w
.
ndim
))
return
x
,
w
,
w
,
w_amax
def
test_func
(
outter_x
,
outter_w
,
outter_w_amax
):
in_specs
=
(
x_sharding
.
spec
,
w_sharding
.
spec
,
None
)
out_specs
=
x_sharding
.
spec
@
partial
(
shard_map
.
shard_map
,
mesh
=
mesh
,
in_specs
=
in_specs
,
out_specs
=
out_specs
,
check_rep
=
False
,
)
def
sharded_group_gemm
(
x
,
w
,
w_amax
):
group_size
=
x
.
shape
[
0
]
x_reshaped
=
x
.
reshape
(
-
1
,
x
.
shape
[
-
1
])
n_groups
=
jnp
.
full
(
group_size
,
x_reshaped
.
shape
[
0
]
//
group_size
)
quantizer_set
=
QuantizerFactory
.
create_set
(
scaling_mode
=
ScalingMode
.
CURRENT_TENSOR_SCALING
,
fwd_dtype
=
jnp
.
float8_e4m3fn
,
bwd_dtype
=
jnp
.
float8_e5m2
,
is_2x2x
=
True
,
n_groups
=
group_size
,
)
output
=
te_grouped_dense
(
x_reshaped
,
w
,
n_groups
,
kernel_amax
=
w_amax
,
quantizer_set
=
quantizer_set
,
kernel_fsdp_info
=
(
MESH_AXIS_NAME
,
kernel_fsdp_axis
),
)
output
=
output
.
reshape
(
*
x
.
shape
[:
-
1
],
-
1
)
return
output
def
run
(
x
,
w
,
w_amax
):
output
=
sharded_group_gemm
(
x
,
w
,
w_amax
)
return
output
output
,
vjp_fn
=
jax
.
vjp
(
run
,
outter_x
,
outter_w
,
outter_w_amax
)
dx
,
dw
,
_
=
vjp_fn
(
output
)
return
output
,
dx
,
dw
def
ref_func
(
outter_x
,
outter_w
):
in_specs
=
(
x_sharding
.
spec
,
w_no_sharding
.
spec
)
out_specs
=
x_sharding
.
spec
@
partial
(
shard_map
.
shard_map
,
mesh
=
mesh
,
in_specs
=
in_specs
,
out_specs
=
out_specs
,
check_rep
=
False
,
)
def
sharded_group_gemm
(
x
,
w
):
group_size
=
x
.
shape
[
0
]
x_reshaped
=
x
.
reshape
(
-
1
,
x
.
shape
[
-
1
])
n_groups
=
jnp
.
full
(
group_size
,
x_reshaped
.
shape
[
0
]
//
group_size
)
quantizer_set
=
QuantizerFactory
.
create_set
(
scaling_mode
=
ScalingMode
.
CURRENT_TENSOR_SCALING
,
fwd_dtype
=
jnp
.
float8_e4m3fn
,
bwd_dtype
=
jnp
.
float8_e5m2
,
is_2x2x
=
True
,
n_groups
=
group_size
,
)
output
=
te_grouped_dense
(
x_reshaped
,
w
,
n_groups
,
quantizer_set
=
quantizer_set
)
output
=
output
.
reshape
(
*
x
.
shape
[:
-
1
],
-
1
)
return
output
def
run
(
x
,
w
):
output
=
sharded_group_gemm
(
x
,
w
)
return
output
output
,
vjp_fn
=
jax
.
vjp
(
run
,
outter_x
,
outter_w
)
dx
,
dw
=
vjp_fn
(
output
)
return
output
,
dx
,
dw
init_func
=
jax
.
jit
(
init_data
,
out_shardings
=
(
x_sharding
,
w_sharding
,
w_no_sharding
,
None
))
x
,
w
,
w_global
,
w_amax
=
init_func
()
o_sharding
=
x_sharding
test_func_jitted
=
jax
.
jit
(
test_func
,
in_shardings
=
(
x_sharding
,
w_sharding
,
None
),
out_shardings
=
(
o_sharding
,
x_sharding
,
w_sharding
),
)
ref_func_jitted
=
jax
.
jit
(
ref_func
,
in_shardings
=
(
x_sharding
,
w_no_sharding
),
out_shardings
=
(
o_sharding
,
x_sharding
,
w_no_sharding
),
)
out
,
dx
,
dw
=
test_func_jitted
(
x
,
w
,
w_amax
)
ref_out
,
ref_dx
,
ref_dw
=
ref_func_jitted
(
x
,
w_global
)
e4m3_tols
=
dtype_tols
(
jnp
.
float8_e4m3fn
)
e5m2_tols
=
dtype_tols
(
jnp
.
float8_e5m2
)
out
,
ref_out
=
jem
.
process_allgather
((
out
,
ref_out
))
dx
,
ref_dx
=
jem
.
process_allgather
((
dx
,
ref_dx
))
dw
,
ref_dw
=
jem
.
process_allgather
((
dw
,
ref_dw
))
jnp
.
allclose
(
out
,
ref_out
,
**
e4m3_tols
)
jnp
.
allclose
(
dx
,
ref_dx
,
**
e5m2_tols
)
jnp
.
allclose
(
dw
,
ref_dw
,
**
e5m2_tols
)
if
__name__
==
"__main__"
:
from
jax.sharding
import
NamedSharding
,
PartitionSpec
from
jax.experimental
import
shard_map
import
sys
coord_addr
=
sys
.
argv
[
1
]
proc_id
=
int
(
sys
.
argv
[
2
])
num_procs
=
int
(
sys
.
argv
[
3
])
jax
.
distributed
.
initialize
(
coordinator_address
=
coord_addr
,
num_processes
=
num_procs
,
process_id
=
proc_id
)
mesh
=
jax
.
make_mesh
((
num_procs
,),
(
MESH_AXIS_NAME
,))
with
mesh
:
data_shapes
=
[((
4
,
16
,
128
,
7168
),
(
7168
,
2048
))]
for
data_shape
in
data_shapes
:
for
kernel_fsdp_axis
in
[
1
,
2
]:
test_grouped_gemm_fp8_allgather
(
data_shape
,
kernel_fsdp_axis
)
tests/jax/test_sharding.py
deleted
100644 → 0
View file @
d262ef4c
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
import
pytest
from
transformer_engine.jax.flax
import
extend_logical_axis_rules
from
transformer_engine.jax.sharding
import
global_shard_guard
,
MeshResource
LOGICAL_RULES
=
[
[((
"a1"
,
None
),
(
"a2"
,
"ma2"
)),
False
],
[((
"a1"
,
None
),
(
"a2"
,
"ma2"
),
(
"a3"
,
(
"ma31"
,
"ma32"
))),
True
],
[((
"a1"
,
None
),
(
"a2"
,
"ma2"
),
(
"a3"
,
"ma31"
),
(
"a3"
,
"ma32"
)),
False
],
[((
"a1"
,
None
),
(
"a2"
,
"ma2"
),
(
"batch"
,
"batch_1200234"
)),
True
],
[((
"a1"
,
None
),
(
"a2"
,
"ma2"
),
(
"a2"
,
"ma1"
),
(
"batch"
,
"model"
),
(
"batch"
,
"data"
)),
True
],
]
MeshS
=
[
MeshResource
(),
MeshResource
(
"data"
,
None
),
MeshResource
(
None
,
"model"
),
MeshResource
(
"data"
,
"model"
),
]
class
TestShardingSideAPI
:
@
pytest
.
mark
.
parametrize
(
"base_rules,need_assert"
,
LOGICAL_RULES
)
@
pytest
.
mark
.
parametrize
(
"sr"
,
MeshS
)
def
test_extend_logical_axis_rules
(
self
,
base_rules
,
need_assert
,
sr
):
with
global_shard_guard
(
sr
):
try
:
target_te_rules
=
extend_logical_axis_rules
(
tuple
())
extended_rules
=
extend_logical_axis_rules
(
base_rules
)
assert
extended_rules
==
(
*
base_rules
,
*
target_te_rules
)
assert
not
need_assert
except
AssertionError
as
ae
:
assert
need_assert
,
f
"
{
ae
.
args
}
"
tests/pytorch/attention/test_attention.py
View file @
27ddce40
...
...
@@ -274,6 +274,8 @@ model_configs_mla = {
"mla_3_0"
:
ModelConfig
(
8
,
1
,
16
,
128
,
max_seqlen_kv
=
2048
,
head_dim_v
=
64
),
# inference
"mla_3_1"
:
ModelConfig
(
8
,
1
,
16
,
256
,
max_seqlen_kv
=
2048
,
head_dim_v
=
128
),
# inference
"mla_3_2"
:
ModelConfig
(
8
,
1
,
16
,
192
,
max_seqlen_kv
=
2048
,
head_dim_v
=
128
),
# inference
"mla_3_3"
:
ModelConfig
(
8
,
1
,
16
,
160
,
max_seqlen_kv
=
2048
,
head_dim_v
=
128
),
# inference
"mla_3_4"
:
ModelConfig
(
8
,
1
,
16
,
160
,
max_seqlen_kv
=
2048
,
head_dim_v
=
160
),
# inference
}
...
...
tests/pytorch/attention/test_attention_with_cp.py
View file @
27ddce40
...
...
@@ -37,6 +37,12 @@ model_configs_flash_attn = {
2
,
4096
,
12
,
128
,
num_gqa_groups
=
2
,
attn_mask_type
=
"causal"
,
window_size
=
(
512
,
0
)
),
# GQA
"cp_2_3"
:
ModelConfig
(
2
,
4096
,
12
,
128
,
num_gqa_groups
=
2
,
window_size
=
(
512
,
512
)),
# GQA
"cp_3_0"
:
ModelConfig
(
2
,
4096
,
12
,
192
,
attn_mask_type
=
"causal"
,
head_dim_v
=
128
),
# MLA
"cp_3_1"
:
ModelConfig
(
2
,
4096
,
12
,
192
,
head_dim_v
=
128
),
# MLA
"cp_3_2"
:
ModelConfig
(
2
,
4096
,
12
,
192
,
attn_mask_type
=
"causal"
,
window_size
=
(
512
,
0
),
head_dim_v
=
128
),
# MLA
"cp_3_3"
:
ModelConfig
(
2
,
4096
,
12
,
192
,
window_size
=
(
512
,
512
),
head_dim_v
=
128
),
# MLA
}
...
...
@@ -82,6 +88,8 @@ def test_cp_with_flash_attention(dtype, model, qkv_format, cp_comm_type):
f
"CP implementation with QKVO A2A requires num_heads (
{
config
.
num_heads
}
) and"
f
" num_gqa_groups (
{
config
.
num_gqa_groups
}
) to be divisible by cp_size (2)!"
)
if
"p2p"
not
in
cp_comm_type
and
config
.
head_dim_qk
!=
config
.
head_dim_v
:
pytest
.
skip
(
"MLA CP currently only support KV P2P!"
)
subprocess
.
run
(
get_bash_arguments
(
...
...
tests/pytorch/attention/test_cp_utils.py
0 → 100644
View file @
27ddce40
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
"""Unit tests for context parallel utils."""
import
torch
import
unittest
from
typing
import
Tuple
from
transformer_engine.pytorch.attention.dot_product_attention.context_parallel
import
(
get_batch_on_this_cp_rank
,
pad_thd_sequences_for_cp
,
generate_positional_ids_for_cp
,
)
class
TestSequencePadding
(
unittest
.
TestCase
):
def
test_padding_with_custom_padding_values_sequences_shorter_than_divisibility_factor
(
self
):
"""Test with custom padding values for all tensors."""
# Setup
input_ids
=
torch
.
tensor
([
1
,
1
,
1
,
2
,
2
,
3
,
3
,
3
,
3
])
cu_seqlens
=
torch
.
tensor
([
0
,
3
,
5
,
9
])
labels
=
torch
.
tensor
([
-
100
,
-
100
,
-
100
,
-
100
,
-
100
,
-
100
,
-
100
,
13
,
-
100
])
positional_ids
=
torch
.
tensor
([
0
,
1
,
2
,
0
,
1
,
0
,
1
,
2
,
3
])
divisibility_factor
=
8
pid
=
777
label_pad
=
-
200
input_ids_padded
,
labels_padded
,
cu_seqlens_padded
=
pad_thd_sequences_for_cp
(
input_ids
.
unsqueeze
(
0
),
labels
.
unsqueeze
(
0
),
cu_seqlens
,
divisibility_factor
,
padding_token_id
=
pid
,
padding_label_id
=
label_pad
,
)
positional_ids_padded
=
generate_positional_ids_for_cp
(
cu_seqlens
,
divisibility_factor
,
)
# Sequence: [ a a a p p p p p b b pppppp ccccpppp]
print
(
"input_ids_padded: "
,
input_ids_padded
)
print
(
"labels_padded: "
,
labels_padded
)
print
(
"positional_ids_padded: "
,
positional_ids_padded
)
print
(
"cu_seqlens_padded: "
,
cu_seqlens_padded
)
expected_input_ids
=
torch
.
tensor
(
[
1
,
1
,
1
,
pid
,
pid
,
pid
,
pid
,
pid
,
2
,
2
,
pid
,
pid
,
pid
,
pid
,
pid
,
pid
,
3
,
3
,
3
,
3
,
pid
,
pid
,
pid
,
pid
,
]
)
expected_cu_seqlens_padded
=
torch
.
tensor
([
0
,
8
,
16
,
24
])
expected_labels_padded
=
torch
.
tensor
(
[
-
100
,
-
100
,
-
100
,
label_pad
,
label_pad
,
label_pad
,
label_pad
,
label_pad
,
-
100
,
-
100
,
label_pad
,
label_pad
,
label_pad
,
label_pad
,
label_pad
,
label_pad
,
-
100
,
-
100
,
13
,
-
100
,
label_pad
,
label_pad
,
label_pad
,
label_pad
,
]
)
expected_positional_ids
=
torch
.
tensor
(
[
0
,
1
,
2
,
3
,
4
,
5
,
6
,
7
,
0
,
1
,
2
,
3
,
4
,
5
,
6
,
7
,
0
,
1
,
2
,
3
,
4
,
5
,
6
,
7
]
)
assert
torch
.
equal
(
input_ids_padded
,
expected_input_ids
)
assert
torch
.
equal
(
labels_padded
,
expected_labels_padded
)
assert
torch
.
equal
(
positional_ids_padded
,
expected_positional_ids
)
assert
torch
.
equal
(
cu_seqlens_padded
,
expected_cu_seqlens_padded
)
def
test_mixed_sequence_lengths_with_divisibility_factor
(
self
):
"""Test with sequences both shorter and longer than divisibility factor."""
# Setup - divisibility factor 6
# Seq 1: length 2 (shorter than 6, needs 4 padding)
# Seq 2: length 7 (longer than 6, needs 5 padding to reach 12)
# Seq 3: length 4 (shorter than 6, needs 2 padding)
# Seq 4: length 10 (longer than 6, needs 2 padding to reach 12)
input_ids
=
torch
.
tensor
(
[
1
,
1
,
2
,
2
,
2
,
2
,
2
,
2
,
2
,
3
,
3
,
3
,
3
,
4
,
4
,
4
,
4
,
4
,
4
,
4
,
4
,
4
,
4
]
)
labels
=
torch
.
tensor
(
[
10
,
11
,
20
,
21
,
22
,
23
,
24
,
25
,
26
,
30
,
31
,
32
,
33
,
40
,
41
,
42
,
43
,
44
,
45
,
46
,
47
,
48
,
49
,
]
)
positional_ids
=
torch
.
tensor
(
[
0
,
1
,
0
,
1
,
2
,
3
,
4
,
5
,
6
,
0
,
1
,
2
,
3
,
0
,
1
,
2
,
3
,
4
,
5
,
6
,
7
,
8
,
9
]
)
cu_seqlens
=
torch
.
tensor
([
0
,
2
,
9
,
13
,
23
])
divisibility_factor
=
6
pid
=
999
label_pad
=
-
300
# Execute
input_ids_padded
,
labels_padded
,
cu_seqlens_padded
=
pad_thd_sequences_for_cp
(
input_ids
.
unsqueeze
(
0
),
labels
.
unsqueeze
(
0
),
cu_seqlens
,
divisibility_factor
,
padding_token_id
=
pid
,
padding_label_id
=
label_pad
,
)
positional_ids_padded
=
generate_positional_ids_for_cp
(
cu_seqlens
,
divisibility_factor
,
)
# Assert
# Seq 1: [1,1] + 4 pads = 6 total
# Seq 2: [2,2,2,2,2,2,2] + 5 pads = 12 total
# Seq 3: [3,3,3,3] + 2 pads = 6 total
# Seq 4: [4,4,4,4,4,4,4,4,4,4] + 2 pads = 12 total
expected_input_ids
=
torch
.
tensor
(
[
1
,
1
,
pid
,
pid
,
pid
,
pid
,
# Seq 1: 2 + 4 padding
2
,
2
,
2
,
2
,
2
,
2
,
2
,
pid
,
pid
,
pid
,
pid
,
pid
,
# Seq 2: 7 + 5 padding
3
,
3
,
3
,
3
,
pid
,
pid
,
# Seq 3: 4 + 2 padding
4
,
4
,
4
,
4
,
4
,
4
,
4
,
4
,
4
,
4
,
pid
,
pid
,
# Seq 4: 10 + 2 padding
]
)
expected_labels
=
torch
.
tensor
(
[
10
,
11
,
label_pad
,
label_pad
,
label_pad
,
label_pad
,
20
,
21
,
22
,
23
,
24
,
25
,
26
,
label_pad
,
label_pad
,
label_pad
,
label_pad
,
label_pad
,
30
,
31
,
32
,
33
,
label_pad
,
label_pad
,
40
,
41
,
42
,
43
,
44
,
45
,
46
,
47
,
48
,
49
,
label_pad
,
label_pad
,
]
)
expected_positional_ids
=
torch
.
tensor
(
[
0
,
1
,
2
,
3
,
4
,
5
,
# Seq 1 positions continue through padding
0
,
1
,
2
,
3
,
4
,
5
,
6
,
7
,
8
,
9
,
10
,
11
,
# Seq 2 positions continue
0
,
1
,
2
,
3
,
4
,
5
,
# Seq 3 positions continue
0
,
1
,
2
,
3
,
4
,
5
,
6
,
7
,
8
,
9
,
10
,
11
,
# Seq 4 positions continue
]
)
expected_cu_seqlens_padded
=
torch
.
tensor
([
0
,
6
,
18
,
24
,
36
])
self
.
assertTrue
(
torch
.
equal
(
input_ids_padded
,
expected_input_ids
))
self
.
assertTrue
(
torch
.
equal
(
labels_padded
,
expected_labels
))
self
.
assertTrue
(
torch
.
equal
(
positional_ids_padded
,
expected_positional_ids
))
self
.
assertTrue
(
torch
.
equal
(
cu_seqlens_padded
,
expected_cu_seqlens_padded
))
def
test_sequences_longer_than_divisibility_factor
(
self
):
"""Test with all sequences longer than the divisibility factor."""
# Setup - divisibility factor 4, all sequences longer than 4
# Seq 1: length 7 (needs 1 padding to reach 8)
# Seq 2: length 11 (needs 1 padding to reach 12)
# Seq 3: length 5 (needs 3 padding to reach 8)
input_ids
=
torch
.
tensor
(
[
1
,
1
,
1
,
1
,
1
,
1
,
1
,
# 7 tokens
2
,
2
,
2
,
2
,
2
,
2
,
2
,
2
,
2
,
2
,
2
,
# 11 tokens
3
,
3
,
3
,
3
,
3
,
# 5 tokens
]
)
labels
=
torch
.
tensor
(
[
100
,
101
,
102
,
103
,
104
,
105
,
106
,
200
,
201
,
202
,
203
,
204
,
205
,
206
,
207
,
208
,
209
,
210
,
300
,
301
,
302
,
303
,
304
,
]
)
positional_ids
=
torch
.
tensor
(
[
0
,
1
,
2
,
3
,
4
,
5
,
6
,
0
,
1
,
2
,
3
,
4
,
5
,
6
,
7
,
8
,
9
,
10
,
0
,
1
,
2
,
3
,
4
]
)
cu_seqlens
=
torch
.
tensor
([
0
,
7
,
18
,
23
])
divisibility_factor
=
4
pid
=
888
label_pad
=
-
400
# Execute
input_ids_padded
,
labels_padded
,
cu_seqlens_padded
=
pad_thd_sequences_for_cp
(
input_ids
.
unsqueeze
(
0
),
labels
.
unsqueeze
(
0
),
cu_seqlens
,
divisibility_factor
,
padding_token_id
=
pid
,
padding_label_id
=
label_pad
,
)
positional_ids_padded
=
generate_positional_ids_for_cp
(
cu_seqlens
,
divisibility_factor
,
)
# Assert
# Seq 1: 7 + 1 pad = 8 (divisible by 4)
# Seq 2: 11 + 1 pad = 12 (divisible by 4)
# Seq 3: 5 + 3 pads = 8 (divisible by 4)
expected_input_ids
=
torch
.
tensor
(
[
1
,
1
,
1
,
1
,
1
,
1
,
1
,
pid
,
2
,
2
,
2
,
2
,
2
,
2
,
2
,
2
,
2
,
2
,
2
,
pid
,
3
,
3
,
3
,
3
,
3
,
pid
,
pid
,
pid
,
]
)
expected_labels
=
torch
.
tensor
(
[
100
,
101
,
102
,
103
,
104
,
105
,
106
,
label_pad
,
200
,
201
,
202
,
203
,
204
,
205
,
206
,
207
,
208
,
209
,
210
,
label_pad
,
300
,
301
,
302
,
303
,
304
,
label_pad
,
label_pad
,
label_pad
,
]
)
expected_positional_ids
=
torch
.
tensor
(
[
0
,
1
,
2
,
3
,
4
,
5
,
6
,
7
,
0
,
1
,
2
,
3
,
4
,
5
,
6
,
7
,
8
,
9
,
10
,
11
,
0
,
1
,
2
,
3
,
4
,
5
,
6
,
7
]
)
expected_cu_seqlens_padded
=
torch
.
tensor
([
0
,
8
,
20
,
28
])
self
.
assertTrue
(
torch
.
equal
(
input_ids_padded
,
expected_input_ids
))
self
.
assertTrue
(
torch
.
equal
(
labels_padded
,
expected_labels
))
self
.
assertTrue
(
torch
.
equal
(
positional_ids_padded
,
expected_positional_ids
))
self
.
assertTrue
(
torch
.
equal
(
cu_seqlens_padded
,
expected_cu_seqlens_padded
))
class
TestContextParallelUtils
(
unittest
.
TestCase
):
"""Test utilities for context parallel functionality."""
def
setUp
(
self
):
"""Set up mock distributed environment."""
# Mock torch.distributed functions
self
.
original_get_world_size
=
torch
.
distributed
.
get_world_size
self
.
original_get_rank
=
torch
.
distributed
.
get_rank
def
tearDown
(
self
):
"""Restore original torch.distributed functions."""
torch
.
distributed
.
get_world_size
=
self
.
original_get_world_size
torch
.
distributed
.
get_rank
=
self
.
original_get_rank
def
_mock_distributed_env
(
self
,
cp_size
,
cp_rank
):
"""Mock the distributed environment for testing."""
def
mock_get_world_size
(
group
=
None
):
return
cp_size
def
mock_get_rank
(
group
=
None
):
return
cp_rank
torch
.
distributed
.
get_world_size
=
mock_get_world_size
torch
.
distributed
.
get_rank
=
mock_get_rank
def
test_cp_rank_slicing_simple_case
(
self
):
"""Test CP rank slicing with a simple 2-rank, single sequence case."""
# Setup: Single sequence of length 8, CP size = 2
# Each sequence gets divided into 2*cp_size = 4 slices of size 2 each
# Rank 0 gets slices [0,1] and [6,7] (first and last)
# Rank 1 gets slices [2,3] and [4,5] (second and second-to-last)
input_ids
=
torch
.
tensor
([[
1
,
2
,
3
,
4
,
5
,
6
,
7
,
8
]])
# Shape: (1, 8) - batch first
labels
=
torch
.
tensor
([[
10
,
20
,
30
,
40
,
50
,
60
,
70
,
80
]])
position_ids
=
torch
.
tensor
([
0
,
1
,
2
,
3
,
4
,
5
,
6
,
7
])
# Shape: (8,) - 1D as expected
cu_seqlens
=
torch
.
tensor
([
0
,
8
])
# Test rank 0
self
.
_mock_distributed_env
(
cp_size
=
2
,
cp_rank
=
0
)
input_ids_r0
,
labels_r0
,
pos_ids_r0
=
get_batch_on_this_cp_rank
(
cu_seqlens
,
input_ids
,
labels
,
position_ids
)
# Rank 0 should get indices [0,1] and [6,7]
expected_input_ids_r0
=
torch
.
tensor
([[
1
,
2
,
7
,
8
]])
expected_labels_r0
=
torch
.
tensor
([[
10
,
20
,
70
,
80
]])
expected_pos_ids_r0
=
torch
.
tensor
([
0
,
1
,
6
,
7
])
self
.
assertTrue
(
torch
.
equal
(
input_ids_r0
,
expected_input_ids_r0
))
self
.
assertTrue
(
torch
.
equal
(
labels_r0
,
expected_labels_r0
))
self
.
assertTrue
(
torch
.
equal
(
pos_ids_r0
,
expected_pos_ids_r0
))
# Test rank 1
self
.
_mock_distributed_env
(
cp_size
=
2
,
cp_rank
=
1
)
input_ids_r1
,
labels_r1
,
pos_ids_r1
=
get_batch_on_this_cp_rank
(
cu_seqlens
,
input_ids
,
labels
,
position_ids
)
# Rank 1 should get indices [2,3] and [4,5]
expected_input_ids_r1
=
torch
.
tensor
([[
3
,
4
,
5
,
6
]])
expected_labels_r1
=
torch
.
tensor
([[
30
,
40
,
50
,
60
]])
expected_pos_ids_r1
=
torch
.
tensor
([
2
,
3
,
4
,
5
])
self
.
assertTrue
(
torch
.
equal
(
input_ids_r1
,
expected_input_ids_r1
))
self
.
assertTrue
(
torch
.
equal
(
labels_r1
,
expected_labels_r1
))
self
.
assertTrue
(
torch
.
equal
(
pos_ids_r1
,
expected_pos_ids_r1
))
def
test_cp_rank_slicing_multiple_sequences
(
self
):
"""Test CP rank slicing with multiple sequences."""
# Setup: Two sequences of length 8 each, CP size = 2
# Total sequence length = 16, cu_seqlens = [0, 8, 16]
input_ids
=
torch
.
tensor
([[
1
,
2
,
3
,
4
,
5
,
6
,
7
,
8
,
11
,
12
,
13
,
14
,
15
,
16
,
17
,
18
]])
labels
=
torch
.
tensor
(
[[
10
,
20
,
30
,
40
,
50
,
60
,
70
,
80
,
110
,
120
,
130
,
140
,
150
,
160
,
170
,
180
]]
)
position_ids
=
torch
.
tensor
([
0
,
1
,
2
,
3
,
4
,
5
,
6
,
7
,
0
,
1
,
2
,
3
,
4
,
5
,
6
,
7
])
cu_seqlens
=
torch
.
tensor
([
0
,
8
,
16
])
# Test rank 0
self
.
_mock_distributed_env
(
cp_size
=
2
,
cp_rank
=
0
)
input_ids_r0
,
labels_r0
,
pos_ids_r0
=
get_batch_on_this_cp_rank
(
cu_seqlens
,
input_ids
,
labels
,
position_ids
)
# For each sequence, rank 0 gets first and last slices
# Seq 1: indices [0,1] and [6,7] -> values [1,2] and [7,8]
# Seq 2: indices [8,9] and [14,15] -> values [11,12] and [17,18]
expected_input_ids_r0
=
torch
.
tensor
([[
1
,
2
,
7
,
8
,
11
,
12
,
17
,
18
]])
expected_labels_r0
=
torch
.
tensor
([[
10
,
20
,
70
,
80
,
110
,
120
,
170
,
180
]])
expected_pos_ids_r0
=
torch
.
tensor
([
0
,
1
,
6
,
7
,
0
,
1
,
6
,
7
])
self
.
assertTrue
(
torch
.
equal
(
input_ids_r0
,
expected_input_ids_r0
))
self
.
assertTrue
(
torch
.
equal
(
labels_r0
,
expected_labels_r0
))
self
.
assertTrue
(
torch
.
equal
(
pos_ids_r0
,
expected_pos_ids_r0
))
def
test_cp_rank_slicing_with_cp_size_1
(
self
):
"""Test that CP size = 1 returns original tensors unchanged."""
input_ids
=
torch
.
tensor
([[
1
,
2
,
3
,
4
,
5
,
6
,
7
,
8
]])
labels
=
torch
.
tensor
([[
10
,
20
,
30
,
40
,
50
,
60
,
70
,
80
]])
position_ids
=
torch
.
tensor
([
0
,
1
,
2
,
3
,
4
,
5
,
6
,
7
])
cu_seqlens
=
torch
.
tensor
([
0
,
8
])
self
.
_mock_distributed_env
(
cp_size
=
1
,
cp_rank
=
0
)
input_ids_result
,
labels_result
,
pos_ids_result
=
get_batch_on_this_cp_rank
(
cu_seqlens
,
input_ids
,
labels
,
position_ids
)
# With CP size = 1, should return original tensors
self
.
assertTrue
(
torch
.
equal
(
input_ids_result
,
input_ids
))
self
.
assertTrue
(
torch
.
equal
(
labels_result
,
labels
))
self
.
assertTrue
(
torch
.
equal
(
pos_ids_result
,
position_ids
))
def
test_cp_rank_slicing_sequence_dim_detection
(
self
):
"""Test that the function correctly detects sequence dimension."""
# Test with sequence dimension = 0 (sequence_length, batch_size)
input_ids
=
torch
.
tensor
(
[[
1
,
10
],
[
2
,
20
],
[
3
,
30
],
[
4
,
40
],
[
5
,
50
],
[
6
,
60
],
[
7
,
70
],
[
8
,
80
]]
)
# (8, 2)
labels
=
torch
.
tensor
(
[[
1
,
10
],
[
2
,
20
],
[
3
,
30
],
[
4
,
40
],
[
5
,
50
],
[
6
,
60
],
[
7
,
70
],
[
8
,
80
]]
)
position_ids
=
torch
.
tensor
(
[[
0
,
0
],
[
1
,
1
],
[
2
,
2
],
[
3
,
3
],
[
4
,
4
],
[
5
,
5
],
[
6
,
6
],
[
7
,
7
]]
)
cu_seqlens
=
torch
.
tensor
([
0
,
8
])
self
.
_mock_distributed_env
(
cp_size
=
2
,
cp_rank
=
0
)
input_ids_r0
,
labels_r0
,
pos_ids_r0
=
get_batch_on_this_cp_rank
(
cu_seqlens
,
input_ids
,
labels
,
position_ids
)
# Should get indices [0,1] and [6,7] along dimension 0
expected_input_ids_r0
=
torch
.
tensor
([[
1
,
10
],
[
2
,
20
],
[
7
,
70
],
[
8
,
80
]])
expected_labels_r0
=
torch
.
tensor
([[
1
,
10
],
[
2
,
20
],
[
7
,
70
],
[
8
,
80
]])
expected_pos_ids_r0
=
torch
.
tensor
([[
0
,
0
],
[
1
,
1
],
[
6
,
6
],
[
7
,
7
]])
self
.
assertTrue
(
torch
.
equal
(
input_ids_r0
,
expected_input_ids_r0
))
self
.
assertTrue
(
torch
.
equal
(
labels_r0
,
expected_labels_r0
))
self
.
assertTrue
(
torch
.
equal
(
pos_ids_r0
,
expected_pos_ids_r0
))
def
test_cp_rank_slicing_mixed_dimensions
(
self
):
"""Test CP rank slicing where input_ids/labels are 1D but position_ids has batch dimension."""
# Setup: Single sequence of length 8, CP size = 2
# This tests the opposite case from the simple test:
# - input_ids and labels: 1D (no batch dimension)
# - position_ids: 2D (has batch dimension)
input_ids
=
torch
.
tensor
([
1
,
2
,
3
,
4
,
5
,
6
,
7
,
8
])
# Shape: (8,) - 1D
labels
=
torch
.
tensor
([
10
,
20
,
30
,
40
,
50
,
60
,
70
,
80
])
# Shape: (8,) - 1D
position_ids
=
torch
.
tensor
([[
0
,
1
,
2
,
3
,
4
,
5
,
6
,
7
]])
# Shape: (1, 8) - 2D with batch
cu_seqlens
=
torch
.
tensor
([
0
,
8
])
# Test rank 0
self
.
_mock_distributed_env
(
cp_size
=
2
,
cp_rank
=
0
)
input_ids_r0
,
labels_r0
,
pos_ids_r0
=
get_batch_on_this_cp_rank
(
cu_seqlens
,
input_ids
,
labels
,
position_ids
)
# Rank 0 should get indices [0,1] and [6,7]
expected_input_ids_r0
=
torch
.
tensor
([
1
,
2
,
7
,
8
])
# 1D result
expected_labels_r0
=
torch
.
tensor
([
10
,
20
,
70
,
80
])
# 1D result
expected_pos_ids_r0
=
torch
.
tensor
([[
0
,
1
,
6
,
7
]])
# 2D result (preserves batch dim)
self
.
assertTrue
(
torch
.
equal
(
input_ids_r0
,
expected_input_ids_r0
))
self
.
assertTrue
(
torch
.
equal
(
labels_r0
,
expected_labels_r0
))
self
.
assertTrue
(
torch
.
equal
(
pos_ids_r0
,
expected_pos_ids_r0
))
# Test rank 1
self
.
_mock_distributed_env
(
cp_size
=
2
,
cp_rank
=
1
)
input_ids_r1
,
labels_r1
,
pos_ids_r1
=
get_batch_on_this_cp_rank
(
cu_seqlens
,
input_ids
,
labels
,
position_ids
)
# Rank 1 should get indices [2,3] and [4,5]
expected_input_ids_r1
=
torch
.
tensor
([
3
,
4
,
5
,
6
])
# 1D result
expected_labels_r1
=
torch
.
tensor
([
30
,
40
,
50
,
60
])
# 1D result
expected_pos_ids_r1
=
torch
.
tensor
([[
2
,
3
,
4
,
5
]])
# 2D result (preserves batch dim)
self
.
assertTrue
(
torch
.
equal
(
input_ids_r1
,
expected_input_ids_r1
))
self
.
assertTrue
(
torch
.
equal
(
labels_r1
,
expected_labels_r1
))
self
.
assertTrue
(
torch
.
equal
(
pos_ids_r1
,
expected_pos_ids_r1
))
def
test_integration_with_padding_and_cp_slicing
(
self
):
"""Integration test: pad sequences then slice for CP ranks."""
# Start with unpadded sequences
input_ids
=
torch
.
tensor
([
1
,
1
,
2
,
2
,
2
])
# Two sequences: [1,1] and [2,2,2]
labels
=
torch
.
tensor
([
10
,
11
,
20
,
21
,
22
])
positional_ids
=
torch
.
tensor
([
0
,
1
,
0
,
1
,
2
])
cu_seqlens
=
torch
.
tensor
([
0
,
2
,
5
])
divisibility_factor
=
4
# Will pad to lengths 4 and 4
# First, pad sequences
input_ids_padded
,
labels_padded
,
cu_seqlens_padded
=
pad_thd_sequences_for_cp
(
input_ids
.
unsqueeze
(
0
),
labels
.
unsqueeze
(
0
),
cu_seqlens
,
divisibility_factor
,
padding_token_id
=
0
,
padding_label_id
=-
100
,
)
positional_ids_padded
=
generate_positional_ids_for_cp
(
cu_seqlens
,
divisibility_factor
,
)
# Expected after padding: [1,1,0,0,2,2,2,0] with cu_seqlens [0,4,8]
expected_padded
=
torch
.
tensor
([
1
,
1
,
0
,
0
,
2
,
2
,
2
,
0
])
self
.
assertTrue
(
torch
.
equal
(
input_ids_padded
,
expected_padded
))
# Now test CP slicing with cp_size=2
# Test rank 0
self
.
_mock_distributed_env
(
cp_size
=
2
,
cp_rank
=
0
)
input_ids_r0
,
labels_r0
,
pos_ids_r0
=
get_batch_on_this_cp_rank
(
cu_seqlens_padded
,
input_ids_padded
.
unsqueeze
(
0
),
labels_padded
.
unsqueeze
(
0
),
positional_ids_padded
,
)
# Each sequence of length 4 gets divided into 4 slices of size 1
# Rank 0 gets slices [0] and [3] from each sequence
# Seq 1: indices [0] and [3] -> values [1] and [0]
# Seq 2: indices [4] and [7] -> values [2] and [0]
expected_input_ids_r0
=
torch
.
tensor
([[
1
,
0
,
2
,
0
]])
self
.
assertTrue
(
torch
.
equal
(
input_ids_r0
,
expected_input_ids_r0
))
if
__name__
==
"__main__"
:
unittest
.
main
()
tests/pytorch/debug/test_api_features.py
View file @
27ddce40
...
...
@@ -268,7 +268,7 @@ def test_statistics_collection(configs_dir, feature_dirs):
)[
0
]
expected_underflows
=
(
((
tensor_fp8
.
_data
==
0
).
sum
()
-
(
tensor
==
0
).
sum
())
*
100
/
(
100
*
100
*
5
)
((
tensor_fp8
.
dequantize
()
==
0
).
sum
()
-
(
tensor
==
0
).
sum
())
*
100
/
(
100
*
100
*
5
)
)
assert
debug_api
.
transformer_engine
.
inspect_tensor_enabled
(
...
...
@@ -302,7 +302,7 @@ def test_statistics_collection(configs_dir, feature_dirs):
)[
0
]
# Second config in same yaml
tensor
=
torch
.
rand
((
100
,
100
,
5
))
tensor
=
torch
.
rand
((
100
,
100
,
5
))
.
cuda
()
debug_api
.
transformer_engine
.
inspect_tensor
(
"decoder.6.mlp.fc1"
,
tensor_name
=
"activation"
,
...
...
@@ -316,7 +316,9 @@ def test_statistics_collection(configs_dir, feature_dirs):
stats
=
log
()
stats_names
=
[
x
[
3
]
for
x
in
stats
.
keys
()]
all
(
s
in
stats_names
for
s
in
[
"cur_amax"
,
"dynamic_range"
,
"mean"
,
"std"
,
"l1_norm"
])
assert
stats
[(
"decoder.6.mlp.fc1"
,
"activation"
,
"mean"
,
200
)]
==
tensor
.
mean
()
torch
.
testing
.
assert_close
(
stats
[(
"decoder.6.mlp.fc1"
,
"activation"
,
"mean"
,
200
)],
tensor
.
mean
()
)
debug_api
.
transformer_engine
.
inspect_tensor
(
"decoder.7.mlp.fc1"
,
...
...
@@ -331,7 +333,7 @@ def test_statistics_collection(configs_dir, feature_dirs):
stats
=
log
()
stats_names
=
[
x
[
3
]
for
x
in
stats
.
keys
()]
all
(
s
in
stats_names
for
s
in
[
"mean"
,
"std"
,
"l1_norm"
,
"min"
,
"max"
])
assert
stats
[(
"decoder.7.mlp.fc1"
,
"weight"
,
"max"
,
200
)]
==
tensor
.
max
()
torch
.
testing
.
assert_close
(
stats
[(
"decoder.7.mlp.fc1"
,
"weight"
,
"max"
,
200
)]
,
tensor
.
max
()
)
assert
not
debug_api
.
transformer_engine
.
inspect_tensor_enabled
(
"decoder.7.mlp.fc1"
,
tensor_name
=
"weight"
,
iteration
=
201
...
...
@@ -377,7 +379,7 @@ def test_statistics_multi_run(configs_dir, feature_dirs):
return
quantizer
(
t
.
cuda
())
shape
=
[
1024
,
1024
]
tensors
=
[
torch
.
randn
(
shape
)
for
_
in
range
(
2
)]
tensors
=
[
torch
.
randn
(
shape
)
.
cuda
()
for
_
in
range
(
2
)]
tensors_fp8
=
[
fp8_tensor
(
tensors
[
i
])
for
i
in
range
(
2
)]
feed
(
tensors
[
0
],
tensors_fp8
[
0
],
quantizer
)
...
...
tests/pytorch/debug/test_log.py
View file @
27ddce40
...
...
@@ -119,6 +119,9 @@ def read_log(log_dir: str) -> str:
def
test_sanity
(
feature_dirs
):
if
not
fp8_available
:
pytest
.
skip
(
reason_for_no_fp8
)
log_all_stats_config
=
LOG_QUANTIZED_CONFIG_BASE
.
format
(
stats
=
", "
.
join
(
all_stats
))
with
debug_session
(
log_all_stats_config
,
feature_dirs
)
as
log_dir
:
model
=
te
.
Linear
(
128
,
128
,
params_dtype
=
torch
.
bfloat16
)
...
...
@@ -164,8 +167,8 @@ def test_numerics(fp8_recipe, feature_dirs):
num_quantizers
=
3
,
)
tensor
=
torch
.
zeros
(
1024
,
1024
).
cuda
()
tensor
[
0
,
:
]
=
100
0
tensor
=
torch
.
randn
(
1024
,
1024
).
cuda
()
tensor
[
0
,
100
:
200
]
=
-
0.
0
quantizer
=
recipe_state
.
make_quantizers
()[
0
]
quantized_tensor
=
quantizer
(
tensor
)
...
...
@@ -188,15 +191,13 @@ def test_numerics(fp8_recipe, feature_dirs):
if
"underflows%"
in
line
:
underflows
=
float
(
line
.
split
(
"value="
)[
1
])
expected
=
(
((
dequantized_tensor
==
0
).
sum
()
-
(
tensor
==
0
).
sum
())
/
dequantized_tensor
.
numel
()
*
100
((
dequantized_tensor
==
0
).
sum
()
-
(
tensor
==
0
).
sum
())
/
tensor
.
numel
()
*
100
)
assert
underflows
==
pytest
.
approx
(
expected
.
cpu
(),
abs
=
1e-4
)
if
"mse"
in
line
:
mse
=
float
(
line
.
split
(
"value="
)[
1
])
expected
=
torch
.
nn
.
functional
.
mse_loss
(
dequantized_tensor
,
tensor
,
reduction
=
"mean"
)
assert
mse
==
pytest
.
approx
(
expected
.
cpu
(),
abs
=
1e-
6
)
assert
mse
==
pytest
.
approx
(
expected
.
cpu
(),
abs
=
1e-
4
)
if
"overflows%"
in
line
:
overflows
=
float
(
line
.
split
(
"value="
)[
1
])
expected
=
(
...
...
@@ -207,6 +208,9 @@ def test_numerics(fp8_recipe, feature_dirs):
@
pytest
.
mark
.
parametrize
(
"layer"
,
[
"linear"
,
"transformer"
])
def
test_log_every_3_or_5_layers
(
layer
,
configs_dir
,
feature_dirs
):
if
not
fp8_available
:
pytest
.
skip
(
reason_for_no_fp8
)
# If layer does not invoke any feature in current iteration,
# then it changed into non-debug mode.
# This test checks whether this works correctly -
...
...
Prev
1
2
3
4
5
6
7
…
11
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment