Unverified Commit d7c9777e authored by Kirthi Shankar Sivamani's avatar Kirthi Shankar Sivamani Committed by GitHub
Browse files

Remove `nvidia-mathdx` dependency (#2295)



* Remove nvidia-mathdx dep
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* Fix SR
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* Add comment
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

---------
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>
parent d2945c6a
......@@ -19,7 +19,7 @@ jobs:
run: |
apt-get update
apt-get install -y git python3.9 pip cudnn9-cuda-12
pip install cmake==3.21.0 pybind11[global] ninja nvidia-mathdx==25.1.1
pip install cmake==3.21.0 pybind11[global] ninja
- name: 'Checkout'
uses: actions/checkout@v3
with:
......@@ -43,7 +43,7 @@ jobs:
run: |
apt-get update
apt-get install -y git python3.9 pip cudnn9-cuda-12
pip install cmake torch ninja pydantic importlib-metadata>=1.0 packaging pybind11 numpy einops onnxscript nvidia-mathdx==25.1.1
pip install cmake torch ninja pydantic importlib-metadata>=1.0 packaging pybind11 numpy einops onnxscript
- name: 'Checkout'
uses: actions/checkout@v3
with:
......@@ -63,7 +63,7 @@ jobs:
options: --user root
steps:
- name: 'Dependencies'
run: pip install pybind11[global] nvidia-mathdx==25.1.1
run: pip install pybind11[global]
- name: 'Checkout'
uses: actions/checkout@v3
with:
......@@ -83,7 +83,7 @@ jobs:
options: --user root
steps:
- name: 'Dependencies'
run: pip install torch pybind11[global] einops onnxscript nvidia-mathdx==25.1.1
run: pip install torch pybind11[global] einops onnxscript
- name: 'Checkout'
uses: actions/checkout@v3
with:
......
......@@ -23,7 +23,7 @@ git checkout $TARGET_BRANCH
git submodule update --init --recursive
# Install deps
/opt/python/cp310-cp310/bin/pip install cmake pybind11[global] ninja setuptools wheel nvidia-mathdx==25.1.1
/opt/python/cp310-cp310/bin/pip install cmake pybind11[global] ninja setuptools wheel
if $BUILD_METAPACKAGE ; then
cd /TransformerEngine
......
......@@ -3,8 +3,7 @@
# See LICENSE for license information.
[build-system]
requires = ["setuptools>=61.0", "cmake>=3.21", "wheel", "pybind11[global]", "ninja", "nvidia-mathdx==25.1.1", "pip", "torch>=2.1", "jax>=0.5.0", "flax>=0.7.1"]
requires = ["setuptools>=61.0", "cmake>=3.21", "wheel", "pybind11[global]", "ninja", "pip", "torch>=2.1", "jax>=0.5.0", "flax>=0.7.1"]
# Use legacy backend to import local packages in setup.py
build-backend = "setuptools.build_meta:__legacy__"
......@@ -98,28 +98,6 @@ set(CUTLASS_TOOLS_INCLUDE_DIR
# Python
find_package(Python COMPONENTS Interpreter Development.Module REQUIRED)
# NVIDIA MathDX include directory (from Python package install location)
if(NOT DEFINED MATHDX_INCLUDE_DIR)
execute_process(
COMMAND ${Python_EXECUTABLE} -m pip show nvidia-mathdx
OUTPUT_VARIABLE _PIP_SHOW_MATHDX
ERROR_VARIABLE _PIP_SHOW_MATHDX_ERR
RESULT_VARIABLE _PIP_SHOW_MATHDX_RES
OUTPUT_STRIP_TRAILING_WHITESPACE)
if(NOT _PIP_SHOW_MATHDX_RES EQUAL 0)
message(FATAL_ERROR "Failed to query 'nvidia-mathdx' with pip (using ${Python_EXECUTABLE}): ${_PIP_SHOW_MATHDX_ERR}")
endif()
string(REGEX MATCH "Location: ([^\n\r]+)" _MATHDX_LOC_MATCH "${_PIP_SHOW_MATHDX}")
if(NOT _MATHDX_LOC_MATCH)
message(FATAL_ERROR "Could not parse installation location for 'nvidia-mathdx'. Output was:\n${_PIP_SHOW_MATHDX}")
endif()
set(MATHDX_LOCATION "${CMAKE_MATCH_1}")
set(MATHDX_INCLUDE_DIR "${MATHDX_LOCATION}/nvidia/mathdx/include")
endif()
if(NOT EXISTS "${MATHDX_INCLUDE_DIR}")
message(FATAL_ERROR "MATHDX include directory not found at ${MATHDX_INCLUDE_DIR}. Set MATHDX_INCLUDE_DIR or ensure 'nvidia-mathdx' is installed for ${Python_EXECUTABLE}.")
endif()
# Configure Transformer Engine library
include_directories(${PROJECT_SOURCE_DIR}/..)
set(transformer_engine_SOURCES)
......@@ -263,7 +241,6 @@ target_link_libraries(transformer_engine PUBLIC
target_include_directories(transformer_engine PRIVATE
${CMAKE_CUDA_TOOLKIT_INCLUDE_DIRECTORIES})
target_include_directories(transformer_engine PRIVATE ${MATHDX_INCLUDE_DIR})
target_include_directories(transformer_engine SYSTEM PRIVATE
${CMAKE_CUDA_TOOLKIT_INCLUDE_DIRECTORIES}/cccl)
target_include_directories(transformer_engine PRIVATE "${CUDNN_FRONTEND_INCLUDE_DIR}")
......
......@@ -19,9 +19,9 @@
#include "common/common.h"
#include "common/util/cuda_runtime.h"
#include "common/util/curanddx.hpp"
#include "common/util/ptx.cuh"
#include "common/utils.cuh"
#include "curanddx.hpp"
#include "cutlass/arch/barrier.h"
#include "cutlass/cutlass.h"
#include "cutlass/gemm/collective/builders/sm100_common.inl"
......@@ -38,15 +38,6 @@ namespace transformer_engine {
namespace detail {
namespace {
// Define a cuRANDDx descriptor
// Note curanddx::PhiloxRounds<4> means 4 rounds of philox4_32. If the operator is not specified, it will be default to 10.
// curanddx::SM<800>() does NOT mean the code can only run on SM 800. The operator is used for do some internal checks, e.g.,
// if shared memory, if needed, is enough for the described problem, usually not applicable.
// curanddx doc: https://docs.nvidia.com/cuda/curanddx/index.html
using RNG = decltype(curanddx::Generator<curanddx::philox4_32>() + curanddx::PhiloxRounds<10>() + curanddx::SM<800>() + curanddx::Thread());
using namespace cute;
using cute::Tensor; // Ensure unqualified Tensor refers to cute::Tensor, not transformer_engine::Tensor
......@@ -502,8 +493,9 @@ rht_gemm_device(MShape M, NShape N, KShape K, ClusterTileShape cluster_tile,
// Initialize RNG for tile
const size_t rng_sequence
= thread_idx + k_tile * 256 + linear_tile_idx * K_TILE_MAX * 256;
RNG rng(rng_seed, rng_sequence, rng_offset);
curanddx::uniform_bits dist;
transformer_engine::curanddx::detail::philox4x32_native_state<10> rng;
rng.init(rng_seed, rng_sequence, rng_offset);
uint4 random_uint4 = uint4{0, 0, 0, 0};
CUTLASS_PRAGMA_UNROLL
......@@ -511,7 +503,7 @@ rht_gemm_device(MShape M, NShape N, KShape K, ClusterTileShape cluster_tile,
auto acc_scale = cutlass::minimum_with_nan_propagation<ElementAccumulator>{}(acc_scales[v], cutlass::platform::numeric_limits<ElementAccumulator>::max());
// auto acc_scale = acc_scales[v];
if constexpr (kEnableStochasticRounding) {
random_uint4 = dist.generate4(rng);
random_uint4 = rng.generate4();
output_frgs[v] = StochasticNumericConverter(
cutlass::multiplies<cutlass::Array<ElementAccumulator, VectorSize>>{}(
compute_frgs[v],
......
......@@ -17,9 +17,9 @@
#include "common/common.h"
#include "common/recipe/recipe_common.cuh"
#include "common/transpose/cast_transpose.h"
#include "common/util/curanddx.hpp"
#include "common/util/ptx.cuh"
#include "common/utils.cuh"
#include "curanddx.hpp"
namespace transformer_engine {
......@@ -33,14 +33,6 @@ using std::uint8_t;
using transformer_engine::detail::TypeExtrema;
// Define a cuRANDDx descriptor
// Note curanddx::PhiloxRounds<4> means 4 rounds of philox4_32. If the operator is not specified, it will be default to 10.
// curanddx::SM<800>() does NOT mean the code can only run on SM 800. The operator is used for do some internal checks, e.g.,
// if shared memory, if needed, is enough for the described problem, usually not applicable.
// curanddx doc: https://docs.nvidia.com/cuda/curanddx/index.html
using RNG = decltype(curanddx::Generator<curanddx::philox4_32>() + curanddx::PhiloxRounds<10>() +
curanddx::SM<800>() + curanddx::Thread());
// clang-format off
/*
......@@ -209,12 +201,15 @@ __device__ __forceinline__ float ComputeGlobalEncodeScaleFP4(const float global_
return global_encode_scale;
}
__device__ __forceinline__ uint32_t get_rbits(RNG& rng, uint4& random_uint4, int& rnd_idx) {
__device__ __forceinline__ uint32_t
get_rbits(transformer_engine::curanddx::detail::philox4x32_native_state<10>&
rng, // philox4x32_native_state<10>: 10 rounds of philox4_32
uint4& random_uint4, int& rnd_idx) {
if (rnd_idx == 4) {
rnd_idx = 0;
curanddx::uniform_bits dist;
random_uint4 = dist.generate4(rng);
random_uint4 = rng.generate4();
}
// Treat uint4 as an array of 4x uint32_t elements for indexing
const uint32_t* const rbits_arr = reinterpret_cast<uint32_t*>(&random_uint4);
const uint32_t rbits = rbits_arr[rnd_idx++];
......@@ -348,9 +343,11 @@ __global__ void __launch_bounds__(kThreadsPerBlock) block_scaled_1d_cast_transpo
threadIdx.x + block_idx_x * kThreadsPerBlock + block_idx_y * gridDim.x * kThreadsPerBlock;
const size_t rng_seed = rng_state != nullptr ? rng_state[0] : 0;
const size_t rng_offset = rng_state != nullptr ? rng_state[1] : 0;
RNG rng(rng_seed, rng_sequence, rng_offset);
curanddx::uniform_bits dist;
uint4 random_uint4 = kApplyStochasticRounding ? dist.generate4(rng) : uint4{0, 0, 0, 0};
transformer_engine::curanddx::detail::philox4x32_native_state<10> rng;
rng.init(rng_seed, rng_sequence, rng_offset);
uint4 random_uint4 = kApplyStochasticRounding ? rng.generate4() : uint4{0, 0, 0, 0};
int rnd_idx =
0; // Index of the random number. It increments each time when used and resets to 0 if reaches 4x
......
/*************************************************************************
* Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
*
* See LICENSE for license information.
************************************************************************/
#ifndef TRANSFORMER_ENGINE_COMMON_UTIL_CURANDDX_HPP_
#define TRANSFORMER_ENGINE_COMMON_UTIL_CURANDDX_HPP_
namespace transformer_engine {
namespace curanddx {
namespace detail {
inline constexpr unsigned int philox4x32_w32_0 = 0x9E3779B9U;
inline constexpr unsigned int philox4x32_w32_1 = 0xBB67AE85U;
inline constexpr unsigned int philox4x32_m4x32_0 = 0xD2511F53U;
inline constexpr unsigned int philox4x32_m4x32_1 = 0xCD9E8D57U;
__forceinline__ __device__ unsigned int mulhilo32(unsigned int a, unsigned int b,
unsigned int* hip) {
*hip = __umulhi(a, b);
return a * b;
}
__forceinline__ __device__ uint4 single_round(uint4 ctr, uint2 key) {
unsigned int hi0;
unsigned int hi1;
unsigned int lo0 = mulhilo32(philox4x32_m4x32_0, ctr.x, &hi0);
unsigned int lo1 = mulhilo32(philox4x32_m4x32_1, ctr.z, &hi1);
uint4 ret = {hi1 ^ ctr.y ^ key.x, lo1, hi0 ^ ctr.w ^ key.y, lo0};
return ret;
}
template <unsigned int Rounds>
__forceinline__ __device__ uint4 multiple_rounds(uint4 c, uint2 k) {
for (unsigned int i = 0; i < Rounds - 1; i++) {
c = single_round(c, k); // 1
k.x += philox4x32_w32_0;
k.y += philox4x32_w32_1;
}
return single_round(c, k); // Rounds
}
template <unsigned int Rounds>
struct philox4x32_native_state {
static constexpr unsigned int rounds = Rounds;
uint4 ctr;
uint2 key;
__forceinline__ __device__ void philox_state_incr() {
if (++ctr.x) return;
if (++ctr.y) return;
if (++ctr.z) return;
++ctr.w;
}
__forceinline__ __device__ void philox_state_incr(size_t n) {
unsigned int nlo = (unsigned int)(n);
unsigned int nhi = (unsigned int)(n >> 32);
ctr.x += nlo;
if (ctr.x < nlo) nhi++;
ctr.y += nhi;
if (nhi <= ctr.y) return;
if (++ctr.z) return;
++ctr.w;
}
__forceinline__ __device__ void philox_state_incr_hi(size_t n) {
unsigned int nlo = (unsigned int)(n);
unsigned int nhi = (unsigned int)(n >> 32);
ctr.z += nlo;
if (ctr.z < nlo) nhi++;
ctr.w += nhi;
}
// offset is the total # of 128bits generated with a single generate4() call
__forceinline__ __device__ void skip_offset(size_t n) { philox_state_incr(n); }
__forceinline__ __device__ void skip_subsequence(size_t n) { philox_state_incr_hi(n); }
__forceinline__ __device__ void init(size_t seed, size_t subsequence, size_t offset) {
ctr = make_uint4(0, 0, 0, 0);
key.x = (unsigned int)seed;
key.y = (unsigned int)(seed >> 32);
skip_subsequence(subsequence);
skip_offset(offset);
}
__forceinline__ __device__ uint4 generate4() {
auto tmp = multiple_rounds<Rounds>(ctr, key);
philox_state_incr();
return tmp;
}
};
} // namespace detail
} // namespace curanddx
} // namespace transformer_engine
#endif // TRANSFORMER_ENGINE_COMMON_UTIL_CURANDDX_HPP_
......@@ -32,9 +32,6 @@ namespace transformer_engine {
#if FP4_TYPE_SUPPORTED
namespace nvfp4_transpose {
using RNG = decltype(curanddx::Generator<curanddx::philox4_32>() + curanddx::PhiloxRounds<10>() +
curanddx::SM<800>() + curanddx::Thread());
using namespace ptx;
using nvfp4_scale_t = fp8e4m3;
......@@ -139,12 +136,15 @@ __device__ __forceinline__ float compute_global_encode_scaling_factor_FP4(const
return global_encode_scale;
}
__device__ __forceinline__ uint32_t get_rbits(RNG &rng, uint4 &random_uint4, int &rnd_idx) {
__device__ __forceinline__ uint32_t
get_rbits(transformer_engine::curanddx::detail::philox4x32_native_state<10>
&rng, // philox4x32_native_state<10>: 10 rounds of philox4_32
uint4 &random_uint4, int &rnd_idx) {
if (rnd_idx == 4) {
rnd_idx = 0;
curanddx::uniform_bits dist;
random_uint4 = dist.generate4(rng);
random_uint4 = rng.generate4();
}
// Treat uint4 as an array of 4x uint32_t elements for indexing
const uint32_t *const rbits_arr = reinterpret_cast<uint32_t *>(&random_uint4);
const uint32_t rbits = rbits_arr[rnd_idx++];
......@@ -363,9 +363,11 @@ __global__ void __launch_bounds__(THREADS_NUM)
threadIdx.x + blockIdx.x * THREADS_NUM + blockIdx.y * gridDim.x * THREADS_NUM;
const size_t rng_seed = rng_state != nullptr ? rng_state[0] : 0;
const size_t rng_offset = rng_state != nullptr ? rng_state[1] : 0;
RNG rng(rng_seed, rng_sequence, rng_offset);
curanddx::uniform_bits dist;
uint4 random_uint4 = USE_STOCHASTIC_ROUNDING ? dist.generate4(rng) : uint4{0, 0, 0, 0};
transformer_engine::curanddx::detail::philox4x32_native_state<10> rng;
rng.init(rng_seed, rng_sequence, rng_offset);
uint4 random_uint4 = USE_STOCHASTIC_ROUNDING ? rng.generate4() : uint4{0, 0, 0, 0};
int rnd_idx =
0; // Index of the random number. It increments each time when used and resets to 0 if reaches 4x
......@@ -874,9 +876,11 @@ __global__ void __launch_bounds__(THREADS_NUM)
threadIdx.x + blockIdx.x * THREADS_NUM + blockIdx.y * gridDim.x * THREADS_NUM;
const size_t rng_seed = rng_state != nullptr ? rng_state[0] : 0;
const size_t rng_offset = rng_state != nullptr ? rng_state[1] : 0;
RNG rng(rng_seed, rng_sequence, rng_offset);
curanddx::uniform_bits dist;
uint4 random_uint4 = USE_STOCHASTIC_ROUNDING ? dist.generate4(rng) : uint4{0, 0, 0, 0};
transformer_engine::curanddx::detail::philox4x32_native_state<10> rng;
rng.init(rng_seed, rng_sequence, rng_offset);
uint4 random_uint4 = USE_STOCHASTIC_ROUNDING ? rng.generate4() : uint4{0, 0, 0, 0};
int rnd_idx =
0; // Index of the random number. It increments each time when used and resets to 0 if reaches 4x
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment