Commit 4099aa8e authored by yuguo's avatar yuguo
Browse files
parents c520cba3 96f9c6de
---
name: Bug report
about: Create a report to help us improve
title: ''
labels: bug
assignees: ''
---
**Describe the bug**
A clear and concise description of what the bug is.
**Steps/Code to reproduce bug**
Please list *minimal* steps or code snippet for us to be able to reproduce the bug.
A helpful guide on on how to craft a minimal bug report http://matthewrocklin.com/blog/work/2018/02/28/minimal-bug-reports.
**Expected behavior**
A clear and concise description of what you expected to happen.
**Environment overview (please complete the following information)**
- Environment location: [Bare-metal, Docker, Cloud(specify cloud provider - AWS, Azure, GCP, Collab)]
- Method of Transformer Engine install: [pip install or from source]. Please specify exact commands you used to install.
- If method of install is [Docker], provide `docker pull` & `docker run` commands used
**Environment details**
If NVIDIA docker image is used you don't need to specify these.
Otherwise, please provide:
- OS version
- PyTorch version
- Python version
- Transformer Engine version
- CUDA version
- CUDNN version
**Device details**
- GPU model
**Additional context**
Add any other context about the problem here.
---
name: Feature request
about: Suggest an idea for this project
title: ''
labels: feature request
assignees: ''
---
**Is your feature request related to a problem? Please describe.**
A clear and concise description of what the problem is. Ex. I'm always frustrated when [...]
**Describe the solution you'd like**
A clear and concise description of what you want to happen.
Provide a code snippet on how new APIs/changes would be used by others.
**Describe alternatives you've considered**
A clear and concise description of any alternative solutions or features you've considered.
**Additional context**
Add any other context or screenshots about the feature request here.
......@@ -9,7 +9,12 @@ set -e
TE_LIB_PATH=`pip3 show transformer-engine | grep Location | cut -d ' ' -f 2`
export LD_LIBRARY_PATH=$TE_LIB_PATH:$LD_LIBRARY_PATH
# Set parallelization parameters
NUM_PHYSICAL_CORES=$(nproc)
NUM_PARALLEL_JOBS=4
cd $TE_PATH/tests/cpp
cmake -GNinja -Bbuild .
cmake --build build
ctest --test-dir build -j4
export OMP_NUM_THREADS=$((NUM_PHYSICAL_CORES / NUM_PARALLEL_JOBS))
ctest --test-dir build -j$NUM_PARALLEL_JOBS
......@@ -2,14 +2,33 @@
#
# See LICENSE for license information.
set -xe
function error_exit() {
echo "Error: $1"
exit 1
}
function test_fail() {
RET=1
FAILED_CASES="$FAILED_CASES $1"
echo "Error: sub-test failed: $1"
}
RET=0
FAILED_CASES=""
: ${TE_PATH:=/opt/transformerengine}
pip3 install -r $TE_PATH/examples/jax/encoder/requirements.txt
pip3 install -r $TE_PATH/examples/jax/encoder/requirements.txt || error_exit "Failed to install requirements"
# Make encoder tests to have run-to-run deterministic to have the stable CI results
export XLA_FLAGS="${XLA_FLAGS} --xla_gpu_deterministic_ops"
python3 -m pytest -c $TE_PATH/tests/jax/pytest.ini -v $TE_PATH/examples/jax/encoder/test_multigpu_encoder.py
python3 -m pytest -c $TE_PATH/tests/jax/pytest.ini -v $TE_PATH/examples/jax/encoder/test_model_parallel_encoder.py
. $TE_PATH/examples/jax/encoder/run_test_multiprocessing_encoder.sh
python3 -m pytest -c $TE_PATH/tests/jax/pytest.ini -v $TE_PATH/examples/jax/encoder/test_multigpu_encoder.py || test_fail "test_multigpu_encoder.py"
python3 -m pytest -c $TE_PATH/tests/jax/pytest.ini -v $TE_PATH/examples/jax/encoder/test_model_parallel_encoder.py || test_fail "test_model_parallel_encoder.py"
. $TE_PATH/examples/jax/encoder/run_test_multiprocessing_encoder.sh || test_fail "run_test_multiprocessing_encoder.sh"
if [ $RET -ne 0 ]; then
echo "Error: some sub-tests failed: $FAILED_CASES"
exit 1
fi
echo "All tests passed"
exit 0
......@@ -2,22 +2,41 @@
#
# See LICENSE for license information.
set -xe
function error_exit() {
echo "Error: $1"
exit 1
}
pip3 install "nltk>=3.8.2"
pip3 install pytest==8.2.1
function test_fail() {
RET=1
FAILED_CASES="$FAILED_CASES $1"
echo "Error: sub-test failed: $1"
}
RET=0
FAILED_CASES=""
pip3 install "nltk>=3.8.2" || error_exit "Failed to install nltk"
pip3 install pytest==8.2.1 || error_exit "Failed to install pytest"
: ${TE_PATH:=/opt/transformerengine}
python3 -m pytest -c $TE_PATH/tests/jax/pytest.ini -v $TE_PATH/tests/jax -k 'not distributed' --ignore=$TE_PATH/tests/jax/test_praxis_layers.py
python3 -m pytest -c $TE_PATH/tests/jax/pytest.ini -v $TE_PATH/tests/jax -k 'not distributed' --ignore=$TE_PATH/tests/jax/test_praxis_layers.py || test_fail "test_praxis_layers.py"
# Test without custom calls
NVTE_CUSTOM_CALLS_RE="" python3 -m pytest -c $TE_PATH/tests/jax/pytest.ini -v $TE_PATH/tests/jax/test_custom_call_compute.py
NVTE_CUSTOM_CALLS_RE="" python3 -m pytest -c $TE_PATH/tests/jax/pytest.ini -v $TE_PATH/tests/jax/test_custom_call_compute.py || test_fail "test_custom_call_compute.py"
pip3 install -r $TE_PATH/examples/jax/mnist/requirements.txt
pip3 install -r $TE_PATH/examples/jax/encoder/requirements.txt
pip3 install -r $TE_PATH/examples/jax/mnist/requirements.txt || error_exit "Failed to install mnist requirements"
pip3 install -r $TE_PATH/examples/jax/encoder/requirements.txt || error_exit "Failed to install encoder requirements"
python3 -m pytest -c $TE_PATH/tests/jax/pytest.ini -v $TE_PATH/examples/jax/mnist
python3 -m pytest -c $TE_PATH/tests/jax/pytest.ini -v $TE_PATH/examples/jax/mnist || test_fail "test_mnist.py"
# Make encoder tests to have run-to-run deterministic to have the stable CI results
export XLA_FLAGS="${XLA_FLAGS} --xla_gpu_deterministic_ops"
python3 -m pytest -c $TE_PATH/tests/jax/pytest.ini -v $TE_PATH/examples/jax/encoder/test_single_gpu_encoder.py
python3 -m pytest -c $TE_PATH/tests/jax/pytest.ini -v $TE_PATH/examples/jax/encoder/test_single_gpu_encoder.py || test_fail "test_single_gpu_encoder.py"
if [ $RET -ne 0 ]; then
echo "Error: some sub-tests failed: $FAILED_CASES"
exit 1
fi
echo "All tests passed"
exit 0
......@@ -2,34 +2,53 @@
#
# See LICENSE for license information.
set -e
function error_exit() {
echo "Error: $1"
exit 1
}
function test_fail() {
RET=1
FAILED_CASES="$FAILED_CASES $1"
echo "Error: sub-test failed: $1"
}
RET=0
FAILED_CASES=""
: "${TE_PATH:=/opt/transformerengine}"
pip3 install wheel
pip3 install wheel || error_exit "Failed to install wheel"
cd $TE_PATH
pip3 uninstall -y transformer-engine transformer-engine-cu12 transformer-engine-jax
pip3 uninstall -y transformer-engine transformer-engine-cu12 transformer-engine-jax || error_exit "Failed to uninstall transformer-engine transformer-engine-cu12 transformer-engine-jax"
VERSION=`cat $TE_PATH/build_tools/VERSION.txt`
WHL_BASE="transformer_engine-${VERSION}"
# Core wheel.
NVTE_RELEASE_BUILD=1 python3 setup.py bdist_wheel
wheel unpack dist/*
NVTE_RELEASE_BUILD=1 python3 setup.py bdist_wheel || error_exit "Failed to setup bdist_wheel"
wheel unpack dist/* || error_exit "Failed to unpack dist/*"
sed -i "s/Name: transformer-engine/Name: transformer-engine-cu12/g" "transformer_engine-${VERSION}/transformer_engine-${VERSION}.dist-info/METADATA"
sed -i "s/Name: transformer_engine/Name: transformer_engine_cu12/g" "transformer_engine-${VERSION}/transformer_engine-${VERSION}.dist-info/METADATA"
mv "${WHL_BASE}/${WHL_BASE}.dist-info" "${WHL_BASE}/transformer_engine_cu12-${VERSION}.dist-info"
wheel pack ${WHL_BASE}
rm dist/*.whl
mv *.whl dist/
NVTE_RELEASE_BUILD=1 NVTE_BUILD_METAPACKAGE=1 python3 setup.py bdist_wheel
mv "${WHL_BASE}/${WHL_BASE}.dist-info" "${WHL_BASE}/transformer_engine_cu12-${VERSION}.dist-info" || error_exit "Failed to move ${WHL_BASE}.dist-info to transformer_engine_cu12-${VERSION}.dist-info"
wheel pack ${WHL_BASE} || error_exit "Failed to pack ${WHL_BASE}"
rm dist/*.whl || error_exit "Failed to remove dist/*.whl"
mv *.whl dist/ || error_exit "Failed to move *.whl to dist/"
NVTE_RELEASE_BUILD=1 NVTE_BUILD_METAPACKAGE=1 python3 setup.py bdist_wheel || error_exit "Failed to setup metapackage"
cd transformer_engine/jax
NVTE_RELEASE_BUILD=1 python3 setup.py sdist
NVTE_RELEASE_BUILD=1 python3 setup.py sdist || error_exit "Failed to setup sdist"
pip3 install dist/*
pip3 install dist/* || error_exit "Failed to install dist/*"
cd $TE_PATH
pip3 install dist/*.whl --no-deps
pip3 install dist/*.whl --no-deps || error_exit "Failed to install dist/*.whl --no-deps"
python3 $TE_PATH/tests/jax/test_sanity_import.py || test_fail "test_sanity_import.py"
python3 $TE_PATH/tests/jax/test_sanity_import.py
if [ $RET -ne 0 ]; then
echo "Error: some sub-tests failed: $FAILED_CASES"
exit 1
fi
echo "All tests passed"
exit 0
......@@ -2,29 +2,46 @@
#
# See LICENSE for license information.
set -x
function error_exit() {
echo "Error: $1"
exit 1
}
: ${TE_PATH:=/opt/transformerengine}
function test_fail() {
RET=1
FAILED_CASES="$FAILED_CASES $1"
echo "Error: sub-test failed: $1"
}
pip3 install pytest==8.2.1
RET=0
FAILED_CASES=""
FAIL=0
set -x
: ${TE_PATH:=/opt/transformerengine}
python3 -m pytest -v -s $TE_PATH/tests/pytorch/test_sanity.py || FAIL=1
python3 -m pytest -v -s $TE_PATH/tests/pytorch/test_recipe.py || FAIL=1
python3 -m pytest -v -s $TE_PATH/tests/pytorch/test_deferred_init.py || FAIL=1
PYTORCH_JIT=0 NVTE_TORCH_COMPILE=0 NVTE_ALLOW_NONDETERMINISTIC_ALGO=0 python3 -m pytest -v -s $TE_PATH/tests/pytorch/test_numerics.py || FAIL=1
NVTE_CUDNN_MXFP8_NORM=0 PYTORCH_JIT=0 NVTE_TORCH_COMPILE=0 NVTE_ALLOW_NONDETERMINISTIC_ALGO=0 python3 -m pytest -v -s $TE_PATH/tests/pytorch/test_cuda_graphs.py || FAIL=1
python3 -m pytest -v -s $TE_PATH/tests/pytorch/test_jit.py || FAIL=1
python3 -m pytest -v -s $TE_PATH/tests/pytorch/test_fused_rope.py || FAIL=1
python3 -m pytest -v -s $TE_PATH/tests/pytorch/test_float8tensor.py || FAIL=1
python3 -m pytest -v -s $TE_PATH/tests/pytorch/test_gqa.py || FAIL=1
python3 -m pytest -v -s $TE_PATH/tests/pytorch/test_fused_optimizer.py || FAIL=1
python3 -m pytest -v -s $TE_PATH/tests/pytorch/test_multi_tensor.py || FAIL=1
python3 -m pytest -v -s $TE_PATH/tests/pytorch/test_fusible_ops.py || FAIL=1
python3 -m pytest -v -s $TE_PATH/tests/pytorch/test_permutation.py || FAIL=1
python3 -m pytest -v -s $TE_PATH/tests/pytorch/test_parallel_cross_entropy.py || FAIL=1
python3 -m pytest -v -s $TE_PATH/tests/pytorch/test_cpu_offloading.py || FAIL=1
NVTE_DEBUG=1 NVTE_DEBUG_LEVEL=1 python3 -m pytest -o log_cli=true --log-cli-level=INFO -v -s $TE_PATH/tests/pytorch/fused_attn/test_fused_attn.py || FAIL=1
pip3 install pytest==8.2.1 || error_exit "Failed to install pytest"
exit $FAIL
python3 -m pytest -v -s $TE_PATH/tests/pytorch/test_sanity.py || test_fail "test_sanity.py"
python3 -m pytest -v -s $TE_PATH/tests/pytorch/test_recipe.py || test_fail "test_recipe.py"
python3 -m pytest -v -s $TE_PATH/tests/pytorch/test_deferred_init.py || test_fail "test_deferred_init.py"
PYTORCH_JIT=0 NVTE_TORCH_COMPILE=0 NVTE_ALLOW_NONDETERMINISTIC_ALGO=0 python3 -m pytest -v -s $TE_PATH/tests/pytorch/test_numerics.py || test_fail "test_numerics.py"
NVTE_CUDNN_MXFP8_NORM=0 PYTORCH_JIT=0 NVTE_TORCH_COMPILE=0 NVTE_ALLOW_NONDETERMINISTIC_ALGO=0 python3 -m pytest -v -s $TE_PATH/tests/pytorch/test_cuda_graphs.py || test_fail "test_cuda_graphs.py"
python3 -m pytest -v -s $TE_PATH/tests/pytorch/test_jit.py || test_fail "test_jit.py"
python3 -m pytest -v -s $TE_PATH/tests/pytorch/test_fused_rope.py || test_fail "test_fused_rope.py"
python3 -m pytest -v -s $TE_PATH/tests/pytorch/test_float8tensor.py || test_fail "test_float8tensor.py"
python3 -m pytest -v -s $TE_PATH/tests/pytorch/test_gqa.py || test_fail "test_gqa.py"
python3 -m pytest -v -s $TE_PATH/tests/pytorch/test_fused_optimizer.py || test_fail "test_fused_optimizer.py"
python3 -m pytest -v -s $TE_PATH/tests/pytorch/test_multi_tensor.py || test_fail "test_multi_tensor.py"
python3 -m pytest -v -s $TE_PATH/tests/pytorch/test_fusible_ops.py || test_fail "test_fusible_ops.py"
python3 -m pytest -v -s $TE_PATH/tests/pytorch/test_permutation.py || test_fail "test_permutation.py"
python3 -m pytest -v -s $TE_PATH/tests/pytorch/test_parallel_cross_entropy.py || test_fail "test_parallel_cross_entropy.py"
python3 -m pytest -v -s $TE_PATH/tests/pytorch/test_cpu_offloading.py || test_fail "test_cpu_offloading.py"
NVTE_DEBUG=1 NVTE_DEBUG_LEVEL=1 python3 -m pytest -o log_cli=true --log-cli-level=INFO -v -s $TE_PATH/tests/pytorch/fused_attn/test_fused_attn.py || test_fail "test_fused_attn.py"
NVTE_DEBUG=1 NVTE_DEBUG_LEVEL=1 python3 -m pytest -o log_cli=true --log-cli-level=INFO -v -s $TE_PATH/tests/pytorch/fused_attn/test_paged_attn.py || test_fail "test_paged_attn.py"
if [ "$RET" -ne 0 ]; then
echo "Error in the following test cases:$FAILED_CASES"
exit 1
fi
echo "All tests passed"
exit 0
......@@ -2,34 +2,53 @@
#
# See LICENSE for license information.
set -e
function error_exit() {
echo "Error: $1"
exit 1
}
function test_fail() {
RET=1
FAILED_CASES="$FAILED_CASES $1"
echo "Error: sub-test failed: $1"
}
RET=0
FAILED_CASES=""
: "${TE_PATH:=/opt/transformerengine}"
pip3 install wheel
pip3 install wheel || error_exit "Failed to install wheel"
cd $TE_PATH
pip3 uninstall -y transformer-engine transformer-engine-cu12 transformer-engine-torch
pip3 uninstall -y transformer-engine transformer-engine-cu12 transformer-engine-torch || error_exit "Failed to uninstall transformer-engine transformer-engine-cu12 transformer-engine-torch"
VERSION=`cat $TE_PATH/build_tools/VERSION.txt`
WHL_BASE="transformer_engine-${VERSION}"
# Core wheel.
NVTE_RELEASE_BUILD=1 python3 setup.py bdist_wheel
wheel unpack dist/*
NVTE_RELEASE_BUILD=1 python3 setup.py bdist_wheel || error_exit "Failed to setup bdist_wheel"
wheel unpack dist/* || error_exit "Failed to unpack dist/*"
sed -i "s/Name: transformer-engine/Name: transformer-engine-cu12/g" "transformer_engine-${VERSION}/transformer_engine-${VERSION}.dist-info/METADATA"
sed -i "s/Name: transformer_engine/Name: transformer_engine_cu12/g" "transformer_engine-${VERSION}/transformer_engine-${VERSION}.dist-info/METADATA"
mv "${WHL_BASE}/${WHL_BASE}.dist-info" "${WHL_BASE}/transformer_engine_cu12-${VERSION}.dist-info"
wheel pack ${WHL_BASE}
rm dist/*.whl
mv *.whl dist/
NVTE_RELEASE_BUILD=1 NVTE_BUILD_METAPACKAGE=1 python3 setup.py bdist_wheel
mv "${WHL_BASE}/${WHL_BASE}.dist-info" "${WHL_BASE}/transformer_engine_cu12-${VERSION}.dist-info" || error_exit "Failed to move ${WHL_BASE}.dist-info to transformer_engine_cu12-${VERSION}.dist-info"
wheel pack ${WHL_BASE} || error_exit "Failed to pack ${WHL_BASE}"
rm dist/*.whl || error_exit "Failed to remove dist/*.whl"
mv *.whl dist/ || error_exit "Failed to move *.whl to dist/"
NVTE_RELEASE_BUILD=1 NVTE_BUILD_METAPACKAGE=1 python3 setup.py bdist_wheel || error_exit "Failed to setup metapackage"
cd transformer_engine/pytorch
NVTE_RELEASE_BUILD=1 python3 setup.py sdist
NVTE_RELEASE_BUILD=1 python3 setup.py sdist || error_exit "Failed to setup sdist"
pip3 install dist/*
pip3 install dist/* || error_exit "Failed to install dist/*"
cd $TE_PATH
pip3 install dist/*.whl --no-deps
pip3 install dist/*.whl --no-deps || error_exit "Failed to install dist/*.whl --no-deps"
python3 $TE_PATH/tests/pytorch/test_sanity_import.py || test_fail "test_sanity_import.py"
python3 $TE_PATH/tests/pytorch/test_sanity_import.py
if [ "$RET" -ne 0 ]; then
echo "Error in the following test cases:$FAILED_CASES"
exit 1
fi
echo "All tests passed"
exit 0
......@@ -6,4 +6,4 @@ set -xe
: ${TE_PATH:=/opt/transformerengine}
pytest -c $TE_PATH/tests/jax/pytest.ini -v $TE_PATH/tests/jax/test_distributed_*
python3 -m pytest -c $TE_PATH/tests/jax/pytest.ini -v $TE_PATH/tests/jax/test_distributed_*
......@@ -2,17 +2,34 @@
#
# See LICENSE for license information.
: ${TE_PATH:=/opt/transformerengine}
function error_exit() {
echo "Error: $1"
exit 1
}
function test_fail() {
RET=1
FAILED_CASES="$FAILED_CASES $1"
echo "Error: sub-test failed: $1"
}
pip3 install pytest==8.2.1
RET=0
FAILED_CASES=""
: ${TE_PATH:=/opt/transformerengine}
FAIL=0
pip3 install pytest==8.2.1 || error_exit "Failed to install pytest"
python3 -m pytest -v -s $TE_PATH/tests/pytorch/distributed/test_numerics.py || FAIL=1
python3 -m pytest -v -s $TE_PATH/tests/pytorch/distributed/test_fusible_ops.py || FAIL=1
python3 -m pytest -v -s $TE_PATH/tests/pytorch/distributed/test_torch_fsdp2.py || FAIL=1
python3 -m pytest -v -s $TE_PATH/tests/pytorch/distributed/test_comm_gemm_overlap.py || FAIL=1
# python3 -m pytest -v -s $TE_PATH/tests/pytorch/distributed/test_fusible_ops_with_userbuffers.py || FAIL=1 ### TODO Debug UB support with te.Sequential
python3 -m pytest -v -s $TE_PATH/tests/pytorch/fused_attn/test_fused_attn_with_cp.py || FAIL=1
python3 -m pytest -v -s $TE_PATH/tests/pytorch/distributed/test_numerics.py || test_fail "test_numerics.py"
python3 -m pytest -v -s $TE_PATH/tests/pytorch/distributed/test_fusible_ops.py || test_fail "test_fusible_ops.py"
python3 -m pytest -v -s $TE_PATH/tests/pytorch/distributed/test_torch_fsdp2.py || test_fail "test_torch_fsdp2.py"
python3 -m pytest -v -s $TE_PATH/tests/pytorch/distributed/test_comm_gemm_overlap.py || test_fail "test_comm_gemm_overlap.py"
# python3 -m pytest -v -s $TE_PATH/tests/pytorch/distributed/test_fusible_ops_with_userbuffers.py || test_fail "test_fusible_ops_with_userbuffers.py" ### TODO Debug UB support with te.Sequential
python3 -m pytest -v -s $TE_PATH/tests/pytorch/fused_attn/test_fused_attn_with_cp.py || test_fail "test_fused_attn_with_cp.py"
exit $FAIL
if [ "$RET" -ne 0 ]; then
echo "Error in the following test cases:$FAILED_CASES"
exit 1
fi
echo "All tests passed"
exit 0
......@@ -17,7 +17,7 @@ if [ $sm_arch -gt 90 ]
then
FA_versions=(2.7.3)
else
FA_versions=(2.1.1 2.3.0 2.4.1 2.5.7 2.7.3 3.0.0b1)
FA_versions=(2.3.0 2.4.1 2.5.7 2.7.3 3.0.0b1)
fi
for fa_version in "${FA_versions[@]}"
......@@ -28,13 +28,15 @@ do
then
pip3 install flash-attn==${fa_version}
else
pip3 install "git+https://github.com/Dao-AILab/flash-attention.git@v2.7.2#egg=flashattn-hopper&subdirectory=hopper"
python_path=`python3 -c "import site; print(site.getsitepackages()[0])"`
mkdir -p $python_path/flashattn_hopper
wget -P $python_path/flashattn_hopper https://raw.githubusercontent.com/Dao-AILab/flash-attention/v2.7.2/hopper/flash_attn_interface.py
git clone https://github.com/Dao-AILab/flash-attention.git
cd flash-attention/ && git checkout 27f501d && cd hopper/ && python setup.py install
python_path=`python -c "import site; print(site.getsitepackages()[0])"`
mkdir -p $python_path/flash_attn_3
wget -P $python_path/flash_attn_3 https://raw.githubusercontent.com/Dao-AILab/flash-attention/27f501dbe011f4371bff938fe7e09311ab3002fa/hopper/flash_attn_interface.py
cd ../../
fi
# Run tests
NVTE_TORCH_COMPILE=0 pytest -v -s $TE_PATH/tests/pytorch/fused_attn/test_fused_attn.py
NVTE_TORCH_COMPILE=0 python3 -m pytest -v -s $TE_PATH/tests/pytorch/fused_attn/test_fused_attn.py
done
......@@ -115,21 +115,50 @@ void compute_ref_x1(const ProcessingMethod processing_method,
const size_t block_size_X,
const size_t scales_stride)
{
std::vector<float> output_dbias_fp32(cols, 0);
const size_t tile_size_Y = std::max(32lu, block_size_Y);
const size_t tile_size_X = std::max(64lu, block_size_X);
const size_t tiles_num_Y = (rows + tile_size_Y - 1) / tile_size_Y;
const size_t tiles_num_X = (cols + tile_size_X - 1) / tile_size_X;
const size_t blocks_per_tile_Y = tile_size_Y / block_size_Y;
const size_t blocks_per_tile_X = tile_size_X / block_size_X;
const size_t blocks_Y = (rows + block_size_Y - 1) / block_size_Y;
const size_t blocks_X = (cols + block_size_X - 1) / block_size_X;
for (size_t ii = 0; ii < blocks_Y; ++ii) {
const size_t i_min = ii * block_size_Y;
const size_t i_max = std::min((ii + 1) * block_size_Y, rows);
for (size_t jj = 0; jj < blocks_X; ++jj) {
const size_t j_min = jj * block_size_X;
const size_t j_max = std::min((jj + 1) * block_size_X, cols);
const size_t scale_idx = ii * scales_stride + jj;
scale_block<InputType, OutputType, OP>(
processing_method, input, grad, output_c, output_dbias_fp32.data(),
output_scales, scale_idx, i_min, i_max, j_min, j_max, cols);
std::vector<float> output_dbias_fp32(cols, 0);
#pragma omp parallel proc_bind(spread)
{
std::vector<float> thread_dbias(cols, 0);
#pragma omp for schedule(static)
for (size_t t = 0; t < tiles_num_Y * tiles_num_X; ++t) {
const size_t tile_Y = t / tiles_num_X;
const size_t tile_X = t % tiles_num_X;
const size_t tile_offset_Y = tile_Y * tile_size_Y;
const size_t tile_offset_X = tile_X * tile_size_X;
for (size_t ii = 0; ii < blocks_per_tile_Y; ++ii) {
const size_t block_idx_Y = tile_Y * blocks_per_tile_Y + ii;
const size_t block_offset_Y = ii * block_size_Y;
const size_t i_min = tile_offset_Y + block_offset_Y;
if (i_min >= rows) continue;
const size_t i_max = std::min(i_min + block_size_Y, rows);
for (size_t jj = 0; jj < blocks_per_tile_X; ++jj) {
const size_t block_idx_X = tile_X * blocks_per_tile_X + jj;
const size_t block_offset_X = jj * block_size_X;
const size_t j_min = tile_offset_X + block_offset_X;
if (j_min >= cols) continue;
const size_t j_max = std::min(j_min + block_size_X, cols);
const size_t scale_idx = block_idx_Y * scales_stride + block_idx_X;
scale_block<InputType, OutputType, OP>(
processing_method, input, grad, output_c, thread_dbias.data(),
output_scales, scale_idx, i_min, i_max, j_min, j_max, cols);
}
}
}
#pragma omp critical
{
for (size_t j = 0; j < cols; ++j) {
output_dbias_fp32[j] += thread_dbias[j];
}
}
}
for (size_t j = 0; j < cols; ++j) {
......
......@@ -61,18 +61,38 @@ void compute_ref_x1(const InputType* input,
const size_t block_size_X,
const size_t scales_stride)
{
const size_t blocks_Y = (rows + block_size_Y - 1) / block_size_Y;
const size_t blocks_X = (cols + block_size_X - 1) / block_size_X;
for (size_t ii = 0; ii < blocks_Y; ++ii) {
const size_t i_min = ii * block_size_Y;
const size_t i_max = std::min((ii + 1) * block_size_Y, rows);
for (size_t jj = 0; jj < blocks_X; ++jj) {
const size_t j_min = jj * block_size_X;
const size_t j_max = std::min((jj + 1) * block_size_X, cols);
const size_t scale_idx = ii * scales_stride + jj;
dequantize_block<InputType, OutputType>(
input, output, scales, scale_idx, i_min, i_max, j_min, j_max, cols);
const size_t tile_size_Y = std::max(32lu, block_size_Y);
const size_t tile_size_X = std::max(64lu, block_size_X);
const size_t tiles_num_Y = (rows + tile_size_Y - 1) / tile_size_Y;
const size_t tiles_num_X = (cols + tile_size_X - 1) / tile_size_X;
const size_t blocks_per_tile_Y = tile_size_Y / block_size_Y;
const size_t blocks_per_tile_X = tile_size_X / block_size_X;
#pragma omp parallel for schedule(static) proc_bind(spread)
for (size_t t = 0; t < tiles_num_Y * tiles_num_X; ++t) {
const size_t tile_Y = t / tiles_num_X;
const size_t tile_X = t % tiles_num_X;
const size_t tile_offset_Y = tile_Y * tile_size_Y;
const size_t tile_offset_X = tile_X * tile_size_X;
for (size_t ii = 0; ii < blocks_per_tile_Y; ++ii) {
const size_t block_idx_Y = tile_Y * blocks_per_tile_Y + ii;
const size_t block_offset_Y = ii * block_size_Y;
const size_t i_min = tile_offset_Y + block_offset_Y;
if (i_min >= rows) continue;
const size_t i_max = std::min(i_min + block_size_Y, rows);
for (size_t jj = 0; jj < blocks_per_tile_X; ++jj) {
const size_t block_idx_X = tile_X * blocks_per_tile_X + jj;
const size_t block_offset_X = jj * block_size_X;
const size_t j_min = tile_offset_X + block_offset_X;
if (j_min >= cols) continue;
const size_t j_max = std::min(j_min + block_size_X, cols);
const size_t scale_idx = block_idx_Y * scales_stride + block_idx_X;
dequantize_block<InputType, OutputType>(
input, output, scales, scale_idx, i_min, i_max, j_min, j_max, cols);
}
}
}
}
......
......@@ -683,10 +683,15 @@ void generate_data_uniformly(T* data, const size_t size, std::mt19937* gen) {
#pragma omp parallel proc_bind(spread)
{
std::mt19937 gen_local = *gen;
gen_local.discard(omp_get_thread_num() * 599);
const int thread_ID = omp_get_thread_num();
const int threads_num = omp_get_max_threads();
const int chunk_size = (size + threads_num - 1) / threads_num;
const int idx_min = chunk_size * thread_ID;
const int idx_max = std::min(chunk_size * (thread_ID + 1), static_cast<int>(size));
gen_local.discard(idx_min);
std::uniform_real_distribution<> dis(-2.0, 1.0);
#pragma omp for schedule(static)
for (size_t i = 0; i < size; ++i) {
for (int i = idx_min; i < idx_max; ++i) {
data[i] = static_cast<T>(dis(gen_local));
}
}
......
......@@ -27,6 +27,7 @@ from transformer_engine.pytorch.dot_product_attention.utils import (
check_set_window_size,
AttentionParams,
)
from transformer_engine.pytorch.dot_product_attention.inference import InferenceParams
from transformer_engine.pytorch.dot_product_attention.rope import RotaryPositionEmbedding
from transformer_engine.pytorch.constants import TE_DType
import transformer_engine.pytorch.cpp_extensions as ext
......@@ -97,6 +98,8 @@ class ModelConfig:
num_layers: int = 1,
bias_shape: str = "1hss",
window_size: Tuple[int, int] = (-1, -1),
total_requests: int = None,
max_ctx_len: int = None,
):
self.batch_size = batch_size
self.num_heads = num_heads
......@@ -115,6 +118,8 @@ class ModelConfig:
self.num_layers = num_layers
self.bias_shape = bias_shape
self.window_size = window_size
self.total_requests = total_requests
self.max_ctx_len = max_ctx_len
@contextmanager
......@@ -137,6 +142,8 @@ def _get_attention_backends(
deterministic: bool = False,
fp8: bool = False,
fp8_meta: Optional[Dict[str, Any]] = None,
is_training: bool = True,
inference_params: Optional[InferenceParams] = None,
) -> Tuple[List, List]:
"""Check if what attention backends support a model configuration"""
......@@ -166,6 +173,7 @@ def _get_attention_backends(
fused_attn_backends = []
available_backends = None
flash_attention_backend = None
fused_attention_backend = None
def test():
......@@ -191,10 +199,13 @@ def _get_attention_backends(
deterministic=deterministic,
fp8=fp8,
fp8_meta=fp8_meta,
is_training=is_training,
inference_params=inference_params,
)
(
use_flash_attention,
use_fused_attention,
flash_attention_backend,
fused_attention_backend,
use_unfused_attention,
available_backends,
......@@ -203,20 +214,21 @@ def _get_attention_backends(
# from get_attention_backend()
_attention_backends["use_flash_attention"] = use_flash_attention
_attention_backends["use_fused_attention"] = use_fused_attention
_attention_backends["flash_attention_backend"] = flash_attention_backend
_attention_backends["fused_attention_backend"] = fused_attention_backend
_attention_backends["use_unfused_attention"] = use_unfused_attention
_attention_backends["backend_selection_requires_update"] = False
return available_backends, fused_attention_backend
return available_backends, flash_attention_backend, fused_attention_backend
backends = {0: "F16_max512_seqlen", 1: "F16_arbitrary_seqlen", 2: "FP8"}
with logging_context():
for i in range(3):
os.environ["NVTE_FUSED_ATTN_BACKEND"] = str(i)
_attention_backends["backend_selection_requires_update"] = True
available_backends, fused_attention_backend = test()
available_backends, flash_attention_backend, fused_attention_backend = test()
if fused_attention_backend == FusedAttnBackend[backends[i]]:
fused_attn_backends.append(fused_attention_backend)
return available_backends, fused_attn_backends
return available_backends, flash_attention_backend, fused_attn_backends
model_configs_base = {
......@@ -268,7 +280,7 @@ def test_dot_product_attention(
if config.window_size == (-1, -1) and swa:
config.window_size = [2, 2]
config.window_size = check_set_window_size(config.attn_mask_type, config.window_size)
available_backends, fused_attn_backends = _get_attention_backends(
available_backends, _, fused_attn_backends = _get_attention_backends(
config,
qkv_dtype=dtype,
qkv_layout=qkv_layout,
......@@ -1153,7 +1165,7 @@ def test_transformer_layer(
qkv_layout = "sbhd_sbhd_sbhd"
# Test backend availability
available_backends, fused_attn_backends = _get_attention_backends(
available_backends, _, fused_attn_backends = _get_attention_backends(
config,
qkv_dtype=dtype,
qkv_layout=qkv_layout,
......
This diff is collapsed.
......@@ -2,9 +2,10 @@
#
# See LICENSE for license information.
from collections import OrderedDict
import math
import os
from typing import Dict, List, Optional
from typing import Dict, List, Tuple, Optional
import pytest
import copy
import random
......@@ -60,6 +61,8 @@ torch.cuda.manual_seed(seed)
_cpu_rng_state = torch.get_rng_state()
_cuda_rng_state = torch.cuda.get_rng_state()
torch._dynamo.config.recompile_limit = 16
class ModelConfig:
def __init__(self, hidden_size, eps, num_attention_heads, embed, num_layers, seq_len):
......@@ -78,9 +81,9 @@ model_configs = {
model_configs_inference = {
# hidden_size, eps, num_attention_heads, embed, num_layers, seq_len
"126m": ModelConfig(768, 1e-5, 12, 64, 12, 16),
"126m": ModelConfig(768, 1e-5, 12, 64, 12, 256),
}
backends_inference = ["FlashAttention", "UnfusedAttention"]
backends_inference = ["FlashAttention", "UnfusedAttention", "FusedAttention"]
module_inference = ["TransformerLayer", "MultiheadAttention"]
input_formats_inference = ["sbhd", "bshd"]
......@@ -329,9 +332,9 @@ class TorchLayerNormLinear(nn.Module):
in_features: int,
out_features: int,
eps: float,
bias: bool = True,
normalization: str = "LayerNorm",
zero_centered_gamma: bool = False,
bias: bool = True,
):
super().__init__()
if normalization == "LayerNorm":
......@@ -345,7 +348,7 @@ class TorchLayerNormLinear(nn.Module):
else:
raise RuntimeError("Unsupported normalization")
self.linear = nn.Linear(in_features, out_features)
self.linear = nn.Linear(in_features, out_features, bias=bias)
def forward(self, x: torch.Tensor) -> torch.Tensor:
return self.linear(self.layernorm(x))
......@@ -445,6 +448,7 @@ class TorchLayerNormMLP(nn.Module):
eps: float = 1e-5,
activation="gelu",
normalization: str = "LayerNorm",
bias: bool = True,
):
super().__init__()
if normalization == "LayerNorm":
......@@ -460,8 +464,8 @@ class TorchLayerNormMLP(nn.Module):
fc1_output_features = ffn_hidden_size
self.gelu = _supported_act[activation]
self.fc1 = nn.Linear(hidden_size, fc1_output_features)
self.fc2 = nn.Linear(ffn_hidden_size, hidden_size)
self.fc1 = nn.Linear(hidden_size, fc1_output_features, bias=bias)
self.fc2 = nn.Linear(ffn_hidden_size, hidden_size, bias=bias)
def forward(self, x):
t = self.gelu(self.fc1(self.ln(x)))
......@@ -672,8 +676,6 @@ def test_gpt_full_activation_recompute(
pytest.skip(reason_for_no_fp8)
if recipe.mxfp8() and not mxfp8_available:
pytest.skip(reason_for_no_mxfp8)
if fp8 and recipe.float8_current_scaling():
pytest.skip("Float8 Current Scaling unsupported for full recompute.")
config = model_configs[model]
......@@ -1039,6 +1041,8 @@ def _test_granular_accuracy(block, bs, dtype, config):
inp_hidden_states.retain_grad()
out = block(inp_hidden_states)
if isinstance(out, (List, Tuple)):
out = out[0]
loss = out.sum()
loss.backward()
......@@ -1117,32 +1121,53 @@ def test_dpa_accuracy(dtype, bs, model):
assert_allclose(te_output, torch_output, atol=5e-2, rtol=1e-2)
class TestReturnBiasModule(nn.Module):
def __init__(self, mod, **kwargs):
super().__init__()
self.te_module = mod(**kwargs)
self.return_bias = kwargs["return_bias"]
self.bias = kwargs["bias"]
def forward(self, x):
if self.return_bias:
out, bias = self.te_module(x)
if self.bias:
out = out + bias
return out
return self.te_module(x)
@pytest.mark.parametrize("dtype", param_types)
@pytest.mark.parametrize("bs", batch_sizes)
@pytest.mark.parametrize("model", ["small"])
def test_linear_accuracy(dtype, bs, model):
@pytest.mark.parametrize("return_bias", all_boolean)
@pytest.mark.parametrize("bias", all_boolean)
def test_linear_accuracy(dtype, bs, model, return_bias, bias):
config = model_configs[model]
te_linear = Linear(
config.hidden_size,
4 * config.hidden_size,
bias=True,
te_linear = TestReturnBiasModule(
Linear,
in_features=config.hidden_size,
out_features=4 * config.hidden_size,
params_dtype=dtype,
return_bias=return_bias,
bias=bias,
device="cuda",
).eval()
)
torch_linear = torch.nn.Linear(
config.hidden_size,
4 * config.hidden_size,
bias=True,
bias=bias,
device="cuda",
dtype=dtype,
).eval()
)
# Share params
with torch.no_grad():
torch_linear.weight = Parameter(te_linear.weight.clone())
torch_linear.bias = Parameter(te_linear.bias.clone())
torch_linear.weight = Parameter(te_linear.te_module.weight.clone())
if bias:
torch_linear.bias = Parameter(te_linear.te_module.bias.clone())
te_outputs = _test_granular_accuracy(te_linear, bs, dtype, config)
torch_outputs = _test_granular_accuracy(torch_linear, bs, dtype, config)
......@@ -1265,41 +1290,51 @@ def test_layernorm_accuracy(dtype, bs, model, eps, zero_centered_gamma):
@pytest.mark.parametrize("model", ["small"])
@pytest.mark.parametrize("normalization", all_normalizations)
@pytest.mark.parametrize("zero_centered_gamma", all_boolean)
def test_layernorm_linear_accuracy(dtype, bs, model, normalization, zero_centered_gamma):
@pytest.mark.parametrize("return_bias", all_boolean)
@pytest.mark.parametrize("bias", all_boolean)
def test_layernorm_linear_accuracy(
dtype, bs, model, normalization, zero_centered_gamma, return_bias, bias
):
config = model_configs[model]
te_ln_linear = LayerNormLinear(
config.hidden_size,
4 * config.hidden_size,
config.eps,
bias=True,
te_ln_linear = TestReturnBiasModule(
LayerNormLinear,
in_features=config.hidden_size,
out_features=4 * config.hidden_size,
eps=config.eps,
normalization=normalization,
params_dtype=dtype,
zero_centered_gamma=zero_centered_gamma,
return_bias=return_bias,
bias=bias,
device="cuda",
).eval()
)
torch_ln_linear = (
TorchLayerNormLinear(
config.hidden_size,
4 * config.hidden_size,
config.eps,
bias=True,
normalization=normalization,
zero_centered_gamma=zero_centered_gamma,
bias=bias,
)
.to(dtype=dtype)
.cuda()
.eval()
)
# Share params
with torch.no_grad():
torch_ln_linear.layernorm.weight = Parameter(te_ln_linear.layer_norm_weight.clone())
torch_ln_linear.layernorm.weight = Parameter(
te_ln_linear.te_module.layer_norm_weight.clone()
)
if normalization != "RMSNorm":
torch_ln_linear.layernorm.bias = Parameter(te_ln_linear.layer_norm_bias.clone())
torch_ln_linear.linear.weight = Parameter(te_ln_linear.weight.clone())
torch_ln_linear.linear.bias = Parameter(te_ln_linear.bias.clone())
torch_ln_linear.layernorm.bias = Parameter(
te_ln_linear.te_module.layer_norm_bias.clone()
)
torch_ln_linear.linear.weight = Parameter(te_ln_linear.te_module.weight.clone())
if bias:
torch_ln_linear.linear.bias = Parameter(te_ln_linear.te_module.bias.clone())
te_outputs = _test_granular_accuracy(te_ln_linear, bs, dtype, config)
torch_outputs = _test_granular_accuracy(torch_ln_linear, bs, dtype, config)
......@@ -1339,17 +1374,22 @@ def test_layernorm_linear_accuracy(dtype, bs, model, normalization, zero_centere
@pytest.mark.parametrize("model", ["small"])
@pytest.mark.parametrize("activation", all_activations)
@pytest.mark.parametrize("normalization", all_normalizations)
def test_layernorm_mlp_accuracy(dtype, bs, model, activation, normalization):
@pytest.mark.parametrize("return_bias", all_boolean)
@pytest.mark.parametrize("bias", all_boolean)
def test_layernorm_mlp_accuracy(dtype, bs, model, activation, normalization, return_bias, bias):
config = model_configs[model]
te_ln_mlp = LayerNormMLP(
config.hidden_size,
4 * config.hidden_size,
te_ln_mlp = TestReturnBiasModule(
LayerNormMLP,
hidden_size=config.hidden_size,
ffn_hidden_size=4 * config.hidden_size,
activation=activation,
normalization=normalization,
params_dtype=dtype,
return_bias=return_bias,
bias=bias,
device="cuda",
).eval()
)
torch_ln_mlp = (
TorchLayerNormMLP(
......@@ -1357,21 +1397,22 @@ def test_layernorm_mlp_accuracy(dtype, bs, model, activation, normalization):
4 * config.hidden_size,
activation=activation,
normalization=normalization,
bias=bias,
)
.to(dtype=dtype)
.cuda()
.eval()
)
# Share params
with torch.no_grad():
torch_ln_mlp.ln.weight = Parameter(te_ln_mlp.layer_norm_weight.clone())
torch_ln_mlp.ln.weight = Parameter(te_ln_mlp.te_module.layer_norm_weight.clone())
if normalization != "RMSNorm":
torch_ln_mlp.ln.bias = Parameter(te_ln_mlp.layer_norm_bias.clone())
torch_ln_mlp.fc1.weight = Parameter(te_ln_mlp.fc1_weight.clone())
torch_ln_mlp.fc1.bias = Parameter(te_ln_mlp.fc1_bias.clone())
torch_ln_mlp.fc2.weight = Parameter(te_ln_mlp.fc2_weight.clone())
torch_ln_mlp.fc2.bias = Parameter(te_ln_mlp.fc2_bias.clone())
torch_ln_mlp.ln.bias = Parameter(te_ln_mlp.te_module.layer_norm_bias.clone())
torch_ln_mlp.fc1.weight = Parameter(te_ln_mlp.te_module.fc1_weight.clone())
torch_ln_mlp.fc2.weight = Parameter(te_ln_mlp.te_module.fc2_weight.clone())
if bias:
torch_ln_mlp.fc1.bias = Parameter(te_ln_mlp.te_module.fc1_bias.clone())
torch_ln_mlp.fc2.bias = Parameter(te_ln_mlp.te_module.fc2_bias.clone())
te_outputs = _test_granular_accuracy(te_ln_mlp, bs, dtype, config)
torch_outputs = _test_granular_accuracy(torch_ln_mlp, bs, dtype, config)
......@@ -2040,14 +2081,25 @@ def test_transformer_layer_hidden_states_format(dtype, bs, model):
@pytest.mark.parametrize("input_format", input_formats_inference)
@pytest.mark.parametrize("module", module_inference)
@pytest.mark.parametrize("backend", backends_inference)
def test_kv_cache_accuracy(dtype, bs, model_key, use_RoPE, input_format, module, backend):
@pytest.mark.parametrize("is_paged", [False, True])
def test_kv_cache_accuracy(dtype, bs, model_key, use_RoPE, input_format, module, backend, is_paged):
reset_rng_states()
if backend in ["FusedAttention", "FlashAttention"] and dtype == torch.float32:
pytest.skip("FusedAttention and FlashAttention do not support FP32")
if use_RoPE:
pytest.skip("KV cache does not support starting positions for RoPE")
os.environ["NVTE_FLASH_ATTN"] = "0"
os.environ["NVTE_FUSED_ATTN"] = "0"
os.environ["NVTE_UNFUSED_ATTN"] = "0"
if backend == "FlashAttention":
os.environ["NVTE_FLASH_ATTN"] = "1"
elif backend == "FusedAttention":
os.environ["NVTE_FUSED_ATTN"] = "1"
elif backend == "UnfusedAttention":
os.environ["NVTE_UNFUSED_ATTN"] = "1"
config = model_configs_inference[model_key]
......@@ -2060,7 +2112,7 @@ def test_kv_cache_accuracy(dtype, bs, model_key, use_RoPE, input_format, module,
# Limits the max size of KV-cache
B_max = B
S_max = S + 2
S_max = S
if module == "TransformerLayer":
model = TransformerLayer(
......@@ -2090,7 +2142,17 @@ def test_kv_cache_accuracy(dtype, bs, model_key, use_RoPE, input_format, module,
.eval()
)
inference_params = InferenceParams(max_batch_size=B_max, max_sequence_length=S_max)
inference_params = InferenceParams(
max_batch_size=B_max,
max_seqlen_kv=S_max,
num_heads_kv=H,
head_dim_k=head_size,
dtype=dtype,
is_paged=is_paged,
total_num_pages=int(B_max * S_max / 256),
page_size=256,
)
rotary_freqs = torch.randn((S_max, 1, 1, head_size), dtype=torch.float, device="cuda")
input = torch.randn((S, B, D), dtype=dtype, device="cuda")
......@@ -2103,22 +2165,39 @@ def test_kv_cache_accuracy(dtype, bs, model_key, use_RoPE, input_format, module,
full_output = model(hidden_states=input, rotary_pos_emb=rotary_freqs if use_RoPE else None)
# Incrementaly generate outputs using KV-cache
step_dict = OrderedDict(zip(list(range(B)), [1] * B))
for i in range(S):
inference_params.pre_step(step_dict)
if input_format == "sbhd":
incremental_input = input[i].view(1, B, D)
else:
incremental_input = input[:, i, :].view(B, 1, D)
seqlens_q = torch.ones(B, dtype=torch.int32, device="cuda")
cu_seqlens_q = torch.zeros(B + 1, dtype=torch.int32, device="cuda")
cu_seqlens_q[1:] = torch.cumsum(seqlens_q, dim=0)
cu_seqlens_kv = cu_seqlens_q.clone()
mask_type = "padding"
kwargs = {}
if module == "TransformerLayer":
kwargs["self_attn_mask_type"] = mask_type
else:
kwargs["attn_mask_type"] = mask_type
line_output = model(
hidden_states=incremental_input,
inference_params=inference_params,
rotary_pos_emb=rotary_freqs if use_RoPE else None,
**kwargs,
max_seqlen_q=1,
max_seqlen_kv=S,
cu_seqlens_q=cu_seqlens_q,
cu_seqlens_kv=cu_seqlens_kv,
)
inference_params.sequence_len_offset += 1
if input_format == "sbhd":
incremental_output[i] = line_output.view(B, D)
incremental_output[i, :, :] = line_output.view(B, D)
else:
incremental_output[:, i, :] = line_output.view(B, D)
......
......@@ -58,9 +58,13 @@ cudnn_frontend::DataType_t get_cudnn_fe_dtype(const transformer_engine::DType t)
}
}
void nvte_cudnn_handle_init() {
auto handle = cudnnExecutionPlanManager::Instance().GetCudnnHandle();
}
void nvte_cudnn_handle_init() { auto _ = cudnnExecutionPlanManager::Instance().GetHandle(); }
namespace detail {
void CreateCuDNNHandle(cudnnHandle_t* handle) { NVTE_CHECK_CUDNN(cudnnCreate(handle)); }
} // namespace detail
#endif
} // namespace transformer_engine
......@@ -70,6 +74,6 @@ namespace cudnn_frontend {
// This is needed to define the symbol `cudnn_dlhandle`
// When using the flag NV_CUDNN_FRONTEND_USE_DYNAMIC_LOADING
// to enable dynamic loading.
void *cudnn_dlhandle = nullptr;
void* cudnn_dlhandle = nullptr;
} // namespace cudnn_frontend
......@@ -11,40 +11,26 @@
#include <cudnn.h>
#include <cudnn_frontend.h>
#include <cudnn_frontend_utils.h>
#include <cudnn_graph.h>
#endif
#include <cstdint>
#include <mutex>
#include "transformer_engine/transformer_engine.h"
#include "util/handle_manager.h"
namespace transformer_engine {
#ifndef __HIP_PLATFORM_AMD__
cudnnDataType_t get_cudnn_dtype(const transformer_engine::DType t);
namespace detail {
cudnn_frontend::DataType_t get_cudnn_fe_dtype(const transformer_engine::DType t);
void CreateCuDNNHandle(cudnnHandle_t* handle);
class cudnnExecutionPlanManager {
public:
static cudnnExecutionPlanManager &Instance() {
static thread_local cudnnExecutionPlanManager instance;
return instance;
}
} // namespace detail
cudnnHandle_t GetCudnnHandle() {
static thread_local std::once_flag flag;
std::call_once(flag, [&] { cudnnCreate(&handle_); });
return handle_;
}
cudnnDataType_t get_cudnn_dtype(const transformer_engine::DType t);
~cudnnExecutionPlanManager() {}
cudnn_frontend::DataType_t get_cudnn_fe_dtype(const transformer_engine::DType t);
private:
cudnnHandle_t handle_ = nullptr;
};
using cudnnExecutionPlanManager = detail::HandleManager<cudnnHandle_t, detail::CreateCuDNNHandle>;
#endif
} // namespace transformer_engine
#endif
#endif // TRANSFORMER_ENGINE_CUDNN_UTILS_H_
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment