Commit 27ddce40 authored by wenjh's avatar wenjh
Browse files

Merge branch 'nv_main'

parents d262ef4c 5b3092a0
......@@ -891,9 +891,18 @@ void fillCase(Tensor *t, const InputsFillCase fill_case) {
}
}
template void fillCase<byte>(Tensor *t, const InputsFillCase fill_case);
template void fillCase<int16>(Tensor *t, const InputsFillCase fill_case);
template void fillCase<int32>(Tensor *t, const InputsFillCase fill_case);
template void fillCase<int64>(Tensor *t, const InputsFillCase fill_case);
template void fillCase<fp32>(Tensor *t, const InputsFillCase fill_case);
template void fillCase<fp16>(Tensor *t, const InputsFillCase fill_case);
template void fillCase<bf16>(Tensor *t, const InputsFillCase fill_case);
template void fillCase<fp8e4m3>(Tensor *t, const InputsFillCase fill_case);
template void fillCase<fp8e5m2>(Tensor *t, const InputsFillCase fill_case);
template void fillCase<fp32>(Tensor *t, const InputsFillCase fill_case);
#if FP4_TYPE_SUPPORTED
template void fillCase<fp4e2m1>(Tensor *t, const InputsFillCase fill_case);
#endif
void setRandomScale(Tensor *t) {
std::uniform_real_distribution<> dis(-2.0, 1.0);
......
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
cmake_minimum_required(VERSION 3.18)
if(NOT DEFINED CMAKE_CUDA_ARCHITECTURES)
if (CUDAToolkit_VERSION VERSION_GREATER_EQUAL 12.8)
set(CMAKE_CUDA_ARCHITECTURES 75 80 89 90 100 120)
else ()
set(CMAKE_CUDA_ARCHITECTURES 75 80 89 90)
endif()
endif()
set(CMAKE_CXX_STANDARD 17)
set(CMAKE_CUDA_STANDARD 17)
set(CMAKE_CUDA_STANDARD_REQUIRED ON)
project(transformer_engine_distributed_tests LANGUAGES CUDA CXX)
add_subdirectory(../../3rdparty/googletest ${PROJECT_BINARY_DIR}/googletest)
include_directories(${gtest_SOURCE_DIR}/include ${gtest_SOURCE_DIR})
if(NOT DEFINED TE_LIB_PATH)
execute_process(COMMAND bash -c "python3 -c 'import transformer_engine as te; print(te.__file__)'"
OUTPUT_VARIABLE TE_LIB_FILE
OUTPUT_STRIP_TRAILING_WHITESPACE)
get_filename_component(TE_LIB_PATH ${TE_LIB_FILE} DIRECTORY)
endif()
find_library(TE_LIB NAMES transformer_engine PATHS "${TE_LIB_PATH}/.." ${TE_LIB_PATH} ENV TE_LIB_PATH REQUIRED)
message(STATUS "Found transformer_engine library: ${TE_LIB}")
include_directories(../../transformer_engine/common/include)
include_directories(../../transformer_engine/common)
include_directories(../../transformer_engine)
include_directories(${CMAKE_SOURCE_DIR})
find_package(CUDAToolkit REQUIRED)
add_executable(test_comm_gemm
test_comm_gemm.cu
../cpp/test_common.cu)
find_package(OpenMP REQUIRED)
find_package(MPI REQUIRED)
find_library(NCCL_LIB
NAMES nccl libnccl
PATH_SUFFIXES lib
REQUIRED)
target_include_directories(test_comm_gemm PRIVATE ${MPI_CXX_INCLUDE_PATH} $ENV{CUBLASMP_HOME}/include)
target_link_libraries(test_comm_gemm PUBLIC CUDA::cuda_driver CUDA::cudart GTest::gtest ${TE_LIB} CUDA::nvrtc MPI::MPI_CXX ${NCCL_LIB} OpenMP::OpenMP_CXX)
include(GoogleTest)
gtest_discover_tests(test_comm_gemm DISCOVERY_TIMEOUT 600)
/*************************************************************************
* Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
*
* See LICENSE for license information.
************************************************************************/
#include <cuda.h>
#include <gtest/gtest.h>
#include <mpi.h>
#include <nccl.h>
#include <transformer_engine/comm_gemm.h>
#include <transformer_engine/gemm.h>
#include <transformer_engine/transformer_engine.h>
#include <iostream>
#include <limits>
#include <random>
#include <sstream>
#include <string>
#include <vector>
#include "../cpp/test_common.h"
#include "common.h"
using transformer_engine::DType;
using transformer_engine::TypeInfo;
#define CHECK_MPI(expr) \
do { \
int err = (expr); \
if (err != MPI_SUCCESS) { \
char err_str[MPI_MAX_ERROR_STRING + 1]{}; \
int _len{}; \
MPI_Error_string(err, err_str, &_len); \
EXPECT_TRUE(false) << "MPI error: " << err << ": " << err_str; \
} \
} while (false)
#define CHECK_NCCL(expr) \
do { \
ncclResult_t err = (expr); \
if (err != ncclSuccess) { \
EXPECT_TRUE(false) << "NCCL error: " << err << ": " << ncclGetErrorString(err); \
} \
} while (false)
#define CHECK_CU(expr) \
do { \
CUresult err = (expr); \
if (err != CUDA_SUCCESS) { \
const char* str{}; \
CUresult e_str = cuGetErrorString(err, &str); \
if (e_str != CUDA_SUCCESS) str = "(unknown)"; \
EXPECT_TRUE(false) << "CU error: " << err << ": " << str; \
} \
} while (false)
int main(int argc, char* argv[]) {
::testing::InitGoogleTest(&argc, argv);
CHECK_MPI(MPI_Init(&argc, &argv));
auto ret = RUN_ALL_TESTS();
CHECK_MPI(MPI_Finalize());
return ret;
}
bool IsMulticastSupported(int device_id) {
int supported = 0;
CHECK_CU(cuDeviceGetAttribute(&supported, CU_DEVICE_ATTRIBUTE_MULTICAST_SUPPORTED, device_id));
return supported;
}
template <typename T>
std::vector<T> CopyMatrix(const std::vector<T>& data, size_t mstart, size_t nstart, size_t msize,
size_t nsize, size_t ld) {
std::vector<T> ret(msize * nsize);
size_t dst = 0;
for (size_t j = nstart; j < nstart + nsize; ++j) {
for (size_t i = mstart; i < mstart + msize; ++i) {
ret[dst++] = data[j * ld + i];
}
}
return ret;
}
template <typename T>
test::Tensor Make(size_t m, size_t n, float scale) {
test::Tensor ret("", std::vector{n, m}, TypeInfo<T>::dtype);
ret.set_scale(scale);
ret.set_scale_inv(1.0 / scale);
return ret;
}
template <typename T>
test::Tensor MakeFromData(const std::vector<T>& data, size_t mstart, size_t nstart, size_t msize,
size_t nsize, size_t ld, float scale) {
test::Tensor ret("", std::vector{nsize, msize}, TypeInfo<T>::dtype);
ret.set_scale(scale);
ret.set_scale_inv(1.0 / scale);
auto local = CopyMatrix(data, mstart, nstart, msize, nsize, ld);
NVTE_CHECK_CUDA(cudaMemcpy(ret.rowwise_dptr(), local.data(), local.size() * sizeof local[0],
cudaMemcpyDefault));
return ret;
}
template <typename T>
float GetScale(float amax) {
if constexpr (sizeof(T) > 1) return 1.0;
return static_cast<float>(static_cast<T>(std::numeric_limits<float>::max())) / amax;
}
struct Params {
DType a_type;
DType b_type;
DType d_type;
bool transa;
bool transb;
size_t m;
size_t n;
size_t k;
float tol;
};
class CommGemmFixure : public ::testing::TestWithParam<Params> {
protected:
CommGemmFixure() {
CHECK_MPI(MPI_Comm_size(MPI_COMM_WORLD, &nranks_));
CHECK_MPI(MPI_Comm_rank(MPI_COMM_WORLD, &rank_));
NVTE_CHECK_CUDA(cudaSetDevice(rank_));
ncclUniqueId id{};
if (rank_ == 0) CHECK_NCCL(ncclGetUniqueId(&id));
CHECK_MPI(MPI_Bcast(&id, sizeof(id), MPI_BYTE, 0, MPI_COMM_WORLD));
CHECK_NCCL(ncclCommInitRank(&comm_, nranks_, id, rank_));
ctx_ = nvte_comm_gemm_ctx_create(comm_, nranks_, rank_);
}
~CommGemmFixure() {
nvte_comm_gemm_ctx_destroy(ctx_);
ncclCommDestroy(comm_);
}
struct PatternDims {
int64_t a_rows_start;
int64_t a_rows_num;
int64_t a_cols_start;
int64_t a_cols_num;
int64_t b_rows_start;
int64_t b_rows_num;
int64_t b_cols_start;
int64_t b_cols_num;
int64_t d_rows_start;
int64_t d_rows_num;
int64_t d_cols_start;
int64_t d_cols_num;
};
virtual PatternDims DistributeTensors(int64_t m, int64_t n, int64_t k) = 0;
virtual void CommGemm(int64_t m, int64_t n, int64_t k, const NVTETensor a, const NVTETensor b,
const NVTETensor d, const NVTETensor bias, const NVTETensor pre_act_out,
bool transa, bool transb, bool grad, bool accumulate, int comm_sm_count,
cudaStream_t stream) = 0;
template <typename AType, typename BType, typename DType, typename BiasType>
void Run(bool transa, bool transb, size_t m, size_t n, size_t k, float tol) {
cudaStream_t stream{};
NVTE_CHECK_CUDA(cudaStreamCreate(&stream));
constexpr float MAX_IN = 1.0;
std::mt19937 rng(12);
std::uniform_real_distribution<float> dist(0.0, MAX_IN);
float a_scale = GetScale<AType>(MAX_IN);
float b_scale = GetScale<BType>(MAX_IN);
float d_scale = GetScale<DType>(MAX_IN * MAX_IN * k);
float bias_scale = GetScale<BiasType>(MAX_IN);
std::vector<AType> adata(m * k);
std::generate(adata.begin(), adata.end(),
[&rng, &dist, a_scale] { return static_cast<AType>(dist(rng) * a_scale); });
std::vector<BType> bdata(k * n);
std::generate(bdata.begin(), bdata.end(),
[&rng, &dist, b_scale] { return static_cast<BType>(dist(rng) * b_scale); });
std::vector<BiasType> biasdata(m * n);
std::generate(biasdata.begin(), biasdata.end(), [&rng, &dist, bias_scale] {
return static_cast<BiasType>(dist(rng) * bias_scale);
});
auto ga = transa ? MakeFromData<AType>(adata, 0, 0, k, m, k, a_scale)
: MakeFromData<AType>(adata, 0, 0, m, k, m, a_scale);
auto gb = transb ? MakeFromData<BType>(bdata, 0, 0, n, k, n, b_scale)
: MakeFromData<BType>(bdata, 0, 0, k, n, k, b_scale);
auto gbias = MakeFromData<BiasType>(biasdata, 0, 0, m, n, m, bias_scale);
auto gd = Make<DType>(m, n, d_scale);
auto gaux = Make<DType>(m, n, d_scale);
auto dims = DistributeTensors(m, n, k);
auto a = transa ? MakeFromData<AType>(adata, dims.a_rows_start, dims.a_cols_start,
dims.a_rows_num, dims.a_cols_num, k, a_scale)
: MakeFromData<AType>(adata, dims.a_cols_start, dims.a_rows_start,
dims.a_cols_num, dims.a_rows_num, m, a_scale);
auto b = transb ? MakeFromData<BType>(bdata, dims.b_cols_start, dims.b_rows_start,
dims.b_cols_num, dims.b_rows_num, n, b_scale)
: MakeFromData<BType>(bdata, dims.b_rows_start, dims.b_cols_start,
dims.b_rows_num, dims.b_cols_num, k, b_scale);
auto bias = MakeFromData<BiasType>(biasdata, dims.d_rows_start, dims.d_cols_start,
dims.d_rows_num, dims.d_cols_num, m, bias_scale);
auto d = Make<DType>(dims.d_rows_num, dims.d_cols_num, d_scale);
auto aux = Make<DType>(dims.d_rows_num, dims.d_cols_num, d_scale);
bool grad = false;
bool accumulate = false;
CommGemm(m, n, k, a.data(), b.data(), d.data(), bias.data(), aux.data(), transa, transb, grad,
accumulate, 0 /*comm_sm_count*/, stream);
auto workspace = Make<uint8_t>(1, 32 << 20, 1.0);
nvte_cublas_gemm(ga.data(), gb.data(), gd.data(), gbias.data(), gaux.data(), transa, transb,
grad, workspace.data(), accumulate, false /* use_split_accumulator */,
0 /* math_sm_count */, stream);
NVTE_CHECK_CUDA(cudaStreamSynchronize(stream));
NVTE_CHECK_CUDA(cudaStreamDestroy(stream));
std::vector<DType> out(dims.d_rows_num * dims.d_cols_num);
NVTE_CHECK_CUDA(
cudaMemcpy(out.data(), d.rowwise_dptr(), out.size() * sizeof out[0], cudaMemcpyDefault));
std::vector<DType> out_golden_global(m * n);
NVTE_CHECK_CUDA(cudaMemcpy(out_golden_global.data(), gd.rowwise_dptr(),
out_golden_global.size() * sizeof out_golden_global[0],
cudaMemcpyDefault));
auto out_golden = CopyMatrix(out_golden_global, dims.d_rows_start, dims.d_cols_start,
dims.d_rows_num, dims.d_cols_num, m);
NVTE_CHECK(out.size() == out_golden.size());
for (size_t i = 0; i < out.size(); ++i) {
EXPECT_NEAR(static_cast<float>(out[i]), static_cast<float>(out_golden[i]), tol * k);
}
}
NVTECommGemmCtx* ctx_{};
int nranks_{};
int rank_{};
ncclComm_t comm_{};
};
struct AgGemm : public CommGemmFixure {
PatternDims DistributeTensors(int64_t m, int64_t n, int64_t k) override {
auto a_cols_num = nvte_comm_gemm_numroc(ctx_, m);
auto b_cols_num = nvte_comm_gemm_numroc(ctx_, n);
int64_t a_cols_start{};
int64_t b_cols_start{};
MPI_Exscan(&a_cols_num, &a_cols_start, 1, MPI_INT64_T, MPI_SUM, MPI_COMM_WORLD);
MPI_Exscan(&b_cols_num, &b_cols_start, 1, MPI_INT64_T, MPI_SUM, MPI_COMM_WORLD);
return PatternDims{
.a_rows_start = 0,
.a_rows_num = k,
.a_cols_start = a_cols_start,
.a_cols_num = a_cols_num,
.b_rows_start = 0,
.b_rows_num = k,
.b_cols_start = b_cols_start,
.b_cols_num = b_cols_num,
.d_rows_start = a_cols_start,
.d_rows_num = a_cols_num,
.d_cols_start = 0,
.d_cols_num = n,
};
}
void CommGemm(int64_t m, int64_t n, int64_t k, const NVTETensor a, const NVTETensor b,
const NVTETensor d, const NVTETensor bias, const NVTETensor pre_act_out,
bool transa, bool transb, bool grad, bool accumulate, int comm_sm_count,
cudaStream_t stream) override {
nvte_all_gather_gemm(ctx_, m, n, k, a, b, d, bias, pre_act_out, transa, transb, grad,
accumulate, comm_sm_count, stream, kNVTECommGemmAlgoDefault);
}
};
struct GemmRs : public CommGemmFixure {
PatternDims DistributeTensors(int64_t m, int64_t n, int64_t k) override {
auto rows_num = nvte_comm_gemm_numroc(ctx_, k);
auto d_cols_num = nvte_comm_gemm_numroc(ctx_, n);
int64_t rows_start{};
int64_t d_cols_start{};
MPI_Exscan(&rows_num, &rows_start, 1, MPI_INT64_T, MPI_SUM, MPI_COMM_WORLD);
MPI_Exscan(&d_cols_num, &d_cols_start, 1, MPI_INT64_T, MPI_SUM, MPI_COMM_WORLD);
return PatternDims{
.a_rows_start = rows_start,
.a_rows_num = rows_num,
.a_cols_start = 0,
.a_cols_num = m,
.b_rows_start = rows_start,
.b_rows_num = rows_num,
.b_cols_start = 0,
.b_cols_num = n,
.d_rows_start = 0,
.d_rows_num = m,
.d_cols_start = d_cols_start,
.d_cols_num = d_cols_num,
};
}
void CommGemm(int64_t m, int64_t n, int64_t k, const NVTETensor a, const NVTETensor b,
const NVTETensor d, const NVTETensor bias, const NVTETensor pre_act_out,
bool transa, bool transb, bool grad, bool accumulate, int comm_sm_count,
cudaStream_t stream) override {
nvte_gemm_reduce_scatter(ctx_, m, n, k, a, b, d, bias, pre_act_out, transa, transb, grad,
accumulate, comm_sm_count, stream, kNVTECommGemmAlgoDefault);
}
};
struct GemmAr : public CommGemmFixure {
PatternDims DistributeTensors(int64_t m, int64_t n, int64_t k) override {
auto rows_num = nvte_comm_gemm_numroc(ctx_, k);
int64_t rows_start{};
MPI_Exscan(&rows_num, &rows_start, 1, MPI_INT64_T, MPI_SUM, MPI_COMM_WORLD);
return PatternDims{
.a_rows_start = rows_start,
.a_rows_num = rows_num,
.a_cols_start = 0,
.a_cols_num = m,
.b_rows_start = rows_start,
.b_rows_num = rows_num,
.b_cols_start = 0,
.b_cols_num = n,
.d_rows_start = 0,
.d_rows_num = m,
.d_cols_start = 0,
.d_cols_num = n,
};
}
void CommGemm(int64_t m, int64_t n, int64_t k, const NVTETensor a, const NVTETensor b,
const NVTETensor d, const NVTETensor bias, const NVTETensor pre_act_out,
bool transa, bool transb, bool grad, bool accumulate, int comm_sm_count,
cudaStream_t stream) override {
nvte_gemm_all_reduce(ctx_, m, n, k, a, b, d, bias, pre_act_out, transa, transb, grad,
accumulate, comm_sm_count, stream, kNVTECommGemmAlgoDefault);
}
void SetUp() override {
if (!IsMulticastSupported(rank_))
GTEST_SKIP() << "Multicast is not supported on device " << rank_;
}
};
TEST_P(AgGemm, Gemm) {
auto [a_type, b_type, d_type, transa, transb, m, n, k, tol] = GetParam();
TRANSFORMER_ENGINE_TYPE_SWITCH_OUTPUT(
a_type, AType,
TRANSFORMER_ENGINE_TYPE_SWITCH_OUTPUT(
b_type, BType,
TRANSFORMER_ENGINE_TYPE_SWITCH_OUTPUT(
d_type, DType, Run<AType, BType, DType, DType>(transa, transb, m, n, k, tol);)));
}
TEST_P(GemmRs, Gemm) {
auto [a_type, b_type, d_type, transa, transb, m, n, k, tol] = GetParam();
TRANSFORMER_ENGINE_TYPE_SWITCH_OUTPUT(
a_type, AType,
TRANSFORMER_ENGINE_TYPE_SWITCH_OUTPUT(
b_type, BType,
TRANSFORMER_ENGINE_TYPE_SWITCH_OUTPUT(
d_type, DType, Run<AType, BType, DType, DType>(transa, transb, m, n, k, tol);)));
}
TEST_P(GemmAr, Gemm) {
auto [a_type, b_type, d_type, transa, transb, m, n, k, tol] = GetParam();
TRANSFORMER_ENGINE_TYPE_SWITCH_OUTPUT(
a_type, AType,
TRANSFORMER_ENGINE_TYPE_SWITCH_OUTPUT(
b_type, BType,
TRANSFORMER_ENGINE_TYPE_SWITCH_OUTPUT(
d_type, DType, Run<AType, BType, DType, DType>(transa, transb, m, n, k, tol);)));
}
std::string ParamSuffix(const testing::TestParamInfo<Params>& info) {
const auto [a_type, b_type, d_type, transa, transb, m, n, k, _tol] = info.param;
std::ostringstream ss;
ss << static_cast<int>(a_type) << "_" << static_cast<int>(b_type) << "_"
<< static_cast<int>(d_type) << "_" << (transa ? "T" : "N") << (transb ? "T" : "N") << "_" << m
<< "x" << n << "x" << k;
return ss.str();
}
INSTANTIATE_TEST_SUITE_P(AgGemm, AgGemm,
testing::Values(Params{DType::kFloat16, DType::kFloat16, DType::kFloat16,
false, false, 256, 128, 64, 1e-3},
Params{DType::kFloat16, DType::kFloat16, DType::kFloat16,
false, true, 256, 128, 64, 1e-3},
Params{DType::kFloat16, DType::kFloat16, DType::kFloat16,
true, false, 256, 128, 64, 1e-3},
Params{DType::kBFloat16, DType::kBFloat16,
DType::kBFloat16, false, false, 256, 128, 64, 1e-3},
Params{DType::kBFloat16, DType::kBFloat16,
DType::kBFloat16, false, true, 256, 128, 64, 1e-3},
Params{DType::kBFloat16, DType::kBFloat16,
DType::kBFloat16, true, false, 256, 128, 64, 1e-3},
Params{DType::kFloat8E4M3, DType::kFloat8E4M3,
DType::kFloat16, true, false, 256, 128, 64, 1e-3},
Params{DType::kFloat8E4M3, DType::kFloat8E5M2,
DType::kFloat16, true, false, 256, 128, 64, 1e-3},
Params{DType::kFloat8E5M2, DType::kFloat8E4M3,
DType::kFloat16, true, false, 256, 128, 64, 1e-3}),
&ParamSuffix);
INSTANTIATE_TEST_SUITE_P(GemmRs, GemmRs,
testing::Values(Params{DType::kFloat16, DType::kFloat16, DType::kFloat16,
false, false, 64, 128, 256, 5e-2},
Params{DType::kFloat16, DType::kFloat16, DType::kFloat16,
false, true, 64, 128, 256, 5e-2},
Params{DType::kFloat16, DType::kFloat16, DType::kFloat16,
true, false, 64, 128, 256, 5e-2},
Params{DType::kBFloat16, DType::kBFloat16,
DType::kBFloat16, false, false, 64, 128, 256, 5e-2},
Params{DType::kBFloat16, DType::kBFloat16,
DType::kBFloat16, false, true, 64, 128, 256, 5e-2},
Params{DType::kBFloat16, DType::kBFloat16,
DType::kBFloat16, true, false, 64, 128, 256, 5e-2},
Params{DType::kFloat8E4M3, DType::kFloat8E4M3,
DType::kFloat16, true, false, 64, 128, 256, 5e-2},
Params{DType::kFloat8E4M3, DType::kFloat8E5M2,
DType::kFloat16, true, false, 64, 128, 256, 5e-2},
Params{DType::kFloat8E5M2, DType::kFloat8E4M3,
DType::kFloat16, true, false, 64, 128, 256, 5e-2}),
&ParamSuffix);
INSTANTIATE_TEST_SUITE_P(
GemmAr, GemmAr,
testing::Values(Params{DType::kFloat16, DType::kFloat16, DType::kFloat16, true, false, 64,
64 * 4, 64 * 4, 5e-2},
Params{DType::kBFloat16, DType::kBFloat16, DType::kBFloat16, true, false, 64,
64 * 4, 64 * 4, 5e-2},
Params{DType::kFloat8E5M2, DType::kFloat8E4M3, DType::kFloat16, true, false,
128, 128 * 4, 128 * 4, 5e-2},
Params{DType::kFloat8E4M3, DType::kFloat8E5M2, DType::kFloat16, true, false,
128, 128 * 4, 128 * 4, 5e-2},
Params{DType::kFloat8E4M3, DType::kFloat8E4M3, DType::kFloat16, true, false,
128, 128 * 4, 128 * 4, 5e-2}),
&ParamSuffix);
......@@ -22,7 +22,7 @@ def generate_configs():
pytest.param(2, (2,), ("dp",), MeshResource(dp_resource="dp"), id="n2_dp2_tp1")
)
configs.append(
pytest.param(2, (2,), ("tp",), MeshResource(tp_resource="tp"), id="n2_dp1_tp2")
pytest.param(2, (2,), ("tpsp",), MeshResource(tpsp_resource="tpsp"), id="n2_dp1_tp2")
)
if is_devices_enough(4):
......@@ -30,8 +30,8 @@ def generate_configs():
pytest.param(
4,
(2, 2),
("dp", "tp"),
MeshResource(dp_resource="dp", tp_resource="tp"),
("dp", "tpsp"),
MeshResource(dp_resource="dp", tpsp_resource="tpsp"),
id=f"n4_dp2_tp2",
)
)
......@@ -43,8 +43,8 @@ def generate_context_parallel_configs_for_attn():
"""Generate CP combinations along with TP+DP for TestDistributedContextParallelSelfAttn only"""
configsL1 = []
configsL2 = []
mr = MeshResource(dp_resource="dp", cp_resource="cp", tp_resource="tp")
axes = ("dp", "cp", "tp")
mr = MeshResource(dp_resource="dp", cp_resource="cp", tpsp_resource="tpsp")
axes = ("dp", "cp", "tpsp")
DP_sizes = (1, 2)
CP_sizes = (1, 2, 4, 8)
TP_sizes = (1, 2)
......
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
#!/bin/bash
SCRIPT_NAME="${SCRIPT_NAME:-test.py}"
XLA_BASE_FLAGS="--xla_gpu_enable_latency_hiding_scheduler=true
--xla_gpu_enable_command_buffer=''"
export XLA_FLAGS="${XLA_BASE_FLAGS}"
NUM_RUNS=$(nvidia-smi -L | wc -l)
for ((i=1; i<NUM_RUNS; i++))
do
CUDA_VISIBLE_DEVICES=$i python $SCRIPT_NAME 127.0.0.1:12345 $i $NUM_RUNS > /dev/null 2>&1 &
done
CUDA_VISIBLE_DEVICES=0 python $SCRIPT_NAME 127.0.0.1:12345 0 $NUM_RUNS
wait
......@@ -31,6 +31,7 @@ from transformer_engine.jax.cpp_extensions.quantization import (
from transformer_engine.jax.cpp_extensions.misc import get_cudnn_version
from transformer_engine.jax import cpp_extensions as tex
from transformer_engine.jax.quantize import (
NoScaleTensor,
ScaledTensor,
ScaledTensor1x,
ScaledTensor2x,
......@@ -182,7 +183,7 @@ ACTIVATION_TYPES = {
class TestActivation:
def ref_act(self, x, activation_type):
return _jax_act_lu(x, activation_type)
return _jax_act_lu(x, activation_type).data
def value_n_grad_ref_func(self, x, activation_type):
jitted_reference = jit(
......@@ -337,8 +338,8 @@ class TestNorm:
ln_out, _ = _jax_rmsnorm(x, gamma, zero_centered_gamma, eps, quantizer)
else:
ln_out, _, _ = _jax_layernorm(x, gamma, beta, zero_centered_gamma, eps, quantizer)
# if isinstance(ln_out, ScaledTensor):
# ln_out = ln_out.dequantize()
# This is a no-op for non-quantized data
ln_out = ln_out.dequantize()
return ln_out
key = jax.random.PRNGKey(0)
......@@ -464,14 +465,23 @@ class TestNorm:
x, gamma, beta, zero_centered_gamma, epsilon, quantizer=quantizer
)
ref_out, ref_mu, ref_rsigma = _jax_layernorm(
x, gamma, beta, zero_centered_gamma, epsilon, quantizer=ref_quantizer
x,
gamma,
beta,
zero_centered_gamma,
epsilon,
quantizer=ref_quantizer,
)
else:
output, rsigma = tex.rmsnorm_fwd(
x, gamma, zero_centered_gamma, epsilon, quantizer=quantizer
)
ref_out, ref_rsigma = _jax_rmsnorm(
x, gamma, zero_centered_gamma, epsilon, quantizer=ref_quantizer
x,
gamma,
zero_centered_gamma,
epsilon,
quantizer=ref_quantizer,
)
ref_mu = None
......@@ -765,7 +775,9 @@ class TestFusedQuantize:
te_output, jax_output, precise_comparison=precise_comparison
)
else:
assert_allclose(te_output, jax_output)
assert isinstance(te_output, NoScaleTensor)
assert isinstance(jax_output, NoScaleTensor)
assert_allclose(te_output.data, jax_output.data)
if is_dbias:
# TE kernels cast the intermediate results to the input dtype which reduces precision compared to the JAX implementation, for dbias this typically only affects bfloat16.
......@@ -1020,7 +1032,6 @@ def _ref_jax_norm_impl(x, gamma, beta, norm_type, zero_centered_gamma, eps, quan
ln_out, _ = _jax_rmsnorm(x, gamma, zero_centered_gamma, eps, quantizer)
else:
ln_out, _, _ = _jax_layernorm(x, gamma, beta, zero_centered_gamma, eps, quantizer)
if isinstance(ln_out, ScaledTensor):
ln_out = ln_out.dequantize()
return ln_out
......@@ -1177,7 +1188,7 @@ class TestFusedDense:
bias_1_shape = (1,) * (linear_1_out.ndim - bias_1.ndim) + bias_1.shape
linear_1_out += jnp.reshape(bias_1, bias_1_shape)
x = _jax_act_lu(linear_1_out, activation_type)
x = _jax_act_lu(linear_1_out, activation_type).data
linear_2_out = jax.lax.dot_general(x, kernel_2, (((1,), (0,)), ((), ())))
if use_bias:
bias_2_shape = (1,) * (linear_2_out.ndim - bias_2.ndim) + bias_2.shape
......
......@@ -45,8 +45,8 @@ class TestDistributedSelfAttn:
_, seqlen, heads, _ = shape
is_dp_enabled = mesh_resource.dp_resource is not None
tp_size = 1
if mesh_resource.tp_resource is not None:
idx = mesh_axes.index(mesh_resource.tp_resource)
if mesh_resource.tpsp_resource is not None:
idx = mesh_axes.index(mesh_resource.tpsp_resource)
tp_size = mesh_shape[idx]
all_reduce_loss_bytes = 4 # 1 * FP32
......
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
import unittest
import jax
import numpy as np
from utils import pytest_parametrize_wrapper, is_devices_enough
from transformer_engine.jax.sharding import MeshResource, global_mesh_resource
from transformer_engine.jax import fp8_autocast
def generate_mesh_configs():
configs = []
if is_devices_enough(2):
configs.append(
[2, (1, 2), ("dp", "tpsp"), MeshResource(dp_resource="dp", tpsp_resource="tpsp")]
)
if is_devices_enough(4):
configs.append(
[4, (2, 2), ("fsdp", "tp"), MeshResource(tp_resource="tp", fsdp_resource="fsdp")]
)
return configs
class TestMeshResource(unittest.TestCase):
def test_fp8_autocast_with_mesh_resource(self):
for mesh_config in generate_mesh_configs():
device_count, mesh_shape, mesh_axes, mesh_resource = mesh_config
devices = np.asarray(jax.devices()[:device_count]).reshape(*mesh_shape)
mesh = jax.sharding.Mesh(devices, mesh_axes)
with mesh, fp8_autocast(enabled=False, mesh_resource=mesh_resource):
self.assertEqual(mesh_resource, global_mesh_resource())
......@@ -62,16 +62,16 @@ BIAS_2_AXES = (W_NO_SHARD_AXES,)
INTERMEDIATE = 64
# Only test with FSDP and TP as DP is not used
def generate_fsdp_and_tp_configs():
# Only test with FSDP and TPSP as DP is not used
def generate_fsdp_and_tpsp_configs():
configs = []
if is_devices_enough(2):
configs.append(
[2, (1, 2), ("fsdp", "tp"), MeshResource(fsdp_resource="fsdp", tp_resource="tp")]
[2, (1, 2), ("fsdp", "tpsp"), MeshResource(fsdp_resource="fsdp", tpsp_resource="tpsp")]
)
if is_devices_enough(4):
configs.append(
[4, (2, 2), ("fsdp", "tp"), MeshResource(fsdp_resource="fsdp", tp_resource="tp")]
[4, (2, 2), ("fsdp", "tpsp"), MeshResource(fsdp_resource="fsdp", tpsp_resource="tpsp")]
)
return configs
......@@ -173,7 +173,9 @@ class TestDistributedLayernormMLP:
)
# Single GPU
with fp8_autocast(enabled=True, fp8_recipe=fp8_recipe):
with fp8_autocast(
enabled=fp8_recipe is not None, fp8_recipe=fp8_recipe, mesh_resource=MeshResource()
):
single_jitter = jax.jit(
value_and_grad_func,
static_argnums=range(len(inputs), len(static_inputs) + len(inputs)),
......@@ -184,14 +186,14 @@ class TestDistributedLayernormMLP:
devices = np.asarray(jax.devices()[:device_count]).reshape(*mesh_shape)
mesh = Mesh(devices, mesh_axes)
with mesh, fp8_autocast(
enabled=True, fp8_recipe=fp8_recipe, mesh_resource=mesh_resource
enabled=fp8_recipe is not None, fp8_recipe=fp8_recipe, mesh_resource=mesh_resource
):
k1_sharding = NamedSharding(mesh, PartitionSpec("fsdp", None, "tp"))
k2_sharding = NamedSharding(mesh, PartitionSpec("tp", "fsdp"))
k1_sharding = NamedSharding(mesh, PartitionSpec("fsdp", None, "tpsp"))
k2_sharding = NamedSharding(mesh, PartitionSpec("tpsp", "fsdp"))
k1_ = jax.device_put(k1, k1_sharding)
k2_ = jax.device_put(k2, k2_sharding)
if use_bias:
b1_sharding = NamedSharding(mesh, PartitionSpec(None, "tp"))
b1_sharding = NamedSharding(mesh, PartitionSpec(None, "tpsp"))
b1_ = jax.device_put(b1, b1_sharding)
else:
b1_sharding = b1_ = None
......@@ -226,7 +228,12 @@ class TestDistributedLayernormMLP:
fwd_test_type = dtype if fp8_recipe is None else jnp.float8_e4m3fn
bwd_test_type = dtype if fp8_recipe is None else jnp.float8_e5m2
if fwd_test_type == jnp.float16 and use_bias:
assert_allclose(multi_fwd, single_fwd, dtype=fwd_test_type, atol=0.04, rtol=1.5)
else:
assert_allclose(multi_fwd, single_fwd, dtype=fwd_test_type)
for i in range(len(inputs)):
if multi_grads[i] is not None:
if isinstance(multi_grads[i], list):
......@@ -247,12 +254,12 @@ class TestDistributedLayernormMLP:
)
@pytest.mark.skipif(not is_fp8_supported, reason=reason)
@pytest_parametrize_wrapper("mesh_config", generate_fsdp_and_tp_configs())
@pytest_parametrize_wrapper("mesh_config", generate_fsdp_and_tpsp_configs())
@pytest_parametrize_wrapper("input_shape", INPUT_SHAPE)
@pytest_parametrize_wrapper("activation_type", [("gelu",), ("gelu", "linear")])
@pytest_parametrize_wrapper("dtype", DTYPES)
@pytest_parametrize_wrapper("use_bias", [True, False])
@pytest_parametrize_wrapper("fp8_recipe", SUPPORTED_RECIPES)
@pytest_parametrize_wrapper("fp8_recipe", [None] + SUPPORTED_RECIPES)
@pytest_parametrize_wrapper("with_jax_gemm", [False, True])
def test_layernorm_mlp_grad(
self,
......@@ -276,12 +283,12 @@ class TestDistributedLayernormMLP:
)
@pytest.mark.skipif(not is_fp8_supported, reason=reason)
@pytest_parametrize_wrapper("mesh_config", generate_fsdp_and_tp_configs())
@pytest_parametrize_wrapper("mesh_config", generate_fsdp_and_tpsp_configs())
@pytest_parametrize_wrapper("input_shape", INPUT_SHAPE)
@pytest_parametrize_wrapper("activation_type", [("gelu",), ("gelu", "linear")])
@pytest_parametrize_wrapper("dtype", DTYPES)
@pytest_parametrize_wrapper("use_bias", [True, False])
@pytest_parametrize_wrapper("fp8_recipe", SUPPORTED_RECIPES)
@pytest_parametrize_wrapper("fp8_recipe", [None] + SUPPORTED_RECIPES)
@pytest_parametrize_wrapper("with_jax_gemm", [False, True])
def test_layernorm_mlp_grad_shardy(
self,
......@@ -330,7 +337,7 @@ class TestDistributedLayernormMLP:
with use_jax_gemm(enabled=with_jax_gemm):
# Single GPUs
with fp8_autocast(enabled=use_fp8, fp8_recipe=fp8_recipe):
with fp8_autocast(enabled=use_fp8, fp8_recipe=fp8_recipe, mesh_resource=MeshResource()):
ln_mlp_single = LayerNormMLP(
layernorm_type=layernorm_type,
intermediate_dim=INTERMEDIATE,
......@@ -408,7 +415,7 @@ class TestDistributedLayernormMLP:
assert_allclose(mlp_out_sharded, mlp_out_single, dtype=dtype, atol=atol, rtol=rtol)
@pytest_parametrize_wrapper("input_shape", INPUT_SHAPE)
@pytest_parametrize_wrapper("mesh_config", generate_fsdp_and_tp_configs())
@pytest_parametrize_wrapper("mesh_config", generate_fsdp_and_tpsp_configs())
@pytest_parametrize_wrapper("activation_type", [("gelu",), ("silu", "linear")])
@pytest_parametrize_wrapper("dtype", DTYPES)
@pytest_parametrize_wrapper("use_bias", [True, False])
......@@ -429,7 +436,7 @@ class TestDistributedLayernormMLP:
)
@pytest.mark.skipif(not is_fp8_supported, reason=reason)
@pytest_parametrize_wrapper("mesh_config", generate_fsdp_and_tp_configs())
@pytest_parametrize_wrapper("mesh_config", generate_fsdp_and_tpsp_configs())
@pytest_parametrize_wrapper("activation_type", [("gelu",), ("gelu", "linear")])
@pytest_parametrize_wrapper("use_bias", [True, False])
@pytest_parametrize_wrapper("input_shape", INPUT_SHAPE)
......@@ -452,7 +459,7 @@ class TestDistributedLayernormMLP:
)
@pytest_parametrize_wrapper("input_shape", INPUT_SHAPE)
@pytest_parametrize_wrapper("mesh_config", generate_fsdp_and_tp_configs())
@pytest_parametrize_wrapper("mesh_config", generate_fsdp_and_tpsp_configs())
@pytest_parametrize_wrapper("activation_type", [("gelu",), ("silu", "linear")])
@pytest_parametrize_wrapper("dtype", DTYPES)
@pytest_parametrize_wrapper("use_bias", [True, False])
......@@ -473,7 +480,7 @@ class TestDistributedLayernormMLP:
)
@pytest.mark.skipif(not is_fp8_supported, reason=reason)
@pytest_parametrize_wrapper("mesh_config", generate_fsdp_and_tp_configs())
@pytest_parametrize_wrapper("mesh_config", generate_fsdp_and_tpsp_configs())
@pytest_parametrize_wrapper("activation_type", [("gelu",), ("gelu", "linear")])
@pytest_parametrize_wrapper("use_bias", [True, False])
@pytest_parametrize_wrapper("input_shape", INPUT_SHAPE)
......
......@@ -41,11 +41,11 @@ class TestDistributedSoftmax:
if not bad_sharding:
x_pspec = PartitionSpec(
mesh_resource.dp_resource, mesh_resource.tp_resource, None, None
mesh_resource.dp_resource, mesh_resource.tpsp_resource, None, None
)
else:
x_pspec = PartitionSpec(
mesh_resource.dp_resource, None, None, mesh_resource.tp_resource
mesh_resource.dp_resource, None, None, mesh_resource.tpsp_resource
)
if broadcast_batch_mask:
......
......@@ -41,6 +41,7 @@ from transformer_engine.jax.cpp_extensions import FusedAttnHelper
from transformer_engine_jax import (
NVTE_Fused_Attn_Backend,
get_cudnn_version,
get_device_compute_capability,
)
from distributed_test_base import assert_equal_collectives
......@@ -348,6 +349,14 @@ class FusedAttnRunner:
"seqlen_q > seqlen_kv is not supported with sliding window attention in cuDNN"
)
if (
get_device_compute_capability(0) == 100
and self.dropout_prob == 0.1
and self.attn_bias_type is not AttnBiasType.NO_BIAS
):
pytest.skip(
"For sm100, bprop kernel support for dropout + determinism (bias) is not supported"
)
# Test the MLA case where head dims for qk differ from head dims for v, only if the tensors
# are provided in BSHD_BSHD_BSHD or THD_THD_THD formats
if self.head_dim_qk != self.head_dim_v and not self.qkv_layout.is_separate():
......@@ -397,7 +406,7 @@ class FusedAttnRunner:
self.mesh = Mesh(self.devices, self.mesh_axes)
self.dp_size = self.mesh.shape.get(self.mesh_resource.dp_resource, 1)
self.cp_size = self.mesh.shape.get(self.mesh_resource.cp_resource, 1)
self.tp_size = self.mesh.shape.get(self.mesh_resource.tp_resource, 1)
self.tp_size = self.mesh.shape.get(self.mesh_resource.tpsp_resource, 1)
key = jax.random.PRNGKey(0)
q_key, k_key, v_key, bias_key, dropout_key = jax.random.split(key, 5)
......@@ -630,7 +639,7 @@ class FusedAttnRunner:
self.qkvo_psec = PartitionSpec(
self.mesh_resource.dp_resource,
self.mesh_resource.cp_resource,
self.mesh_resource.tp_resource,
self.mesh_resource.tpsp_resource,
None,
)
self.qkvo_sharding = NamedSharding(self.mesh, self.qkvo_psec)
......@@ -658,7 +667,7 @@ class FusedAttnRunner:
if self.bias_shape == BiasShape._1HSS:
self.bias_pspec = PartitionSpec(
None, self.mesh_resource.tp_resource, self.mesh_resource.cp_resource, None
None, self.mesh_resource.tpsp_resource, self.mesh_resource.cp_resource, None
)
elif self.bias_shape == BiasShape._B1SS:
self.bias_pspec = PartitionSpec(
......
......@@ -14,10 +14,11 @@ from transformer_engine.common.recipe import DelayedScaling, MXFP8BlockScaling,
from transformer_engine.common.recipe import Format as FP8Format
from transformer_engine.jax import fp8_autocast, get_delayed_scaling
from transformer_engine.jax.quantize import (
QuantizeConfig,
get_quantize_config,
is_fp8_available,
ScalingMode,
update_collections,
TensorSource,
)
from transformer_engine.jax.sharding import MeshResource, global_mesh_resource
......@@ -49,7 +50,7 @@ class TestHelper(unittest.TestCase):
class TestFP8Functions(unittest.TestCase):
def _check_default_state(self):
self.assertFalse(QuantizeConfig.is_fp8_enabled())
self.assertFalse(get_quantize_config().is_fp8_enabled())
def _compare_delay_scaling(self, ref, test):
self.assertTrue(ref.margin == test.margin)
......@@ -58,107 +59,90 @@ class TestFP8Functions(unittest.TestCase):
self.assertTrue(ref.amax_compute_algo == test.amax_compute_algo)
def _compare_current_scaling(self, test):
self.assertEqual(QuantizeConfig.FP8_FORMAT, test.fp8_format)
self.assertEqual(QuantizeConfig.SCALING_MODE, ScalingMode.CURRENT_TENSOR_SCALING)
self.assertEqual(get_quantize_config().FP8_FORMAT, test.fp8_format)
for tensor_source in TensorSource:
self.assertEqual(
get_quantize_config().get_scaling_mode(tensor_source),
ScalingMode.CURRENT_TENSOR_SCALING,
)
def _compare_mxfp8_scaling(self, test):
self.assertEqual(QuantizeConfig.MARGIN, test.margin)
self.assertEqual(QuantizeConfig.FP8_FORMAT, test.fp8_format)
self.assertEqual(QuantizeConfig.SCALING_MODE, ScalingMode.MXFP8_1D_SCALING)
self.assertEqual(get_quantize_config().MARGIN, test.margin)
self.assertEqual(get_quantize_config().FP8_FORMAT, test.fp8_format)
for tensor_source in TensorSource:
self.assertEqual(
get_quantize_config().get_scaling_mode(tensor_source), ScalingMode.MXFP8_1D_SCALING
)
@unittest.skipIf(not is_fp8_supported, reason=reason)
def test_fp8_autocast_delayed_scaling(self):
QuantizeConfig.finalize() # Ensure the testing not affect by previous tests.
self._check_default_state()
with fp8_autocast(enabled=False, fp8_recipe=DelayedScaling()):
with fp8_autocast(enabled=False, fp8_recipe=DelayedScaling(), mesh_resource=MeshResource()):
self._check_default_state()
self._check_default_state()
ds = DelayedScaling(margin=5.0, fp8_format=FP8Format.E4M3, amax_history_len=1)
with fp8_autocast(enabled=True, fp8_recipe=ds):
self.assertTrue(QuantizeConfig.is_fp8_enabled())
with fp8_autocast(enabled=True, fp8_recipe=ds, mesh_resource=MeshResource()):
self.assertTrue(get_quantize_config().is_fp8_enabled())
self._compare_delay_scaling(get_delayed_scaling(), ds)
self._check_default_state()
ds = DelayedScaling(margin=3.0, fp8_format=FP8Format.HYBRID, amax_history_len=1)
with fp8_autocast(enabled=True, fp8_recipe=ds):
self.assertTrue(QuantizeConfig.is_fp8_enabled())
with fp8_autocast(enabled=True, fp8_recipe=ds, mesh_resource=MeshResource()):
self.assertTrue(get_quantize_config().is_fp8_enabled())
self._compare_delay_scaling(get_delayed_scaling(), ds)
self._check_default_state()
@unittest.skipIf(not is_fp8_supported, reason=reason)
def test_fp8_autocast_current_scaling(self):
QuantizeConfig.finalize() # Ensure the testing not affect by previous tests.
self._check_default_state()
with fp8_autocast(enabled=False, fp8_recipe=Float8CurrentScaling()):
with fp8_autocast(
enabled=False, fp8_recipe=Float8CurrentScaling(), mesh_resource=MeshResource()
):
self._check_default_state()
self._check_default_state()
cs = Float8CurrentScaling(fp8_format=FP8Format.E4M3)
with fp8_autocast(enabled=True, fp8_recipe=cs):
self.assertTrue(QuantizeConfig.is_fp8_enabled())
with fp8_autocast(enabled=True, fp8_recipe=cs, mesh_resource=MeshResource()):
self.assertTrue(get_quantize_config().is_fp8_enabled())
self._compare_current_scaling(cs)
self._check_default_state()
cs = Float8CurrentScaling(fp8_format=FP8Format.HYBRID)
with fp8_autocast(enabled=True, fp8_recipe=cs):
self.assertTrue(QuantizeConfig.is_fp8_enabled())
with fp8_autocast(enabled=True, fp8_recipe=cs, mesh_resource=MeshResource()):
self.assertTrue(get_quantize_config().is_fp8_enabled())
self._compare_current_scaling(cs)
self._check_default_state()
@unittest.skipIf(not is_mxfp8_supported, reason=mxfp8_reason)
def test_fp8_autocast_mxfp8_block_scaling(self):
QuantizeConfig.finalize() # Ensure the testing not affect by previous tests.
self._check_default_state()
with fp8_autocast(enabled=False, fp8_recipe=MXFP8BlockScaling()):
with fp8_autocast(
enabled=False, fp8_recipe=MXFP8BlockScaling(), mesh_resource=MeshResource()
):
self._check_default_state()
self._check_default_state()
bs = MXFP8BlockScaling(margin=5.0, fp8_format=FP8Format.E4M3)
with fp8_autocast(enabled=True, fp8_recipe=bs):
self.assertTrue(QuantizeConfig.is_fp8_enabled())
with fp8_autocast(enabled=True, fp8_recipe=bs, mesh_resource=MeshResource()):
self.assertTrue(get_quantize_config().is_fp8_enabled())
self._compare_mxfp8_scaling(bs)
self._check_default_state()
bs = MXFP8BlockScaling(margin=3.0, fp8_format=FP8Format.HYBRID)
with fp8_autocast(enabled=True, fp8_recipe=bs):
self.assertTrue(QuantizeConfig.is_fp8_enabled())
with fp8_autocast(enabled=True, fp8_recipe=bs, mesh_resource=MeshResource()):
self.assertTrue(get_quantize_config().is_fp8_enabled())
self._compare_mxfp8_scaling(bs)
self._check_default_state()
@unittest.skipIf(not is_fp8_supported, reason=reason)
def test_fp8_autocast_with_sharding_resource(self):
QuantizeConfig.finalize() # Ensure the testing not affect by previous tests.
self._check_default_state()
ds = DelayedScaling(margin=5.0, fp8_format=FP8Format.E4M3, amax_history_len=1)
mesh_s = (
(MeshResource(None, None)),
(MeshResource("dp", None)),
(MeshResource(None, "tp")),
(MeshResource("dp", "tp")),
)
# TODO (Ming Huang): Support multi-GPUs testing. # pylint: disable=fixme
mesh_shape = (1, 1)
devices = np.asarray(jax.devices()[:1]).reshape(*mesh_shape)
with jax.sharding.Mesh(devices, ("dp", "tp")):
for sr in mesh_s:
with fp8_autocast(enabled=True, fp8_recipe=ds, mesh_resource=sr):
self.assertTrue(QuantizeConfig.is_fp8_enabled())
self._compare_delay_scaling(get_delayed_scaling(), ds)
self.assertEqual(sr, global_mesh_resource())
self._check_default_state()
......@@ -23,11 +23,14 @@ from utils import EncoderLayer as RefEncoderLayer
from transformer_engine.common import recipe
from transformer_engine.jax.flax import TransformerLayer, TransformerLayerType
from transformer_engine.jax.quantize import (
QuantizeConfig,
get_quantize_config,
ScalingMode,
is_fp8_available,
update_collections,
TensorSource,
fp8_autocast,
)
from transformer_engine.jax.sharding import MeshResource
@pytest.fixture(autouse=True, scope="function")
......@@ -262,6 +265,16 @@ ATTRS = [
_KEY_OF_RELATIVE_EMBEDDING: False,
_KEY_OF_WINDOW_SIZE: (2, 2),
},
# attrs29
{
_KEY_OF_RELATIVE_EMBEDDING: True,
_KEY_OF_SELF_ATTN_BIAS_TYPE: "pre_scale_bias",
},
# attrs30
{
_KEY_OF_RELATIVE_EMBEDDING: True,
_KEY_OF_SELF_ATTN_BIAS_TYPE: "post_scale_bias",
},
]
ATTRS = [{**BASE_ATTRS, **attr} for attr in ATTRS]
......@@ -345,7 +358,7 @@ class BaseRunner:
ref_params, test_params = self._sync_params(ref_params, test_params)
if QuantizeConfig.is_fp8_enabled():
if get_quantize_config().is_fp8_enabled():
for _ in range(4):
_, updated_state = jax.value_and_grad(self._loss_fn, argnums=(3,), has_aux=False)(
inputs,
......@@ -354,12 +367,15 @@ class BaseRunner:
test_others,
test_layer,
)
if QuantizeConfig.SCALING_MODE == ScalingMode.DELAYED_TENSOR_SCALING:
if (
get_quantize_config().get_scaling_mode(TensorSource.X)
== ScalingMode.DELAYED_TENSOR_SCALING
):
_, updated_quantize_meta = flax.core.pop(
updated_state[0], QuantizeConfig.COLLECTION_NAME
updated_state[0], get_quantize_config().COLLECTION_NAME
)
test_others = update_collections(
{QuantizeConfig.COLLECTION_NAME: updated_quantize_meta}, test_others
{get_quantize_config().COLLECTION_NAME: updated_quantize_meta}, test_others
)
del updated_quantize_meta
del updated_state
......@@ -489,29 +505,33 @@ class BaseTester:
def test_forward(self, data_shape, dtype, attrs):
"""Test normal datatype forward"""
QuantizeConfig.finalize() # Ensure FP8 disabled.
# Ensure FP8 disabled.
# Empty MeshResource is used as we are running on a single device
with fp8_autocast(enabled=False, mesh_resource=MeshResource()):
self.runner(attrs).test_forward(data_shape, dtype)
def test_backward(self, data_shape, dtype, attrs):
"""Test normal datatype backward"""
QuantizeConfig.finalize() # Ensure FP8 disabled.
# Ensure FP8 disabled.
# Empty MeshResource is used as we are running on a single device
with fp8_autocast(enabled=False, mesh_resource=MeshResource()):
self.runner(attrs).test_backward(data_shape, dtype)
@pytest.mark.skipif(not is_fp8_supported, reason=reason)
@pytest.mark.parametrize("fp8_recipe", QUANTIZE_RECIPES)
def test_forward_with_fp8(self, data_shape, dtype, attrs, fp8_recipe):
"""Test forward with fp8 enabled"""
QuantizeConfig.initialize(fp8_recipe=fp8_recipe)
# Empty MeshResource is used as we are running on a single device
with fp8_autocast(enabled=True, fp8_recipe=fp8_recipe, mesh_resource=MeshResource()):
self.runner(attrs).test_forward(data_shape, dtype, rtol=1e-4, atol=1e-3)
QuantizeConfig.finalize()
@pytest.mark.skipif(not is_fp8_supported, reason=reason)
@pytest.mark.parametrize("fp8_recipe", QUANTIZE_RECIPES)
def test_backward_with_fp8(self, data_shape, dtype, attrs, fp8_recipe):
"""Test backward with fp8 enabled"""
QuantizeConfig.initialize(fp8_recipe=fp8_recipe)
# Empty MeshResource is used as we are running on a single device
with fp8_autocast(enabled=True, fp8_recipe=fp8_recipe, mesh_resource=MeshResource()):
self.runner(attrs).test_backward(data_shape, dtype, rtol=1e-4, atol=1e-3)
QuantizeConfig.finalize()
class TestEncoderLayer(BaseTester):
......
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
from functools import partial
import jax
import jax.numpy as jnp
import jax.experimental.multihost_utils as jem
from transformer_engine.jax.dense import grouped_dense as te_grouped_dense
from transformer_engine.jax.quantize import (
QuantizerFactory,
ScalingMode,
)
from utils import assert_allclose, dtype_tols
N_GROUP = 8
MESH_AXIS_NAME = "fsdp"
def test_grouped_gemm_fp8_allgather(data_shapes, kernel_fsdp_axis):
assert kernel_fsdp_axis in [1, 2]
x_shape, w_shape = data_shapes
x_sharding = NamedSharding(mesh, PartitionSpec(None, MESH_AXIS_NAME, None, None, None))
w_sharding = (
NamedSharding(mesh, PartitionSpec(None, None, MESH_AXIS_NAME))
if kernel_fsdp_axis == 2
else NamedSharding(mesh, PartitionSpec(None, MESH_AXIS_NAME, None))
)
w_no_sharding = NamedSharding(mesh, PartitionSpec(None, None, None))
def init_data():
x_key = jax.random.PRNGKey(0)
w_key = jax.random.PRNGKey(1)
x = jax.random.normal(x_key, shape=(N_GROUP, *x_shape), dtype=jnp.bfloat16)
w = jax.random.normal(w_key, shape=(N_GROUP, *w_shape), dtype=jnp.bfloat16)
w_amax = jnp.max(jnp.abs(w), axis=range(1, w.ndim))
return x, w, w, w_amax
def test_func(outter_x, outter_w, outter_w_amax):
in_specs = (x_sharding.spec, w_sharding.spec, None)
out_specs = x_sharding.spec
@partial(
shard_map.shard_map,
mesh=mesh,
in_specs=in_specs,
out_specs=out_specs,
check_rep=False,
)
def sharded_group_gemm(x, w, w_amax):
group_size = x.shape[0]
x_reshaped = x.reshape(-1, x.shape[-1])
n_groups = jnp.full(group_size, x_reshaped.shape[0] // group_size)
quantizer_set = QuantizerFactory.create_set(
scaling_mode=ScalingMode.CURRENT_TENSOR_SCALING,
fwd_dtype=jnp.float8_e4m3fn,
bwd_dtype=jnp.float8_e5m2,
is_2x2x=True,
n_groups=group_size,
)
output = te_grouped_dense(
x_reshaped,
w,
n_groups,
kernel_amax=w_amax,
quantizer_set=quantizer_set,
kernel_fsdp_info=(MESH_AXIS_NAME, kernel_fsdp_axis),
)
output = output.reshape(*x.shape[:-1], -1)
return output
def run(x, w, w_amax):
output = sharded_group_gemm(x, w, w_amax)
return output
output, vjp_fn = jax.vjp(run, outter_x, outter_w, outter_w_amax)
dx, dw, _ = vjp_fn(output)
return output, dx, dw
def ref_func(outter_x, outter_w):
in_specs = (x_sharding.spec, w_no_sharding.spec)
out_specs = x_sharding.spec
@partial(
shard_map.shard_map,
mesh=mesh,
in_specs=in_specs,
out_specs=out_specs,
check_rep=False,
)
def sharded_group_gemm(x, w):
group_size = x.shape[0]
x_reshaped = x.reshape(-1, x.shape[-1])
n_groups = jnp.full(group_size, x_reshaped.shape[0] // group_size)
quantizer_set = QuantizerFactory.create_set(
scaling_mode=ScalingMode.CURRENT_TENSOR_SCALING,
fwd_dtype=jnp.float8_e4m3fn,
bwd_dtype=jnp.float8_e5m2,
is_2x2x=True,
n_groups=group_size,
)
output = te_grouped_dense(x_reshaped, w, n_groups, quantizer_set=quantizer_set)
output = output.reshape(*x.shape[:-1], -1)
return output
def run(x, w):
output = sharded_group_gemm(x, w)
return output
output, vjp_fn = jax.vjp(run, outter_x, outter_w)
dx, dw = vjp_fn(output)
return output, dx, dw
init_func = jax.jit(init_data, out_shardings=(x_sharding, w_sharding, w_no_sharding, None))
x, w, w_global, w_amax = init_func()
o_sharding = x_sharding
test_func_jitted = jax.jit(
test_func,
in_shardings=(x_sharding, w_sharding, None),
out_shardings=(o_sharding, x_sharding, w_sharding),
)
ref_func_jitted = jax.jit(
ref_func,
in_shardings=(x_sharding, w_no_sharding),
out_shardings=(o_sharding, x_sharding, w_no_sharding),
)
out, dx, dw = test_func_jitted(x, w, w_amax)
ref_out, ref_dx, ref_dw = ref_func_jitted(x, w_global)
e4m3_tols = dtype_tols(jnp.float8_e4m3fn)
e5m2_tols = dtype_tols(jnp.float8_e5m2)
out, ref_out = jem.process_allgather((out, ref_out))
dx, ref_dx = jem.process_allgather((dx, ref_dx))
dw, ref_dw = jem.process_allgather((dw, ref_dw))
jnp.allclose(out, ref_out, **e4m3_tols)
jnp.allclose(dx, ref_dx, **e5m2_tols)
jnp.allclose(dw, ref_dw, **e5m2_tols)
if __name__ == "__main__":
from jax.sharding import NamedSharding, PartitionSpec
from jax.experimental import shard_map
import sys
coord_addr = sys.argv[1]
proc_id = int(sys.argv[2])
num_procs = int(sys.argv[3])
jax.distributed.initialize(
coordinator_address=coord_addr, num_processes=num_procs, process_id=proc_id
)
mesh = jax.make_mesh((num_procs,), (MESH_AXIS_NAME,))
with mesh:
data_shapes = [((4, 16, 128, 7168), (7168, 2048))]
for data_shape in data_shapes:
for kernel_fsdp_axis in [1, 2]:
test_grouped_gemm_fp8_allgather(data_shape, kernel_fsdp_axis)
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
import pytest
from transformer_engine.jax.flax import extend_logical_axis_rules
from transformer_engine.jax.sharding import global_shard_guard, MeshResource
LOGICAL_RULES = [
[(("a1", None), ("a2", "ma2")), False],
[(("a1", None), ("a2", "ma2"), ("a3", ("ma31", "ma32"))), True],
[(("a1", None), ("a2", "ma2"), ("a3", "ma31"), ("a3", "ma32")), False],
[(("a1", None), ("a2", "ma2"), ("batch", "batch_1200234")), True],
[(("a1", None), ("a2", "ma2"), ("a2", "ma1"), ("batch", "model"), ("batch", "data")), True],
]
MeshS = [
MeshResource(),
MeshResource("data", None),
MeshResource(None, "model"),
MeshResource("data", "model"),
]
class TestShardingSideAPI:
@pytest.mark.parametrize("base_rules,need_assert", LOGICAL_RULES)
@pytest.mark.parametrize("sr", MeshS)
def test_extend_logical_axis_rules(self, base_rules, need_assert, sr):
with global_shard_guard(sr):
try:
target_te_rules = extend_logical_axis_rules(tuple())
extended_rules = extend_logical_axis_rules(base_rules)
assert extended_rules == (*base_rules, *target_te_rules)
assert not need_assert
except AssertionError as ae:
assert need_assert, f"{ae.args}"
......@@ -274,6 +274,8 @@ model_configs_mla = {
"mla_3_0": ModelConfig(8, 1, 16, 128, max_seqlen_kv=2048, head_dim_v=64), # inference
"mla_3_1": ModelConfig(8, 1, 16, 256, max_seqlen_kv=2048, head_dim_v=128), # inference
"mla_3_2": ModelConfig(8, 1, 16, 192, max_seqlen_kv=2048, head_dim_v=128), # inference
"mla_3_3": ModelConfig(8, 1, 16, 160, max_seqlen_kv=2048, head_dim_v=128), # inference
"mla_3_4": ModelConfig(8, 1, 16, 160, max_seqlen_kv=2048, head_dim_v=160), # inference
}
......
......@@ -37,6 +37,12 @@ model_configs_flash_attn = {
2, 4096, 12, 128, num_gqa_groups=2, attn_mask_type="causal", window_size=(512, 0)
), # GQA
"cp_2_3": ModelConfig(2, 4096, 12, 128, num_gqa_groups=2, window_size=(512, 512)), # GQA
"cp_3_0": ModelConfig(2, 4096, 12, 192, attn_mask_type="causal", head_dim_v=128), # MLA
"cp_3_1": ModelConfig(2, 4096, 12, 192, head_dim_v=128), # MLA
"cp_3_2": ModelConfig(
2, 4096, 12, 192, attn_mask_type="causal", window_size=(512, 0), head_dim_v=128
), # MLA
"cp_3_3": ModelConfig(2, 4096, 12, 192, window_size=(512, 512), head_dim_v=128), # MLA
}
......@@ -82,6 +88,8 @@ def test_cp_with_flash_attention(dtype, model, qkv_format, cp_comm_type):
f"CP implementation with QKVO A2A requires num_heads ({config.num_heads}) and"
f" num_gqa_groups ({config.num_gqa_groups}) to be divisible by cp_size (2)!"
)
if "p2p" not in cp_comm_type and config.head_dim_qk != config.head_dim_v:
pytest.skip("MLA CP currently only support KV P2P!")
subprocess.run(
get_bash_arguments(
......
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
"""Unit tests for context parallel utils."""
import torch
import unittest
from typing import Tuple
from transformer_engine.pytorch.attention.dot_product_attention.context_parallel import (
get_batch_on_this_cp_rank,
pad_thd_sequences_for_cp,
generate_positional_ids_for_cp,
)
class TestSequencePadding(unittest.TestCase):
def test_padding_with_custom_padding_values_sequences_shorter_than_divisibility_factor(self):
"""Test with custom padding values for all tensors."""
# Setup
input_ids = torch.tensor([1, 1, 1, 2, 2, 3, 3, 3, 3])
cu_seqlens = torch.tensor([0, 3, 5, 9])
labels = torch.tensor([-100, -100, -100, -100, -100, -100, -100, 13, -100])
positional_ids = torch.tensor([0, 1, 2, 0, 1, 0, 1, 2, 3])
divisibility_factor = 8
pid = 777
label_pad = -200
input_ids_padded, labels_padded, cu_seqlens_padded = pad_thd_sequences_for_cp(
input_ids.unsqueeze(0),
labels.unsqueeze(0),
cu_seqlens,
divisibility_factor,
padding_token_id=pid,
padding_label_id=label_pad,
)
positional_ids_padded = generate_positional_ids_for_cp(
cu_seqlens,
divisibility_factor,
)
# Sequence: [ a a a p p p p p b b pppppp ccccpppp]
print("input_ids_padded: ", input_ids_padded)
print("labels_padded: ", labels_padded)
print("positional_ids_padded: ", positional_ids_padded)
print("cu_seqlens_padded: ", cu_seqlens_padded)
expected_input_ids = torch.tensor(
[
1,
1,
1,
pid,
pid,
pid,
pid,
pid,
2,
2,
pid,
pid,
pid,
pid,
pid,
pid,
3,
3,
3,
3,
pid,
pid,
pid,
pid,
]
)
expected_cu_seqlens_padded = torch.tensor([0, 8, 16, 24])
expected_labels_padded = torch.tensor(
[
-100,
-100,
-100,
label_pad,
label_pad,
label_pad,
label_pad,
label_pad,
-100,
-100,
label_pad,
label_pad,
label_pad,
label_pad,
label_pad,
label_pad,
-100,
-100,
13,
-100,
label_pad,
label_pad,
label_pad,
label_pad,
]
)
expected_positional_ids = torch.tensor(
[0, 1, 2, 3, 4, 5, 6, 7, 0, 1, 2, 3, 4, 5, 6, 7, 0, 1, 2, 3, 4, 5, 6, 7]
)
assert torch.equal(input_ids_padded, expected_input_ids)
assert torch.equal(labels_padded, expected_labels_padded)
assert torch.equal(positional_ids_padded, expected_positional_ids)
assert torch.equal(cu_seqlens_padded, expected_cu_seqlens_padded)
def test_mixed_sequence_lengths_with_divisibility_factor(self):
"""Test with sequences both shorter and longer than divisibility factor."""
# Setup - divisibility factor 6
# Seq 1: length 2 (shorter than 6, needs 4 padding)
# Seq 2: length 7 (longer than 6, needs 5 padding to reach 12)
# Seq 3: length 4 (shorter than 6, needs 2 padding)
# Seq 4: length 10 (longer than 6, needs 2 padding to reach 12)
input_ids = torch.tensor(
[1, 1, 2, 2, 2, 2, 2, 2, 2, 3, 3, 3, 3, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4]
)
labels = torch.tensor(
[
10,
11,
20,
21,
22,
23,
24,
25,
26,
30,
31,
32,
33,
40,
41,
42,
43,
44,
45,
46,
47,
48,
49,
]
)
positional_ids = torch.tensor(
[0, 1, 0, 1, 2, 3, 4, 5, 6, 0, 1, 2, 3, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9]
)
cu_seqlens = torch.tensor([0, 2, 9, 13, 23])
divisibility_factor = 6
pid = 999
label_pad = -300
# Execute
input_ids_padded, labels_padded, cu_seqlens_padded = pad_thd_sequences_for_cp(
input_ids.unsqueeze(0),
labels.unsqueeze(0),
cu_seqlens,
divisibility_factor,
padding_token_id=pid,
padding_label_id=label_pad,
)
positional_ids_padded = generate_positional_ids_for_cp(
cu_seqlens,
divisibility_factor,
)
# Assert
# Seq 1: [1,1] + 4 pads = 6 total
# Seq 2: [2,2,2,2,2,2,2] + 5 pads = 12 total
# Seq 3: [3,3,3,3] + 2 pads = 6 total
# Seq 4: [4,4,4,4,4,4,4,4,4,4] + 2 pads = 12 total
expected_input_ids = torch.tensor(
[
1,
1,
pid,
pid,
pid,
pid, # Seq 1: 2 + 4 padding
2,
2,
2,
2,
2,
2,
2,
pid,
pid,
pid,
pid,
pid, # Seq 2: 7 + 5 padding
3,
3,
3,
3,
pid,
pid, # Seq 3: 4 + 2 padding
4,
4,
4,
4,
4,
4,
4,
4,
4,
4,
pid,
pid, # Seq 4: 10 + 2 padding
]
)
expected_labels = torch.tensor(
[
10,
11,
label_pad,
label_pad,
label_pad,
label_pad,
20,
21,
22,
23,
24,
25,
26,
label_pad,
label_pad,
label_pad,
label_pad,
label_pad,
30,
31,
32,
33,
label_pad,
label_pad,
40,
41,
42,
43,
44,
45,
46,
47,
48,
49,
label_pad,
label_pad,
]
)
expected_positional_ids = torch.tensor(
[
0,
1,
2,
3,
4,
5, # Seq 1 positions continue through padding
0,
1,
2,
3,
4,
5,
6,
7,
8,
9,
10,
11, # Seq 2 positions continue
0,
1,
2,
3,
4,
5, # Seq 3 positions continue
0,
1,
2,
3,
4,
5,
6,
7,
8,
9,
10,
11, # Seq 4 positions continue
]
)
expected_cu_seqlens_padded = torch.tensor([0, 6, 18, 24, 36])
self.assertTrue(torch.equal(input_ids_padded, expected_input_ids))
self.assertTrue(torch.equal(labels_padded, expected_labels))
self.assertTrue(torch.equal(positional_ids_padded, expected_positional_ids))
self.assertTrue(torch.equal(cu_seqlens_padded, expected_cu_seqlens_padded))
def test_sequences_longer_than_divisibility_factor(self):
"""Test with all sequences longer than the divisibility factor."""
# Setup - divisibility factor 4, all sequences longer than 4
# Seq 1: length 7 (needs 1 padding to reach 8)
# Seq 2: length 11 (needs 1 padding to reach 12)
# Seq 3: length 5 (needs 3 padding to reach 8)
input_ids = torch.tensor(
[
1,
1,
1,
1,
1,
1,
1, # 7 tokens
2,
2,
2,
2,
2,
2,
2,
2,
2,
2,
2, # 11 tokens
3,
3,
3,
3,
3, # 5 tokens
]
)
labels = torch.tensor(
[
100,
101,
102,
103,
104,
105,
106,
200,
201,
202,
203,
204,
205,
206,
207,
208,
209,
210,
300,
301,
302,
303,
304,
]
)
positional_ids = torch.tensor(
[0, 1, 2, 3, 4, 5, 6, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 0, 1, 2, 3, 4]
)
cu_seqlens = torch.tensor([0, 7, 18, 23])
divisibility_factor = 4
pid = 888
label_pad = -400
# Execute
input_ids_padded, labels_padded, cu_seqlens_padded = pad_thd_sequences_for_cp(
input_ids.unsqueeze(0),
labels.unsqueeze(0),
cu_seqlens,
divisibility_factor,
padding_token_id=pid,
padding_label_id=label_pad,
)
positional_ids_padded = generate_positional_ids_for_cp(
cu_seqlens,
divisibility_factor,
)
# Assert
# Seq 1: 7 + 1 pad = 8 (divisible by 4)
# Seq 2: 11 + 1 pad = 12 (divisible by 4)
# Seq 3: 5 + 3 pads = 8 (divisible by 4)
expected_input_ids = torch.tensor(
[
1,
1,
1,
1,
1,
1,
1,
pid,
2,
2,
2,
2,
2,
2,
2,
2,
2,
2,
2,
pid,
3,
3,
3,
3,
3,
pid,
pid,
pid,
]
)
expected_labels = torch.tensor(
[
100,
101,
102,
103,
104,
105,
106,
label_pad,
200,
201,
202,
203,
204,
205,
206,
207,
208,
209,
210,
label_pad,
300,
301,
302,
303,
304,
label_pad,
label_pad,
label_pad,
]
)
expected_positional_ids = torch.tensor(
[0, 1, 2, 3, 4, 5, 6, 7, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 0, 1, 2, 3, 4, 5, 6, 7]
)
expected_cu_seqlens_padded = torch.tensor([0, 8, 20, 28])
self.assertTrue(torch.equal(input_ids_padded, expected_input_ids))
self.assertTrue(torch.equal(labels_padded, expected_labels))
self.assertTrue(torch.equal(positional_ids_padded, expected_positional_ids))
self.assertTrue(torch.equal(cu_seqlens_padded, expected_cu_seqlens_padded))
class TestContextParallelUtils(unittest.TestCase):
"""Test utilities for context parallel functionality."""
def setUp(self):
"""Set up mock distributed environment."""
# Mock torch.distributed functions
self.original_get_world_size = torch.distributed.get_world_size
self.original_get_rank = torch.distributed.get_rank
def tearDown(self):
"""Restore original torch.distributed functions."""
torch.distributed.get_world_size = self.original_get_world_size
torch.distributed.get_rank = self.original_get_rank
def _mock_distributed_env(self, cp_size, cp_rank):
"""Mock the distributed environment for testing."""
def mock_get_world_size(group=None):
return cp_size
def mock_get_rank(group=None):
return cp_rank
torch.distributed.get_world_size = mock_get_world_size
torch.distributed.get_rank = mock_get_rank
def test_cp_rank_slicing_simple_case(self):
"""Test CP rank slicing with a simple 2-rank, single sequence case."""
# Setup: Single sequence of length 8, CP size = 2
# Each sequence gets divided into 2*cp_size = 4 slices of size 2 each
# Rank 0 gets slices [0,1] and [6,7] (first and last)
# Rank 1 gets slices [2,3] and [4,5] (second and second-to-last)
input_ids = torch.tensor([[1, 2, 3, 4, 5, 6, 7, 8]]) # Shape: (1, 8) - batch first
labels = torch.tensor([[10, 20, 30, 40, 50, 60, 70, 80]])
position_ids = torch.tensor([0, 1, 2, 3, 4, 5, 6, 7]) # Shape: (8,) - 1D as expected
cu_seqlens = torch.tensor([0, 8])
# Test rank 0
self._mock_distributed_env(cp_size=2, cp_rank=0)
input_ids_r0, labels_r0, pos_ids_r0 = get_batch_on_this_cp_rank(
cu_seqlens, input_ids, labels, position_ids
)
# Rank 0 should get indices [0,1] and [6,7]
expected_input_ids_r0 = torch.tensor([[1, 2, 7, 8]])
expected_labels_r0 = torch.tensor([[10, 20, 70, 80]])
expected_pos_ids_r0 = torch.tensor([0, 1, 6, 7])
self.assertTrue(torch.equal(input_ids_r0, expected_input_ids_r0))
self.assertTrue(torch.equal(labels_r0, expected_labels_r0))
self.assertTrue(torch.equal(pos_ids_r0, expected_pos_ids_r0))
# Test rank 1
self._mock_distributed_env(cp_size=2, cp_rank=1)
input_ids_r1, labels_r1, pos_ids_r1 = get_batch_on_this_cp_rank(
cu_seqlens, input_ids, labels, position_ids
)
# Rank 1 should get indices [2,3] and [4,5]
expected_input_ids_r1 = torch.tensor([[3, 4, 5, 6]])
expected_labels_r1 = torch.tensor([[30, 40, 50, 60]])
expected_pos_ids_r1 = torch.tensor([2, 3, 4, 5])
self.assertTrue(torch.equal(input_ids_r1, expected_input_ids_r1))
self.assertTrue(torch.equal(labels_r1, expected_labels_r1))
self.assertTrue(torch.equal(pos_ids_r1, expected_pos_ids_r1))
def test_cp_rank_slicing_multiple_sequences(self):
"""Test CP rank slicing with multiple sequences."""
# Setup: Two sequences of length 8 each, CP size = 2
# Total sequence length = 16, cu_seqlens = [0, 8, 16]
input_ids = torch.tensor([[1, 2, 3, 4, 5, 6, 7, 8, 11, 12, 13, 14, 15, 16, 17, 18]])
labels = torch.tensor(
[[10, 20, 30, 40, 50, 60, 70, 80, 110, 120, 130, 140, 150, 160, 170, 180]]
)
position_ids = torch.tensor([0, 1, 2, 3, 4, 5, 6, 7, 0, 1, 2, 3, 4, 5, 6, 7])
cu_seqlens = torch.tensor([0, 8, 16])
# Test rank 0
self._mock_distributed_env(cp_size=2, cp_rank=0)
input_ids_r0, labels_r0, pos_ids_r0 = get_batch_on_this_cp_rank(
cu_seqlens, input_ids, labels, position_ids
)
# For each sequence, rank 0 gets first and last slices
# Seq 1: indices [0,1] and [6,7] -> values [1,2] and [7,8]
# Seq 2: indices [8,9] and [14,15] -> values [11,12] and [17,18]
expected_input_ids_r0 = torch.tensor([[1, 2, 7, 8, 11, 12, 17, 18]])
expected_labels_r0 = torch.tensor([[10, 20, 70, 80, 110, 120, 170, 180]])
expected_pos_ids_r0 = torch.tensor([0, 1, 6, 7, 0, 1, 6, 7])
self.assertTrue(torch.equal(input_ids_r0, expected_input_ids_r0))
self.assertTrue(torch.equal(labels_r0, expected_labels_r0))
self.assertTrue(torch.equal(pos_ids_r0, expected_pos_ids_r0))
def test_cp_rank_slicing_with_cp_size_1(self):
"""Test that CP size = 1 returns original tensors unchanged."""
input_ids = torch.tensor([[1, 2, 3, 4, 5, 6, 7, 8]])
labels = torch.tensor([[10, 20, 30, 40, 50, 60, 70, 80]])
position_ids = torch.tensor([0, 1, 2, 3, 4, 5, 6, 7])
cu_seqlens = torch.tensor([0, 8])
self._mock_distributed_env(cp_size=1, cp_rank=0)
input_ids_result, labels_result, pos_ids_result = get_batch_on_this_cp_rank(
cu_seqlens, input_ids, labels, position_ids
)
# With CP size = 1, should return original tensors
self.assertTrue(torch.equal(input_ids_result, input_ids))
self.assertTrue(torch.equal(labels_result, labels))
self.assertTrue(torch.equal(pos_ids_result, position_ids))
def test_cp_rank_slicing_sequence_dim_detection(self):
"""Test that the function correctly detects sequence dimension."""
# Test with sequence dimension = 0 (sequence_length, batch_size)
input_ids = torch.tensor(
[[1, 10], [2, 20], [3, 30], [4, 40], [5, 50], [6, 60], [7, 70], [8, 80]]
) # (8, 2)
labels = torch.tensor(
[[1, 10], [2, 20], [3, 30], [4, 40], [5, 50], [6, 60], [7, 70], [8, 80]]
)
position_ids = torch.tensor(
[[0, 0], [1, 1], [2, 2], [3, 3], [4, 4], [5, 5], [6, 6], [7, 7]]
)
cu_seqlens = torch.tensor([0, 8])
self._mock_distributed_env(cp_size=2, cp_rank=0)
input_ids_r0, labels_r0, pos_ids_r0 = get_batch_on_this_cp_rank(
cu_seqlens, input_ids, labels, position_ids
)
# Should get indices [0,1] and [6,7] along dimension 0
expected_input_ids_r0 = torch.tensor([[1, 10], [2, 20], [7, 70], [8, 80]])
expected_labels_r0 = torch.tensor([[1, 10], [2, 20], [7, 70], [8, 80]])
expected_pos_ids_r0 = torch.tensor([[0, 0], [1, 1], [6, 6], [7, 7]])
self.assertTrue(torch.equal(input_ids_r0, expected_input_ids_r0))
self.assertTrue(torch.equal(labels_r0, expected_labels_r0))
self.assertTrue(torch.equal(pos_ids_r0, expected_pos_ids_r0))
def test_cp_rank_slicing_mixed_dimensions(self):
"""Test CP rank slicing where input_ids/labels are 1D but position_ids has batch dimension."""
# Setup: Single sequence of length 8, CP size = 2
# This tests the opposite case from the simple test:
# - input_ids and labels: 1D (no batch dimension)
# - position_ids: 2D (has batch dimension)
input_ids = torch.tensor([1, 2, 3, 4, 5, 6, 7, 8]) # Shape: (8,) - 1D
labels = torch.tensor([10, 20, 30, 40, 50, 60, 70, 80]) # Shape: (8,) - 1D
position_ids = torch.tensor([[0, 1, 2, 3, 4, 5, 6, 7]]) # Shape: (1, 8) - 2D with batch
cu_seqlens = torch.tensor([0, 8])
# Test rank 0
self._mock_distributed_env(cp_size=2, cp_rank=0)
input_ids_r0, labels_r0, pos_ids_r0 = get_batch_on_this_cp_rank(
cu_seqlens, input_ids, labels, position_ids
)
# Rank 0 should get indices [0,1] and [6,7]
expected_input_ids_r0 = torch.tensor([1, 2, 7, 8]) # 1D result
expected_labels_r0 = torch.tensor([10, 20, 70, 80]) # 1D result
expected_pos_ids_r0 = torch.tensor([[0, 1, 6, 7]]) # 2D result (preserves batch dim)
self.assertTrue(torch.equal(input_ids_r0, expected_input_ids_r0))
self.assertTrue(torch.equal(labels_r0, expected_labels_r0))
self.assertTrue(torch.equal(pos_ids_r0, expected_pos_ids_r0))
# Test rank 1
self._mock_distributed_env(cp_size=2, cp_rank=1)
input_ids_r1, labels_r1, pos_ids_r1 = get_batch_on_this_cp_rank(
cu_seqlens, input_ids, labels, position_ids
)
# Rank 1 should get indices [2,3] and [4,5]
expected_input_ids_r1 = torch.tensor([3, 4, 5, 6]) # 1D result
expected_labels_r1 = torch.tensor([30, 40, 50, 60]) # 1D result
expected_pos_ids_r1 = torch.tensor([[2, 3, 4, 5]]) # 2D result (preserves batch dim)
self.assertTrue(torch.equal(input_ids_r1, expected_input_ids_r1))
self.assertTrue(torch.equal(labels_r1, expected_labels_r1))
self.assertTrue(torch.equal(pos_ids_r1, expected_pos_ids_r1))
def test_integration_with_padding_and_cp_slicing(self):
"""Integration test: pad sequences then slice for CP ranks."""
# Start with unpadded sequences
input_ids = torch.tensor([1, 1, 2, 2, 2]) # Two sequences: [1,1] and [2,2,2]
labels = torch.tensor([10, 11, 20, 21, 22])
positional_ids = torch.tensor([0, 1, 0, 1, 2])
cu_seqlens = torch.tensor([0, 2, 5])
divisibility_factor = 4 # Will pad to lengths 4 and 4
# First, pad sequences
input_ids_padded, labels_padded, cu_seqlens_padded = pad_thd_sequences_for_cp(
input_ids.unsqueeze(0),
labels.unsqueeze(0),
cu_seqlens,
divisibility_factor,
padding_token_id=0,
padding_label_id=-100,
)
positional_ids_padded = generate_positional_ids_for_cp(
cu_seqlens,
divisibility_factor,
)
# Expected after padding: [1,1,0,0,2,2,2,0] with cu_seqlens [0,4,8]
expected_padded = torch.tensor([1, 1, 0, 0, 2, 2, 2, 0])
self.assertTrue(torch.equal(input_ids_padded, expected_padded))
# Now test CP slicing with cp_size=2
# Test rank 0
self._mock_distributed_env(cp_size=2, cp_rank=0)
input_ids_r0, labels_r0, pos_ids_r0 = get_batch_on_this_cp_rank(
cu_seqlens_padded,
input_ids_padded.unsqueeze(0),
labels_padded.unsqueeze(0),
positional_ids_padded,
)
# Each sequence of length 4 gets divided into 4 slices of size 1
# Rank 0 gets slices [0] and [3] from each sequence
# Seq 1: indices [0] and [3] -> values [1] and [0]
# Seq 2: indices [4] and [7] -> values [2] and [0]
expected_input_ids_r0 = torch.tensor([[1, 0, 2, 0]])
self.assertTrue(torch.equal(input_ids_r0, expected_input_ids_r0))
if __name__ == "__main__":
unittest.main()
......@@ -268,7 +268,7 @@ def test_statistics_collection(configs_dir, feature_dirs):
)[0]
expected_underflows = (
((tensor_fp8._data == 0).sum() - (tensor == 0).sum()) * 100 / (100 * 100 * 5)
((tensor_fp8.dequantize() == 0).sum() - (tensor == 0).sum()) * 100 / (100 * 100 * 5)
)
assert debug_api.transformer_engine.inspect_tensor_enabled(
......@@ -302,7 +302,7 @@ def test_statistics_collection(configs_dir, feature_dirs):
)[0]
# Second config in same yaml
tensor = torch.rand((100, 100, 5))
tensor = torch.rand((100, 100, 5)).cuda()
debug_api.transformer_engine.inspect_tensor(
"decoder.6.mlp.fc1",
tensor_name="activation",
......@@ -316,7 +316,9 @@ def test_statistics_collection(configs_dir, feature_dirs):
stats = log()
stats_names = [x[3] for x in stats.keys()]
all(s in stats_names for s in ["cur_amax", "dynamic_range", "mean", "std", "l1_norm"])
assert stats[("decoder.6.mlp.fc1", "activation", "mean", 200)] == tensor.mean()
torch.testing.assert_close(
stats[("decoder.6.mlp.fc1", "activation", "mean", 200)], tensor.mean()
)
debug_api.transformer_engine.inspect_tensor(
"decoder.7.mlp.fc1",
......@@ -331,7 +333,7 @@ def test_statistics_collection(configs_dir, feature_dirs):
stats = log()
stats_names = [x[3] for x in stats.keys()]
all(s in stats_names for s in ["mean", "std", "l1_norm", "min", "max"])
assert stats[("decoder.7.mlp.fc1", "weight", "max", 200)] == tensor.max()
torch.testing.assert_close(stats[("decoder.7.mlp.fc1", "weight", "max", 200)], tensor.max())
assert not debug_api.transformer_engine.inspect_tensor_enabled(
"decoder.7.mlp.fc1", tensor_name="weight", iteration=201
......@@ -377,7 +379,7 @@ def test_statistics_multi_run(configs_dir, feature_dirs):
return quantizer(t.cuda())
shape = [1024, 1024]
tensors = [torch.randn(shape) for _ in range(2)]
tensors = [torch.randn(shape).cuda() for _ in range(2)]
tensors_fp8 = [fp8_tensor(tensors[i]) for i in range(2)]
feed(tensors[0], tensors_fp8[0], quantizer)
......
......@@ -119,6 +119,9 @@ def read_log(log_dir: str) -> str:
def test_sanity(feature_dirs):
if not fp8_available:
pytest.skip(reason_for_no_fp8)
log_all_stats_config = LOG_QUANTIZED_CONFIG_BASE.format(stats=", ".join(all_stats))
with debug_session(log_all_stats_config, feature_dirs) as log_dir:
model = te.Linear(128, 128, params_dtype=torch.bfloat16)
......@@ -164,8 +167,8 @@ def test_numerics(fp8_recipe, feature_dirs):
num_quantizers=3,
)
tensor = torch.zeros(1024, 1024).cuda()
tensor[0, :] = 1000
tensor = torch.randn(1024, 1024).cuda()
tensor[0, 100:200] = -0.0
quantizer = recipe_state.make_quantizers()[0]
quantized_tensor = quantizer(tensor)
......@@ -188,15 +191,13 @@ def test_numerics(fp8_recipe, feature_dirs):
if "underflows%" in line:
underflows = float(line.split("value=")[1])
expected = (
((dequantized_tensor == 0).sum() - (tensor == 0).sum())
/ dequantized_tensor.numel()
* 100
((dequantized_tensor == 0).sum() - (tensor == 0).sum()) / tensor.numel() * 100
)
assert underflows == pytest.approx(expected.cpu(), abs=1e-4)
if "mse" in line:
mse = float(line.split("value=")[1])
expected = torch.nn.functional.mse_loss(dequantized_tensor, tensor, reduction="mean")
assert mse == pytest.approx(expected.cpu(), abs=1e-6)
assert mse == pytest.approx(expected.cpu(), abs=1e-4)
if "overflows%" in line:
overflows = float(line.split("value=")[1])
expected = (
......@@ -207,6 +208,9 @@ def test_numerics(fp8_recipe, feature_dirs):
@pytest.mark.parametrize("layer", ["linear", "transformer"])
def test_log_every_3_or_5_layers(layer, configs_dir, feature_dirs):
if not fp8_available:
pytest.skip(reason_for_no_fp8)
# If layer does not invoke any feature in current iteration,
# then it changed into non-debug mode.
# This test checks whether this works correctly -
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment