Commit 4099aa8e authored by yuguo's avatar yuguo
Browse files
parents c520cba3 96f9c6de
---
name: Bug report
about: Create a report to help us improve
title: ''
labels: bug
assignees: ''
---
**Describe the bug**
A clear and concise description of what the bug is.
**Steps/Code to reproduce bug**
Please list *minimal* steps or code snippet for us to be able to reproduce the bug.
A helpful guide on on how to craft a minimal bug report http://matthewrocklin.com/blog/work/2018/02/28/minimal-bug-reports.
**Expected behavior**
A clear and concise description of what you expected to happen.
**Environment overview (please complete the following information)**
- Environment location: [Bare-metal, Docker, Cloud(specify cloud provider - AWS, Azure, GCP, Collab)]
- Method of Transformer Engine install: [pip install or from source]. Please specify exact commands you used to install.
- If method of install is [Docker], provide `docker pull` & `docker run` commands used
**Environment details**
If NVIDIA docker image is used you don't need to specify these.
Otherwise, please provide:
- OS version
- PyTorch version
- Python version
- Transformer Engine version
- CUDA version
- CUDNN version
**Device details**
- GPU model
**Additional context**
Add any other context about the problem here.
---
name: Feature request
about: Suggest an idea for this project
title: ''
labels: feature request
assignees: ''
---
**Is your feature request related to a problem? Please describe.**
A clear and concise description of what the problem is. Ex. I'm always frustrated when [...]
**Describe the solution you'd like**
A clear and concise description of what you want to happen.
Provide a code snippet on how new APIs/changes would be used by others.
**Describe alternatives you've considered**
A clear and concise description of any alternative solutions or features you've considered.
**Additional context**
Add any other context or screenshots about the feature request here.
......@@ -9,7 +9,12 @@ set -e
TE_LIB_PATH=`pip3 show transformer-engine | grep Location | cut -d ' ' -f 2`
export LD_LIBRARY_PATH=$TE_LIB_PATH:$LD_LIBRARY_PATH
# Set parallelization parameters
NUM_PHYSICAL_CORES=$(nproc)
NUM_PARALLEL_JOBS=4
cd $TE_PATH/tests/cpp
cmake -GNinja -Bbuild .
cmake --build build
ctest --test-dir build -j4
export OMP_NUM_THREADS=$((NUM_PHYSICAL_CORES / NUM_PARALLEL_JOBS))
ctest --test-dir build -j$NUM_PARALLEL_JOBS
......@@ -2,14 +2,33 @@
#
# See LICENSE for license information.
set -xe
function error_exit() {
echo "Error: $1"
exit 1
}
function test_fail() {
RET=1
FAILED_CASES="$FAILED_CASES $1"
echo "Error: sub-test failed: $1"
}
RET=0
FAILED_CASES=""
: ${TE_PATH:=/opt/transformerengine}
pip3 install -r $TE_PATH/examples/jax/encoder/requirements.txt
pip3 install -r $TE_PATH/examples/jax/encoder/requirements.txt || error_exit "Failed to install requirements"
# Make encoder tests to have run-to-run deterministic to have the stable CI results
export XLA_FLAGS="${XLA_FLAGS} --xla_gpu_deterministic_ops"
python3 -m pytest -c $TE_PATH/tests/jax/pytest.ini -v $TE_PATH/examples/jax/encoder/test_multigpu_encoder.py
python3 -m pytest -c $TE_PATH/tests/jax/pytest.ini -v $TE_PATH/examples/jax/encoder/test_model_parallel_encoder.py
. $TE_PATH/examples/jax/encoder/run_test_multiprocessing_encoder.sh
python3 -m pytest -c $TE_PATH/tests/jax/pytest.ini -v $TE_PATH/examples/jax/encoder/test_multigpu_encoder.py || test_fail "test_multigpu_encoder.py"
python3 -m pytest -c $TE_PATH/tests/jax/pytest.ini -v $TE_PATH/examples/jax/encoder/test_model_parallel_encoder.py || test_fail "test_model_parallel_encoder.py"
. $TE_PATH/examples/jax/encoder/run_test_multiprocessing_encoder.sh || test_fail "run_test_multiprocessing_encoder.sh"
if [ $RET -ne 0 ]; then
echo "Error: some sub-tests failed: $FAILED_CASES"
exit 1
fi
echo "All tests passed"
exit 0
......@@ -2,22 +2,41 @@
#
# See LICENSE for license information.
set -xe
function error_exit() {
echo "Error: $1"
exit 1
}
pip3 install "nltk>=3.8.2"
pip3 install pytest==8.2.1
function test_fail() {
RET=1
FAILED_CASES="$FAILED_CASES $1"
echo "Error: sub-test failed: $1"
}
RET=0
FAILED_CASES=""
pip3 install "nltk>=3.8.2" || error_exit "Failed to install nltk"
pip3 install pytest==8.2.1 || error_exit "Failed to install pytest"
: ${TE_PATH:=/opt/transformerengine}
python3 -m pytest -c $TE_PATH/tests/jax/pytest.ini -v $TE_PATH/tests/jax -k 'not distributed' --ignore=$TE_PATH/tests/jax/test_praxis_layers.py
python3 -m pytest -c $TE_PATH/tests/jax/pytest.ini -v $TE_PATH/tests/jax -k 'not distributed' --ignore=$TE_PATH/tests/jax/test_praxis_layers.py || test_fail "test_praxis_layers.py"
# Test without custom calls
NVTE_CUSTOM_CALLS_RE="" python3 -m pytest -c $TE_PATH/tests/jax/pytest.ini -v $TE_PATH/tests/jax/test_custom_call_compute.py
NVTE_CUSTOM_CALLS_RE="" python3 -m pytest -c $TE_PATH/tests/jax/pytest.ini -v $TE_PATH/tests/jax/test_custom_call_compute.py || test_fail "test_custom_call_compute.py"
pip3 install -r $TE_PATH/examples/jax/mnist/requirements.txt
pip3 install -r $TE_PATH/examples/jax/encoder/requirements.txt
pip3 install -r $TE_PATH/examples/jax/mnist/requirements.txt || error_exit "Failed to install mnist requirements"
pip3 install -r $TE_PATH/examples/jax/encoder/requirements.txt || error_exit "Failed to install encoder requirements"
python3 -m pytest -c $TE_PATH/tests/jax/pytest.ini -v $TE_PATH/examples/jax/mnist
python3 -m pytest -c $TE_PATH/tests/jax/pytest.ini -v $TE_PATH/examples/jax/mnist || test_fail "test_mnist.py"
# Make encoder tests to have run-to-run deterministic to have the stable CI results
export XLA_FLAGS="${XLA_FLAGS} --xla_gpu_deterministic_ops"
python3 -m pytest -c $TE_PATH/tests/jax/pytest.ini -v $TE_PATH/examples/jax/encoder/test_single_gpu_encoder.py
python3 -m pytest -c $TE_PATH/tests/jax/pytest.ini -v $TE_PATH/examples/jax/encoder/test_single_gpu_encoder.py || test_fail "test_single_gpu_encoder.py"
if [ $RET -ne 0 ]; then
echo "Error: some sub-tests failed: $FAILED_CASES"
exit 1
fi
echo "All tests passed"
exit 0
......@@ -2,34 +2,53 @@
#
# See LICENSE for license information.
set -e
function error_exit() {
echo "Error: $1"
exit 1
}
function test_fail() {
RET=1
FAILED_CASES="$FAILED_CASES $1"
echo "Error: sub-test failed: $1"
}
RET=0
FAILED_CASES=""
: "${TE_PATH:=/opt/transformerengine}"
pip3 install wheel
pip3 install wheel || error_exit "Failed to install wheel"
cd $TE_PATH
pip3 uninstall -y transformer-engine transformer-engine-cu12 transformer-engine-jax
pip3 uninstall -y transformer-engine transformer-engine-cu12 transformer-engine-jax || error_exit "Failed to uninstall transformer-engine transformer-engine-cu12 transformer-engine-jax"
VERSION=`cat $TE_PATH/build_tools/VERSION.txt`
WHL_BASE="transformer_engine-${VERSION}"
# Core wheel.
NVTE_RELEASE_BUILD=1 python3 setup.py bdist_wheel
wheel unpack dist/*
NVTE_RELEASE_BUILD=1 python3 setup.py bdist_wheel || error_exit "Failed to setup bdist_wheel"
wheel unpack dist/* || error_exit "Failed to unpack dist/*"
sed -i "s/Name: transformer-engine/Name: transformer-engine-cu12/g" "transformer_engine-${VERSION}/transformer_engine-${VERSION}.dist-info/METADATA"
sed -i "s/Name: transformer_engine/Name: transformer_engine_cu12/g" "transformer_engine-${VERSION}/transformer_engine-${VERSION}.dist-info/METADATA"
mv "${WHL_BASE}/${WHL_BASE}.dist-info" "${WHL_BASE}/transformer_engine_cu12-${VERSION}.dist-info"
wheel pack ${WHL_BASE}
rm dist/*.whl
mv *.whl dist/
NVTE_RELEASE_BUILD=1 NVTE_BUILD_METAPACKAGE=1 python3 setup.py bdist_wheel
mv "${WHL_BASE}/${WHL_BASE}.dist-info" "${WHL_BASE}/transformer_engine_cu12-${VERSION}.dist-info" || error_exit "Failed to move ${WHL_BASE}.dist-info to transformer_engine_cu12-${VERSION}.dist-info"
wheel pack ${WHL_BASE} || error_exit "Failed to pack ${WHL_BASE}"
rm dist/*.whl || error_exit "Failed to remove dist/*.whl"
mv *.whl dist/ || error_exit "Failed to move *.whl to dist/"
NVTE_RELEASE_BUILD=1 NVTE_BUILD_METAPACKAGE=1 python3 setup.py bdist_wheel || error_exit "Failed to setup metapackage"
cd transformer_engine/jax
NVTE_RELEASE_BUILD=1 python3 setup.py sdist
NVTE_RELEASE_BUILD=1 python3 setup.py sdist || error_exit "Failed to setup sdist"
pip3 install dist/*
pip3 install dist/* || error_exit "Failed to install dist/*"
cd $TE_PATH
pip3 install dist/*.whl --no-deps
pip3 install dist/*.whl --no-deps || error_exit "Failed to install dist/*.whl --no-deps"
python3 $TE_PATH/tests/jax/test_sanity_import.py || test_fail "test_sanity_import.py"
python3 $TE_PATH/tests/jax/test_sanity_import.py
if [ $RET -ne 0 ]; then
echo "Error: some sub-tests failed: $FAILED_CASES"
exit 1
fi
echo "All tests passed"
exit 0
......@@ -2,29 +2,46 @@
#
# See LICENSE for license information.
set -x
function error_exit() {
echo "Error: $1"
exit 1
}
: ${TE_PATH:=/opt/transformerengine}
function test_fail() {
RET=1
FAILED_CASES="$FAILED_CASES $1"
echo "Error: sub-test failed: $1"
}
pip3 install pytest==8.2.1
RET=0
FAILED_CASES=""
FAIL=0
set -x
: ${TE_PATH:=/opt/transformerengine}
python3 -m pytest -v -s $TE_PATH/tests/pytorch/test_sanity.py || FAIL=1
python3 -m pytest -v -s $TE_PATH/tests/pytorch/test_recipe.py || FAIL=1
python3 -m pytest -v -s $TE_PATH/tests/pytorch/test_deferred_init.py || FAIL=1
PYTORCH_JIT=0 NVTE_TORCH_COMPILE=0 NVTE_ALLOW_NONDETERMINISTIC_ALGO=0 python3 -m pytest -v -s $TE_PATH/tests/pytorch/test_numerics.py || FAIL=1
NVTE_CUDNN_MXFP8_NORM=0 PYTORCH_JIT=0 NVTE_TORCH_COMPILE=0 NVTE_ALLOW_NONDETERMINISTIC_ALGO=0 python3 -m pytest -v -s $TE_PATH/tests/pytorch/test_cuda_graphs.py || FAIL=1
python3 -m pytest -v -s $TE_PATH/tests/pytorch/test_jit.py || FAIL=1
python3 -m pytest -v -s $TE_PATH/tests/pytorch/test_fused_rope.py || FAIL=1
python3 -m pytest -v -s $TE_PATH/tests/pytorch/test_float8tensor.py || FAIL=1
python3 -m pytest -v -s $TE_PATH/tests/pytorch/test_gqa.py || FAIL=1
python3 -m pytest -v -s $TE_PATH/tests/pytorch/test_fused_optimizer.py || FAIL=1
python3 -m pytest -v -s $TE_PATH/tests/pytorch/test_multi_tensor.py || FAIL=1
python3 -m pytest -v -s $TE_PATH/tests/pytorch/test_fusible_ops.py || FAIL=1
python3 -m pytest -v -s $TE_PATH/tests/pytorch/test_permutation.py || FAIL=1
python3 -m pytest -v -s $TE_PATH/tests/pytorch/test_parallel_cross_entropy.py || FAIL=1
python3 -m pytest -v -s $TE_PATH/tests/pytorch/test_cpu_offloading.py || FAIL=1
NVTE_DEBUG=1 NVTE_DEBUG_LEVEL=1 python3 -m pytest -o log_cli=true --log-cli-level=INFO -v -s $TE_PATH/tests/pytorch/fused_attn/test_fused_attn.py || FAIL=1
pip3 install pytest==8.2.1 || error_exit "Failed to install pytest"
exit $FAIL
python3 -m pytest -v -s $TE_PATH/tests/pytorch/test_sanity.py || test_fail "test_sanity.py"
python3 -m pytest -v -s $TE_PATH/tests/pytorch/test_recipe.py || test_fail "test_recipe.py"
python3 -m pytest -v -s $TE_PATH/tests/pytorch/test_deferred_init.py || test_fail "test_deferred_init.py"
PYTORCH_JIT=0 NVTE_TORCH_COMPILE=0 NVTE_ALLOW_NONDETERMINISTIC_ALGO=0 python3 -m pytest -v -s $TE_PATH/tests/pytorch/test_numerics.py || test_fail "test_numerics.py"
NVTE_CUDNN_MXFP8_NORM=0 PYTORCH_JIT=0 NVTE_TORCH_COMPILE=0 NVTE_ALLOW_NONDETERMINISTIC_ALGO=0 python3 -m pytest -v -s $TE_PATH/tests/pytorch/test_cuda_graphs.py || test_fail "test_cuda_graphs.py"
python3 -m pytest -v -s $TE_PATH/tests/pytorch/test_jit.py || test_fail "test_jit.py"
python3 -m pytest -v -s $TE_PATH/tests/pytorch/test_fused_rope.py || test_fail "test_fused_rope.py"
python3 -m pytest -v -s $TE_PATH/tests/pytorch/test_float8tensor.py || test_fail "test_float8tensor.py"
python3 -m pytest -v -s $TE_PATH/tests/pytorch/test_gqa.py || test_fail "test_gqa.py"
python3 -m pytest -v -s $TE_PATH/tests/pytorch/test_fused_optimizer.py || test_fail "test_fused_optimizer.py"
python3 -m pytest -v -s $TE_PATH/tests/pytorch/test_multi_tensor.py || test_fail "test_multi_tensor.py"
python3 -m pytest -v -s $TE_PATH/tests/pytorch/test_fusible_ops.py || test_fail "test_fusible_ops.py"
python3 -m pytest -v -s $TE_PATH/tests/pytorch/test_permutation.py || test_fail "test_permutation.py"
python3 -m pytest -v -s $TE_PATH/tests/pytorch/test_parallel_cross_entropy.py || test_fail "test_parallel_cross_entropy.py"
python3 -m pytest -v -s $TE_PATH/tests/pytorch/test_cpu_offloading.py || test_fail "test_cpu_offloading.py"
NVTE_DEBUG=1 NVTE_DEBUG_LEVEL=1 python3 -m pytest -o log_cli=true --log-cli-level=INFO -v -s $TE_PATH/tests/pytorch/fused_attn/test_fused_attn.py || test_fail "test_fused_attn.py"
NVTE_DEBUG=1 NVTE_DEBUG_LEVEL=1 python3 -m pytest -o log_cli=true --log-cli-level=INFO -v -s $TE_PATH/tests/pytorch/fused_attn/test_paged_attn.py || test_fail "test_paged_attn.py"
if [ "$RET" -ne 0 ]; then
echo "Error in the following test cases:$FAILED_CASES"
exit 1
fi
echo "All tests passed"
exit 0
......@@ -2,34 +2,53 @@
#
# See LICENSE for license information.
set -e
function error_exit() {
echo "Error: $1"
exit 1
}
function test_fail() {
RET=1
FAILED_CASES="$FAILED_CASES $1"
echo "Error: sub-test failed: $1"
}
RET=0
FAILED_CASES=""
: "${TE_PATH:=/opt/transformerengine}"
pip3 install wheel
pip3 install wheel || error_exit "Failed to install wheel"
cd $TE_PATH
pip3 uninstall -y transformer-engine transformer-engine-cu12 transformer-engine-torch
pip3 uninstall -y transformer-engine transformer-engine-cu12 transformer-engine-torch || error_exit "Failed to uninstall transformer-engine transformer-engine-cu12 transformer-engine-torch"
VERSION=`cat $TE_PATH/build_tools/VERSION.txt`
WHL_BASE="transformer_engine-${VERSION}"
# Core wheel.
NVTE_RELEASE_BUILD=1 python3 setup.py bdist_wheel
wheel unpack dist/*
NVTE_RELEASE_BUILD=1 python3 setup.py bdist_wheel || error_exit "Failed to setup bdist_wheel"
wheel unpack dist/* || error_exit "Failed to unpack dist/*"
sed -i "s/Name: transformer-engine/Name: transformer-engine-cu12/g" "transformer_engine-${VERSION}/transformer_engine-${VERSION}.dist-info/METADATA"
sed -i "s/Name: transformer_engine/Name: transformer_engine_cu12/g" "transformer_engine-${VERSION}/transformer_engine-${VERSION}.dist-info/METADATA"
mv "${WHL_BASE}/${WHL_BASE}.dist-info" "${WHL_BASE}/transformer_engine_cu12-${VERSION}.dist-info"
wheel pack ${WHL_BASE}
rm dist/*.whl
mv *.whl dist/
NVTE_RELEASE_BUILD=1 NVTE_BUILD_METAPACKAGE=1 python3 setup.py bdist_wheel
mv "${WHL_BASE}/${WHL_BASE}.dist-info" "${WHL_BASE}/transformer_engine_cu12-${VERSION}.dist-info" || error_exit "Failed to move ${WHL_BASE}.dist-info to transformer_engine_cu12-${VERSION}.dist-info"
wheel pack ${WHL_BASE} || error_exit "Failed to pack ${WHL_BASE}"
rm dist/*.whl || error_exit "Failed to remove dist/*.whl"
mv *.whl dist/ || error_exit "Failed to move *.whl to dist/"
NVTE_RELEASE_BUILD=1 NVTE_BUILD_METAPACKAGE=1 python3 setup.py bdist_wheel || error_exit "Failed to setup metapackage"
cd transformer_engine/pytorch
NVTE_RELEASE_BUILD=1 python3 setup.py sdist
NVTE_RELEASE_BUILD=1 python3 setup.py sdist || error_exit "Failed to setup sdist"
pip3 install dist/*
pip3 install dist/* || error_exit "Failed to install dist/*"
cd $TE_PATH
pip3 install dist/*.whl --no-deps
pip3 install dist/*.whl --no-deps || error_exit "Failed to install dist/*.whl --no-deps"
python3 $TE_PATH/tests/pytorch/test_sanity_import.py || test_fail "test_sanity_import.py"
python3 $TE_PATH/tests/pytorch/test_sanity_import.py
if [ "$RET" -ne 0 ]; then
echo "Error in the following test cases:$FAILED_CASES"
exit 1
fi
echo "All tests passed"
exit 0
......@@ -6,4 +6,4 @@ set -xe
: ${TE_PATH:=/opt/transformerengine}
pytest -c $TE_PATH/tests/jax/pytest.ini -v $TE_PATH/tests/jax/test_distributed_*
python3 -m pytest -c $TE_PATH/tests/jax/pytest.ini -v $TE_PATH/tests/jax/test_distributed_*
......@@ -2,17 +2,34 @@
#
# See LICENSE for license information.
: ${TE_PATH:=/opt/transformerengine}
function error_exit() {
echo "Error: $1"
exit 1
}
function test_fail() {
RET=1
FAILED_CASES="$FAILED_CASES $1"
echo "Error: sub-test failed: $1"
}
pip3 install pytest==8.2.1
RET=0
FAILED_CASES=""
: ${TE_PATH:=/opt/transformerengine}
FAIL=0
pip3 install pytest==8.2.1 || error_exit "Failed to install pytest"
python3 -m pytest -v -s $TE_PATH/tests/pytorch/distributed/test_numerics.py || FAIL=1
python3 -m pytest -v -s $TE_PATH/tests/pytorch/distributed/test_fusible_ops.py || FAIL=1
python3 -m pytest -v -s $TE_PATH/tests/pytorch/distributed/test_torch_fsdp2.py || FAIL=1
python3 -m pytest -v -s $TE_PATH/tests/pytorch/distributed/test_comm_gemm_overlap.py || FAIL=1
# python3 -m pytest -v -s $TE_PATH/tests/pytorch/distributed/test_fusible_ops_with_userbuffers.py || FAIL=1 ### TODO Debug UB support with te.Sequential
python3 -m pytest -v -s $TE_PATH/tests/pytorch/fused_attn/test_fused_attn_with_cp.py || FAIL=1
python3 -m pytest -v -s $TE_PATH/tests/pytorch/distributed/test_numerics.py || test_fail "test_numerics.py"
python3 -m pytest -v -s $TE_PATH/tests/pytorch/distributed/test_fusible_ops.py || test_fail "test_fusible_ops.py"
python3 -m pytest -v -s $TE_PATH/tests/pytorch/distributed/test_torch_fsdp2.py || test_fail "test_torch_fsdp2.py"
python3 -m pytest -v -s $TE_PATH/tests/pytorch/distributed/test_comm_gemm_overlap.py || test_fail "test_comm_gemm_overlap.py"
# python3 -m pytest -v -s $TE_PATH/tests/pytorch/distributed/test_fusible_ops_with_userbuffers.py || test_fail "test_fusible_ops_with_userbuffers.py" ### TODO Debug UB support with te.Sequential
python3 -m pytest -v -s $TE_PATH/tests/pytorch/fused_attn/test_fused_attn_with_cp.py || test_fail "test_fused_attn_with_cp.py"
exit $FAIL
if [ "$RET" -ne 0 ]; then
echo "Error in the following test cases:$FAILED_CASES"
exit 1
fi
echo "All tests passed"
exit 0
......@@ -17,7 +17,7 @@ if [ $sm_arch -gt 90 ]
then
FA_versions=(2.7.3)
else
FA_versions=(2.1.1 2.3.0 2.4.1 2.5.7 2.7.3 3.0.0b1)
FA_versions=(2.3.0 2.4.1 2.5.7 2.7.3 3.0.0b1)
fi
for fa_version in "${FA_versions[@]}"
......@@ -28,13 +28,15 @@ do
then
pip3 install flash-attn==${fa_version}
else
pip3 install "git+https://github.com/Dao-AILab/flash-attention.git@v2.7.2#egg=flashattn-hopper&subdirectory=hopper"
python_path=`python3 -c "import site; print(site.getsitepackages()[0])"`
mkdir -p $python_path/flashattn_hopper
wget -P $python_path/flashattn_hopper https://raw.githubusercontent.com/Dao-AILab/flash-attention/v2.7.2/hopper/flash_attn_interface.py
git clone https://github.com/Dao-AILab/flash-attention.git
cd flash-attention/ && git checkout 27f501d && cd hopper/ && python setup.py install
python_path=`python -c "import site; print(site.getsitepackages()[0])"`
mkdir -p $python_path/flash_attn_3
wget -P $python_path/flash_attn_3 https://raw.githubusercontent.com/Dao-AILab/flash-attention/27f501dbe011f4371bff938fe7e09311ab3002fa/hopper/flash_attn_interface.py
cd ../../
fi
# Run tests
NVTE_TORCH_COMPILE=0 pytest -v -s $TE_PATH/tests/pytorch/fused_attn/test_fused_attn.py
NVTE_TORCH_COMPILE=0 python3 -m pytest -v -s $TE_PATH/tests/pytorch/fused_attn/test_fused_attn.py
done
......@@ -115,21 +115,50 @@ void compute_ref_x1(const ProcessingMethod processing_method,
const size_t block_size_X,
const size_t scales_stride)
{
std::vector<float> output_dbias_fp32(cols, 0);
const size_t tile_size_Y = std::max(32lu, block_size_Y);
const size_t tile_size_X = std::max(64lu, block_size_X);
const size_t tiles_num_Y = (rows + tile_size_Y - 1) / tile_size_Y;
const size_t tiles_num_X = (cols + tile_size_X - 1) / tile_size_X;
const size_t blocks_per_tile_Y = tile_size_Y / block_size_Y;
const size_t blocks_per_tile_X = tile_size_X / block_size_X;
const size_t blocks_Y = (rows + block_size_Y - 1) / block_size_Y;
const size_t blocks_X = (cols + block_size_X - 1) / block_size_X;
for (size_t ii = 0; ii < blocks_Y; ++ii) {
const size_t i_min = ii * block_size_Y;
const size_t i_max = std::min((ii + 1) * block_size_Y, rows);
for (size_t jj = 0; jj < blocks_X; ++jj) {
const size_t j_min = jj * block_size_X;
const size_t j_max = std::min((jj + 1) * block_size_X, cols);
const size_t scale_idx = ii * scales_stride + jj;
scale_block<InputType, OutputType, OP>(
processing_method, input, grad, output_c, output_dbias_fp32.data(),
output_scales, scale_idx, i_min, i_max, j_min, j_max, cols);
std::vector<float> output_dbias_fp32(cols, 0);
#pragma omp parallel proc_bind(spread)
{
std::vector<float> thread_dbias(cols, 0);
#pragma omp for schedule(static)
for (size_t t = 0; t < tiles_num_Y * tiles_num_X; ++t) {
const size_t tile_Y = t / tiles_num_X;
const size_t tile_X = t % tiles_num_X;
const size_t tile_offset_Y = tile_Y * tile_size_Y;
const size_t tile_offset_X = tile_X * tile_size_X;
for (size_t ii = 0; ii < blocks_per_tile_Y; ++ii) {
const size_t block_idx_Y = tile_Y * blocks_per_tile_Y + ii;
const size_t block_offset_Y = ii * block_size_Y;
const size_t i_min = tile_offset_Y + block_offset_Y;
if (i_min >= rows) continue;
const size_t i_max = std::min(i_min + block_size_Y, rows);
for (size_t jj = 0; jj < blocks_per_tile_X; ++jj) {
const size_t block_idx_X = tile_X * blocks_per_tile_X + jj;
const size_t block_offset_X = jj * block_size_X;
const size_t j_min = tile_offset_X + block_offset_X;
if (j_min >= cols) continue;
const size_t j_max = std::min(j_min + block_size_X, cols);
const size_t scale_idx = block_idx_Y * scales_stride + block_idx_X;
scale_block<InputType, OutputType, OP>(
processing_method, input, grad, output_c, thread_dbias.data(),
output_scales, scale_idx, i_min, i_max, j_min, j_max, cols);
}
}
}
#pragma omp critical
{
for (size_t j = 0; j < cols; ++j) {
output_dbias_fp32[j] += thread_dbias[j];
}
}
}
for (size_t j = 0; j < cols; ++j) {
......
......@@ -61,18 +61,38 @@ void compute_ref_x1(const InputType* input,
const size_t block_size_X,
const size_t scales_stride)
{
const size_t blocks_Y = (rows + block_size_Y - 1) / block_size_Y;
const size_t blocks_X = (cols + block_size_X - 1) / block_size_X;
for (size_t ii = 0; ii < blocks_Y; ++ii) {
const size_t i_min = ii * block_size_Y;
const size_t i_max = std::min((ii + 1) * block_size_Y, rows);
for (size_t jj = 0; jj < blocks_X; ++jj) {
const size_t j_min = jj * block_size_X;
const size_t j_max = std::min((jj + 1) * block_size_X, cols);
const size_t scale_idx = ii * scales_stride + jj;
dequantize_block<InputType, OutputType>(
input, output, scales, scale_idx, i_min, i_max, j_min, j_max, cols);
const size_t tile_size_Y = std::max(32lu, block_size_Y);
const size_t tile_size_X = std::max(64lu, block_size_X);
const size_t tiles_num_Y = (rows + tile_size_Y - 1) / tile_size_Y;
const size_t tiles_num_X = (cols + tile_size_X - 1) / tile_size_X;
const size_t blocks_per_tile_Y = tile_size_Y / block_size_Y;
const size_t blocks_per_tile_X = tile_size_X / block_size_X;
#pragma omp parallel for schedule(static) proc_bind(spread)
for (size_t t = 0; t < tiles_num_Y * tiles_num_X; ++t) {
const size_t tile_Y = t / tiles_num_X;
const size_t tile_X = t % tiles_num_X;
const size_t tile_offset_Y = tile_Y * tile_size_Y;
const size_t tile_offset_X = tile_X * tile_size_X;
for (size_t ii = 0; ii < blocks_per_tile_Y; ++ii) {
const size_t block_idx_Y = tile_Y * blocks_per_tile_Y + ii;
const size_t block_offset_Y = ii * block_size_Y;
const size_t i_min = tile_offset_Y + block_offset_Y;
if (i_min >= rows) continue;
const size_t i_max = std::min(i_min + block_size_Y, rows);
for (size_t jj = 0; jj < blocks_per_tile_X; ++jj) {
const size_t block_idx_X = tile_X * blocks_per_tile_X + jj;
const size_t block_offset_X = jj * block_size_X;
const size_t j_min = tile_offset_X + block_offset_X;
if (j_min >= cols) continue;
const size_t j_max = std::min(j_min + block_size_X, cols);
const size_t scale_idx = block_idx_Y * scales_stride + block_idx_X;
dequantize_block<InputType, OutputType>(
input, output, scales, scale_idx, i_min, i_max, j_min, j_max, cols);
}
}
}
}
......
......@@ -683,10 +683,15 @@ void generate_data_uniformly(T* data, const size_t size, std::mt19937* gen) {
#pragma omp parallel proc_bind(spread)
{
std::mt19937 gen_local = *gen;
gen_local.discard(omp_get_thread_num() * 599);
const int thread_ID = omp_get_thread_num();
const int threads_num = omp_get_max_threads();
const int chunk_size = (size + threads_num - 1) / threads_num;
const int idx_min = chunk_size * thread_ID;
const int idx_max = std::min(chunk_size * (thread_ID + 1), static_cast<int>(size));
gen_local.discard(idx_min);
std::uniform_real_distribution<> dis(-2.0, 1.0);
#pragma omp for schedule(static)
for (size_t i = 0; i < size; ++i) {
for (int i = idx_min; i < idx_max; ++i) {
data[i] = static_cast<T>(dis(gen_local));
}
}
......
......@@ -27,6 +27,7 @@ from transformer_engine.pytorch.dot_product_attention.utils import (
check_set_window_size,
AttentionParams,
)
from transformer_engine.pytorch.dot_product_attention.inference import InferenceParams
from transformer_engine.pytorch.dot_product_attention.rope import RotaryPositionEmbedding
from transformer_engine.pytorch.constants import TE_DType
import transformer_engine.pytorch.cpp_extensions as ext
......@@ -97,6 +98,8 @@ class ModelConfig:
num_layers: int = 1,
bias_shape: str = "1hss",
window_size: Tuple[int, int] = (-1, -1),
total_requests: int = None,
max_ctx_len: int = None,
):
self.batch_size = batch_size
self.num_heads = num_heads
......@@ -115,6 +118,8 @@ class ModelConfig:
self.num_layers = num_layers
self.bias_shape = bias_shape
self.window_size = window_size
self.total_requests = total_requests
self.max_ctx_len = max_ctx_len
@contextmanager
......@@ -137,6 +142,8 @@ def _get_attention_backends(
deterministic: bool = False,
fp8: bool = False,
fp8_meta: Optional[Dict[str, Any]] = None,
is_training: bool = True,
inference_params: Optional[InferenceParams] = None,
) -> Tuple[List, List]:
"""Check if what attention backends support a model configuration"""
......@@ -166,6 +173,7 @@ def _get_attention_backends(
fused_attn_backends = []
available_backends = None
flash_attention_backend = None
fused_attention_backend = None
def test():
......@@ -191,10 +199,13 @@ def _get_attention_backends(
deterministic=deterministic,
fp8=fp8,
fp8_meta=fp8_meta,
is_training=is_training,
inference_params=inference_params,
)
(
use_flash_attention,
use_fused_attention,
flash_attention_backend,
fused_attention_backend,
use_unfused_attention,
available_backends,
......@@ -203,20 +214,21 @@ def _get_attention_backends(
# from get_attention_backend()
_attention_backends["use_flash_attention"] = use_flash_attention
_attention_backends["use_fused_attention"] = use_fused_attention
_attention_backends["flash_attention_backend"] = flash_attention_backend
_attention_backends["fused_attention_backend"] = fused_attention_backend
_attention_backends["use_unfused_attention"] = use_unfused_attention
_attention_backends["backend_selection_requires_update"] = False
return available_backends, fused_attention_backend
return available_backends, flash_attention_backend, fused_attention_backend
backends = {0: "F16_max512_seqlen", 1: "F16_arbitrary_seqlen", 2: "FP8"}
with logging_context():
for i in range(3):
os.environ["NVTE_FUSED_ATTN_BACKEND"] = str(i)
_attention_backends["backend_selection_requires_update"] = True
available_backends, fused_attention_backend = test()
available_backends, flash_attention_backend, fused_attention_backend = test()
if fused_attention_backend == FusedAttnBackend[backends[i]]:
fused_attn_backends.append(fused_attention_backend)
return available_backends, fused_attn_backends
return available_backends, flash_attention_backend, fused_attn_backends
model_configs_base = {
......@@ -268,7 +280,7 @@ def test_dot_product_attention(
if config.window_size == (-1, -1) and swa:
config.window_size = [2, 2]
config.window_size = check_set_window_size(config.attn_mask_type, config.window_size)
available_backends, fused_attn_backends = _get_attention_backends(
available_backends, _, fused_attn_backends = _get_attention_backends(
config,
qkv_dtype=dtype,
qkv_layout=qkv_layout,
......@@ -1153,7 +1165,7 @@ def test_transformer_layer(
qkv_layout = "sbhd_sbhd_sbhd"
# Test backend availability
available_backends, fused_attn_backends = _get_attention_backends(
available_backends, _, fused_attn_backends = _get_attention_backends(
config,
qkv_dtype=dtype,
qkv_layout=qkv_layout,
......
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
from collections import OrderedDict
from typing import List
import os
import logging
import math
import pytest
import torch
from torch.distributions import Exponential
from transformer_engine.pytorch import make_graphed_callables
from transformer_engine.common import recipe
from transformer_engine.pytorch import fp8_autocast, fp8_model_init
from transformer_engine.pytorch.transformer import (
TransformerLayer,
)
from transformer_engine.pytorch.attention import DotProductAttention
from transformer_engine.pytorch.dot_product_attention.inference import InferenceParams
from transformer_engine.pytorch.dot_product_attention.utils import FlashAttentionUtils as fa_utils
from transformer_engine.pytorch.utils import (
get_device_compute_capability,
init_method_normal,
scaled_init_method_normal,
is_bf16_compatible,
)
from test_fused_attn import (
ModelConfig,
reset_rng_states,
_get_attention_backends,
)
# Initialize RNG state
seed = 1234
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
_cpu_rng_state = torch.get_rng_state()
_cuda_rng_state = torch.cuda.get_rng_state()
param_types = [torch.float16]
if is_bf16_compatible():
param_types.append(torch.bfloat16)
model_configs_infer = {
# test: b, h, hg, d, sq, skv, p, mask, bias
"infer_0": ModelConfig(
4, 16, 16, 128, 64, 64, 0.0, "no_mask", "no_bias", total_requests=8, max_ctx_len=16
),
"infer_1": ModelConfig(
2, 16, 4, 64, 66, 66, 0.0, "no_mask", "no_bias", total_requests=6, max_ctx_len=16
),
}
qkv_formats = ["bshd", "sbhd", "thd"]
def to_pretty_string(x: torch.Tensor):
return "[" + ",".join(["{:>3s}".format(str(i)) for i in x.tolist()]) + "]"
def round_up(a: int, b: int):
return b * math.ceil(a / b)
class Simulation:
def __init__(
self,
total_requests: int = 10,
max_seq_len: int = 1024,
max_ctx_len: int = 128,
max_batch_size: int = 5,
poisson_rate: float = 1,
):
self.total_requests = total_requests
self.max_seq_len = max_seq_len
self.max_batch_size = max_batch_size
self.poisson_rate = poisson_rate
# calculate maximum context/generation length
self.max_ctx_len = max_ctx_len
self.max_gen_len = max_seq_len - self.max_ctx_len
# simulate sequence ids in monotonically increasing fashion
self.seq_ids = torch.range(0, total_requests - 1, dtype=torch.int32, device="cpu")
# simulate context lengths in Uniform distribution
self.context_lens = torch.randint(
1, self.max_ctx_len, [total_requests], dtype=torch.int32, device="cpu"
)
# simulate gen lengths in Exponential distribution
gen_dist = Exponential(1 / self.max_gen_len)
gen_lens = gen_dist.sample((total_requests,))
gen_lens = torch.where(gen_lens > self.max_gen_len, self.max_gen_len, gen_lens).to(
dtype=torch.int32, device="cpu"
)
self.gen_lens = torch.where(gen_lens == 0, 1, gen_lens).to(dtype=torch.int32, device="cpu")
# simulate arrival times in Poisson distribution
if poisson_rate is None:
self.poisson_rate = torch.randint(1, max_batch_size, [1]).item()
interval_dist = Exponential(self.poisson_rate)
arrival_intervals = interval_dist.sample((total_requests,))
self.arrival_times = torch.cumsum(arrival_intervals, dim=0).to(
dtype=torch.int32, device="cpu"
)
self.last_arrival = self.arrival_times.max().item()
# initialize tensors
self.reset()
def reset(self):
self.t = 0
self.request_delays = torch.zeros([self.total_requests], dtype=torch.int32, device="cpu")
self.delayed_seq_ids = torch.Tensor().to(dtype=torch.int32, device="cpu")
self.serving_times = self.arrival_times
self.complete_times = self.arrival_times
# batch info at step t
self.t_seq_ids = torch.Tensor([]).to(dtype=torch.bool, device="cpu")
self.t_ctx_lens = torch.Tensor([]).to(dtype=torch.bool, device="cpu")
self.t_gen_lens = torch.Tensor([]).to(dtype=torch.bool, device="cpu")
self.t_total_lens = self.t_ctx_lens + self.t_gen_lens
self.t_batch_size = 0
# step info from step t-1 to t
self.step_lens = torch.Tensor([]).to(dtype=torch.int32, device="cpu")
def print_setup(self, logger):
logger.info("Simulation:")
logger.info(" {:<31s}: {}".format("total number of requests", self.total_requests))
logger.info(" {:<31s}: {}".format("max sequence length per request", self.max_seq_len))
logger.info(" {:<31s}: {}".format("max context length", self.max_ctx_len))
logger.info(" {:<31s}: {}".format("max generation length", self.max_gen_len))
logger.info(" {:<31s}: {}".format("max batch size per iteration", self.max_batch_size))
logger.info(" {:<31s}: {}".format("Poisson rate", self.poisson_rate))
logger.info(" {:<17s}: {}".format("sequence ids", to_pretty_string(self.seq_ids)))
logger.info(" {:<17s}: {}".format("arrival times", to_pretty_string(self.arrival_times)))
logger.info(" {:<17s}: {}".format("context lengths", to_pretty_string(self.context_lens)))
logger.info(" {:<17s}: {}".format("generation lengths", to_pretty_string(self.gen_lens)))
def print_step(self, logger):
logger.info(f"Step t = {self.t}:")
logger.info(" {:<15s}: {}".format("t_batch_size", self.t_batch_size))
logger.info(" {:<15s}: {}".format("t_seq_ids", self.t_seq_ids.tolist()))
logger.info(" {:<15s}: {}".format("t_ctx_lens", self.t_ctx_lens.tolist()))
logger.info(" {:<15s}: {}".format("t_gen_lens", self.t_gen_lens.tolist()))
logger.info(" {:<15s}: {}".format("t_total_lens", self.t_total_lens.tolist()))
logger.info(" {:<15s}: {}".format("step_lens", self.step_lens.tolist()))
def print_summary(self, logger):
logger.info("Summary:")
logger.info(" {:<18s}: {}".format("total steps taken", self.t))
logger.info(" {:<18s}: {}".format("arrival_times", to_pretty_string(self.arrival_times)))
logger.info(" {:<18s}: {}".format("serving_times", to_pretty_string(self.serving_times)))
logger.info(" {:<18s}: {}".format("total_gen_lens", to_pretty_string(self.gen_lens)))
logger.info(" {:<18s}: {}".format("complete_times", to_pretty_string(self.complete_times)))
def add_new_seqs(self, new_seq_ids):
# get ctx_lens for new seqs
self.t_seq_ids = torch.cat([self.t_seq_ids, new_seq_ids], dim=0)
self.t_ctx_lens = torch.cat([self.t_ctx_lens, self.context_lens[new_seq_ids]], dim=0)
gen_lens = torch.Tensor([0] * len(new_seq_ids)).to(dtype=torch.int32, device="cpu")
self.t_gen_lens = torch.cat([self.t_gen_lens, gen_lens], dim=0)
# append new seqs' ctx_lens to step_lens
self.step_lens = torch.cat([self.step_lens, self.context_lens[new_seq_ids]], dim=0)
def remove_finished(self):
# figure out which seqs have finished
finished = torch.where(self.t_gen_lens - self.gen_lens[self.t_seq_ids] < 0, False, True).to(
dtype=torch.bool, device="cpu"
)
self.t_seq_ids = self.t_seq_ids[~finished]
self.t_ctx_lens = self.t_ctx_lens[~finished]
self.t_gen_lens = self.t_gen_lens[~finished]
# add ones for unfinished seqs to step_lens
self.step_lens = torch.ones([len(self.t_seq_ids)], dtype=torch.int32, device="cpu")
def step(self, dynamic_fill: bool = True):
# remove finished seqs
if self.t != 0:
self.remove_finished()
# get allowed new seqs
arrived_seq_ids = torch.where(self.arrival_times == self.t, True, False).nonzero().view(-1)
queuing_seq_ids = torch.cat([self.delayed_seq_ids, arrived_seq_ids], dim=0)
if dynamic_fill:
allowed_num_new_seqs = self.max_batch_size - len(self.t_seq_ids)
else:
allowed_num_new_seqs = 0 if len(self.t_seq_ids) else self.max_batch_size
if len(queuing_seq_ids) > allowed_num_new_seqs:
new_seq_ids = queuing_seq_ids[:allowed_num_new_seqs]
self.delayed_seq_ids = queuing_seq_ids[allowed_num_new_seqs:]
self.request_delays[self.delayed_seq_ids.tolist()] += 1
else:
new_seq_ids = queuing_seq_ids
self.delayed_seq_ids = torch.Tensor().to(dtype=torch.int32)
# add new seqs to batch
self.add_new_seqs(new_seq_ids)
# update batch variables
self.t_batch_size = len(self.t_seq_ids)
self.t_total_lens = self.t_ctx_lens + self.t_gen_lens
def get_model(
module: torch.nn.Module,
config: ModelConfig,
dtype: torch.dtype,
backend: str = "FusedAttention",
qkv_format: str = "bshd",
num_layers: int = 1,
mode: str = "reference",
is_fp8: bool = False,
):
reset_rng_states()
sigma = 0.023
init_method = init_method_normal(sigma)
output_layer_init_method = scaled_init_method_normal(sigma, num_layers)
if mode == "reference":
attn_mask_type = "causal"
qkv_format = "bshd"
if mode == "inference":
attn_mask_type = "padding_causal" if backend != "FusedAttention" else "padding"
fp8_recipe = recipe.DelayedScaling(
margin=0,
fp8_format=recipe.Format.HYBRID,
amax_history_len=1,
amax_compute_algo="most_recent",
fp8_dpa=is_fp8,
fp8_mha=False,
)
if module == "TransformerLayer":
hidden_size = config.head_dim_qk * config.num_heads
with fp8_model_init(enabled=is_fp8, recipe=fp8_recipe):
model = [
TransformerLayer(
hidden_size=hidden_size,
ffn_hidden_size=4 * hidden_size,
num_attention_heads=config.num_heads,
num_gqa_groups=config.num_gqa_groups,
hidden_dropout=0.0,
attention_dropout=config.dropout_p,
init_method=init_method,
output_layer_init_method=output_layer_init_method,
layer_number=layer_number,
kv_channels=config.head_dim_qk,
self_attn_mask_type=attn_mask_type,
fuse_qkv_params=False,
params_dtype=dtype,
attn_input_format=qkv_format,
)
.cuda()
.eval()
for layer_number in range(1, num_layers + 1)
]
if module == "DotProductAttention":
with fp8_model_init(enabled=is_fp8, recipe=fp8_recipe):
model = [
DotProductAttention(
kv_channels=config.head_dim_qk,
num_attention_heads=config.num_heads,
num_gqa_groups=config.num_gqa_groups,
layer_number=layer_number,
attention_dropout=config.dropout_p,
qkv_format=qkv_format,
attn_mask_type=attn_mask_type,
)
.cuda()
.eval()
for layer_number in range(1, num_layers + 1)
]
return model
def generate_args(
module: torch.nn.Module,
config: ModelConfig,
dtype: torch.dtype,
qkv_format: str = "bshd",
mode: str = "full_inputs",
):
# full inputs used as reference
if mode == "full_inputs":
warmup = False
shapes = []
if module == "TransformerLayer":
shapes.append(
[config.total_requests, config.max_seqlen_kv, config.num_heads * config.head_dim_qk]
)
if module == "DotProductAttention":
shapes.append(
[config.total_requests, config.max_seqlen_kv, config.num_heads, config.head_dim_qk]
)
shapes.append(
[
config.total_requests,
config.max_seqlen_kv,
config.num_gqa_groups,
config.head_dim_qk,
]
)
shapes.append(
[
config.total_requests,
config.max_seqlen_kv,
config.num_gqa_groups,
config.head_dim_v,
]
)
# sample args used for cuda graph warmup
elif mode == "sample_args":
warmup = True
shapes = []
if qkv_format == "bshd":
shape = [config.batch_size, config.max_ctx_len]
if qkv_format == "sbhd":
shape = [config.max_ctx_len, config.batch_size]
if qkv_format == "thd":
shape = [config.batch_size * config.max_ctx_len]
if module == "TransformerLayer":
shapes.append([*shape, config.num_heads * config.head_dim_qk])
if module == "DotProductAttention":
shapes.append([*shape, config.num_heads, config.head_dim_qk])
shapes.append([*shape, config.num_gqa_groups, config.head_dim_qk])
shapes.append([*shape, config.num_gqa_groups, config.head_dim_v])
num_tensors = len(shapes)
if warmup:
return [
torch.ones(
*shapes[i],
device="cuda",
dtype=dtype,
)
for i in range(num_tensors)
]
elif module == "TransformerLayer":
return [
0.01
* torch.randint(
-100,
100,
shapes[i],
device="cuda",
dtype=dtype,
)
for i in range(num_tensors)
]
elif module == "DotProductAttention":
return [
0.1
* torch.randn(
*shapes[i],
device="cuda",
dtype=dtype,
)
for i in range(num_tensors)
]
def get_tols(module, backend, dtype):
if module == "TransformerLayer":
tols = {
torch.half: (5e-3, 5e-3),
torch.bfloat16: (3.5e-2, 3.5e-2),
}
if module == "DotProductAttention":
tols = {
torch.half: (1e-3, 1e-3),
torch.bfloat16: (1e-2, 1e-3),
torch.float8_e4m3fn: (2e-2, 3e-2),
}
return tols[dtype]
@pytest.mark.parametrize("dtype", param_types)
@pytest.mark.parametrize("model", model_configs_infer.keys())
@pytest.mark.parametrize("qkv_format", qkv_formats)
@pytest.mark.parametrize("is_paged", [False, True])
@pytest.mark.parametrize("backend", ["FusedAttention", "FlashAttention", "UnfusedAttention"])
@pytest.mark.parametrize("module", ["TransformerLayer", "DotProductAttention"])
@pytest.mark.parametrize("is_cuda_graph", [False, True])
@pytest.mark.parametrize("is_fp8", [False, True])
def test_paged_attn(dtype, model, qkv_format, is_paged, backend, module, is_cuda_graph, is_fp8):
reset_rng_states()
logger = logging.getLogger("test_paged_attn")
fp8_recipe = recipe.DelayedScaling(
margin=0,
fp8_format=recipe.Format.HYBRID,
amax_history_len=1,
amax_compute_algo="most_recent",
fp8_dpa=is_fp8,
fp8_mha=False,
)
fp8_meta = {}
fp8_meta["recipe"] = fp8_recipe
config = model_configs_infer[model]
num_layers = 2 if module == "TransformerLayer" and backend != "FusedAttention" else 1
# flash-attn v2 requires page_size >= 256
if backend == "FlashAttention" and not fa_utils.v3_is_installed:
config_max_seqlen_q = config.max_seqlen_q
config_max_seqlen_kv = config.max_seqlen_kv
config.max_seqlen_q = 256
config.max_seqlen_kv = 256
# create a real-life simulation
max_batch_size = config.batch_size
page_size = None
total_num_pages = None
if is_paged:
page_size = 256 if backend == "FlashAttention" and not fa_utils.v3_is_installed else 1
config.max_seqlen_kv = round_up(config.max_seqlen_kv, page_size)
total_num_pages = int(max_batch_size * config.max_seqlen_kv / page_size)
else:
config.max_seqlen_kv = round_up(config.max_seqlen_kv, 64)
sim = Simulation(
total_requests=config.total_requests,
max_seq_len=config.max_seqlen_kv,
max_ctx_len=config.max_ctx_len,
max_batch_size=max_batch_size,
poisson_rate=2,
)
sim.print_setup(logger)
# initialize inference_params
inference_params = InferenceParams(
max_batch_size=max_batch_size,
max_seqlen_kv=config.max_seqlen_kv,
num_heads_kv=config.num_gqa_groups,
head_dim_k=config.head_dim_qk,
head_dim_v=config.head_dim_v,
dtype=dtype,
is_paged=is_paged,
page_size=page_size,
total_num_pages=total_num_pages,
max_ctx_len=config.max_ctx_len,
qkv_format=qkv_format,
)
if module == "DotProductAttention":
for layer_number in range(1, num_layers + 1):
inference_params.allocate_memory(layer_number)
# figure out supported backends
inference_params_qkv_format = "bshd"
qkv_layout = qkv_format + "_" + "_".join([inference_params_qkv_format] * 2)
if is_paged:
qkv_layout = "paged_kv_" + qkv_layout
available_backends, _, fused_attn_backends = _get_attention_backends(
config,
qkv_dtype=dtype,
qkv_layout=qkv_layout,
window_size=config.window_size,
pad_between_seqs=False,
is_training=False,
fp8=is_fp8,
fp8_meta=fp8_meta,
inference_params=inference_params,
)
flash_attn_supported, fused_attn_supported, unfused_attn_supported = available_backends
if backend == "FlashAttention" and not flash_attn_supported:
pytest.skip("FlashAttention backend is not supported")
if backend == "FusedAttention" and not fused_attn_supported:
pytest.skip("FusedAttention backend is not supported")
if backend == "UnfusedAttention" and not unfused_attn_supported:
pytest.skip("UnfusedAttention backend is not supported")
os.environ["NVTE_FLASH_ATTN"] = str(int(backend == "FlashAttention"))
os.environ["NVTE_FUSED_ATTN"] = str(int(backend == "FusedAttention"))
os.environ["NVTE_UNFUSED_ATTN"] = str(int(backend == "UnfusedAttention"))
if backend == "UnfusedAttention" and is_cuda_graph:
pytest.skip("CUDA graph is not supported for UnfusedAttention backend")
# TransformerLayer FP8 TN Gemm currently requires %8=0
if is_fp8 and not (qkv_format == "thd" and module == "DotProductAttention"):
pytest.skip("BSHD/SBHD <-> THD conversions for FP8 are not supported")
# create full model
logger.info("=== Generating all tokens at once ===")
model = get_model(module, config, dtype, backend, qkv_format, num_layers, mode="reference")
# generate data for all requests
full_inputs = generate_args(module, config, dtype, qkv_format="bshd", mode="full_inputs")
# generate reference results
if module == "DotProductAttention":
full_output = full_inputs
for m in model:
full_output = m(
*full_output if isinstance(full_output, List) else full_output,
)
if module == "TransformerLayer":
full_output = full_inputs
for m in model:
full_output = m(
full_output[0] if isinstance(full_output, List) else full_output,
)
# create inference model
logger.info("=== Generating one token at a time ===")
model = get_model(
module,
config,
dtype,
backend,
qkv_format,
num_layers,
mode="inference",
is_fp8=is_fp8,
)
# graph the model if necessary
if is_cuda_graph:
t_seq_ids = torch.range(0, max_batch_size, dtype=torch.int32, device="cpu")
step_lens = config.max_ctx_len * torch.ones(max_batch_size, dtype=torch.int32, device="cpu")
step_dict = OrderedDict(zip(t_seq_ids.tolist(), step_lens.tolist()))
inference_params.pre_step(step_dict)
sample_args = generate_args(
module, config, dtype, qkv_format=qkv_format, mode="sample_args"
)
sample_kwargs = {}
sample_kwargs["cu_seqlens_q"] = torch.linspace(
0,
config.batch_size * config.max_ctx_len,
steps=config.batch_size + 1,
device="cuda",
dtype=torch.int32,
)
sample_kwargs["cu_seqlens_kv"] = torch.linspace(
0,
config.batch_size * config.max_ctx_len,
steps=config.batch_size + 1,
device="cuda",
dtype=torch.int32,
)
sample_kwargs["inference_params"] = inference_params
sample_kwargs["max_seqlen_q"] = config.max_ctx_len
sample_kwargs["max_seqlen_kv"] = config.max_seqlen_kv
model = [
make_graphed_callables(
model[i],
sample_args,
num_warmup_iters=10,
fp8_enabled=is_fp8,
sample_kwargs=sample_kwargs,
fp8_recipe=fp8_recipe,
)
for i in range(num_layers)
]
sim.reset()
inference_params.reset()
step_dict = OrderedDict()
# simulate step by step
# t-1: ...
# compute for seq_ids = [0, 1, 2], ctx_lens = [5, 2, 3], gen_lens = [2, 9, 4],
# batch_size = 3, step_lens = [1, 1, 1]
# increase counter for gen_lens = [3, 10, 5]
# t: detect seq 1 is finished since expected_gen_lens = [12, 10, 15]
# add two new seqs 3 and 4, with ctx lens 10 and 11
# compute for seq_ids = [0, 2, 3, 4], ctx_lens = [5, 3, 10, 11], gen_lens = [3, 5, 0, 0],
# batch_size = 4, step_lens = [1, 1, 10, 11]
# increase counter for gen_lens = [3, 5, 1, 1]
max_tokens = config.batch_size * config.max_ctx_len
while True:
# prepare batch for the current step
dynamic_fill = True # inference_params.is_paged
sim.step(dynamic_fill=dynamic_fill)
sim.print_step(logger)
if sim.t_batch_size == 0:
# all sequences are finished
if sim.t > sim.last_arrival:
sim.serving_times = sim.arrival_times + sim.request_delays
sim.complete_times = sim.serving_times + sim.gen_lens
break
# not finished; run next iteration
else:
sim.t += 1
continue
# create incremental input
batch_size = max_batch_size if is_cuda_graph else sim.t_batch_size
max_seqlen_q = sim.max_ctx_len if is_cuda_graph else max(sim.step_lens).item()
num_tensors = len(full_inputs)
if qkv_format == "thd":
incremental_inputs = []
for i in range(num_tensors):
inp = full_inputs[i]
inc_inp = torch.Tensor().to(dtype=dtype, device="cuda")
for i, seq in enumerate(sim.t_seq_ids):
start = (sim.t_total_lens[i] - sim.step_lens[i]).item()
end = sim.t_total_lens[i].item()
inc_inp = torch.cat([inc_inp, inp[seq, start:end]], dim=0)
if is_cuda_graph:
inc_inp = torch.cat(
[
inc_inp,
torch.zeros(
max_tokens - sum(sim.step_lens),
*inp.shape[2:],
dtype=dtype,
device=inc_inp.device,
),
],
dim=0,
)
incremental_inputs.append(inc_inp)
else:
incremental_inputs = []
for i in range(num_tensors):
inp = full_inputs[i]
inc_inp = torch.zeros(
batch_size,
max_seqlen_q,
*inp.shape[2:],
dtype=dtype,
device="cuda",
)
for i, seq in enumerate(sim.t_seq_ids):
start = (sim.t_total_lens[i] - sim.step_lens[i]).item()
end = sim.t_total_lens[i].item()
inc_inp[i, : sim.step_lens[i], :] = inp[seq, start:end]
if qkv_format == "sbhd":
inc_inp = inc_inp.transpose(0, 1).contiguous()
incremental_inputs.append(inc_inp)
# run step
batch_size = max_batch_size if is_cuda_graph else sim.t_batch_size
cu_seqlens_q = torch.zeros(batch_size + 1, dtype=torch.int32, device="cuda")
cu_seqlens_q[1 : sim.t_batch_size + 1] = torch.cumsum(sim.step_lens, dim=0)
cu_seqlens_kv = cu_seqlens_q.clone()
step_dict = OrderedDict(zip(sim.t_seq_ids.tolist(), sim.step_lens.tolist()))
inference_params.pre_step(step_dict)
if inference_params.is_paged:
inference_params.cache_manager.print_cache()
incremental_output = incremental_inputs
with fp8_autocast(enabled=is_fp8, fp8_recipe=fp8_recipe):
for m in model:
incremental_output = m(
*incremental_output,
cu_seqlens_q=cu_seqlens_q,
cu_seqlens_kv=cu_seqlens_kv,
inference_params=inference_params,
max_seqlen_q=max_seqlen_q,
max_seqlen_kv=config.max_seqlen_kv,
)
incremental_output = [incremental_output]
incremental_output = incremental_output[0]
# compare results
atol, rtol = get_tols(module, backend, dtype=dtype if not is_fp8 else torch.float8_e4m3fn)
for i, seq in enumerate(sim.t_seq_ids):
token_index = sim.step_lens[i] - 1
if qkv_format == "bshd":
torch.testing.assert_close(
full_output[seq, sim.t_total_lens[i] - 1, :],
incremental_output[i, sim.step_lens[i] - 1, :],
atol=atol,
rtol=rtol,
)
if qkv_format == "sbhd":
torch.testing.assert_close(
full_output[seq, sim.t_total_lens[i] - 1, :],
incremental_output[sim.step_lens[i] - 1, i, :],
atol=atol,
rtol=rtol,
)
if qkv_format == "thd":
torch.testing.assert_close(
full_output[seq, sim.t_total_lens[i] - 1, :],
incremental_output[cu_seqlens_q[i + 1] - 1, :],
atol=atol,
rtol=rtol,
)
sim.t += 1
sim.t_gen_lens = sim.t_gen_lens + 1
# last value in complete_times should be equal to sim.t
sim.serving_times = sim.arrival_times + sim.request_delays
sim.complete_times = sim.serving_times + sim.gen_lens
sim.print_summary(logger)
if backend == "FlashAttention" and not fa_utils.v3_is_installed:
config.max_seqlen_q = config_max_seqlen_q
config.max_seqlen_kv = config_max_seqlen_kv
......@@ -2,9 +2,10 @@
#
# See LICENSE for license information.
from collections import OrderedDict
import math
import os
from typing import Dict, List, Optional
from typing import Dict, List, Tuple, Optional
import pytest
import copy
import random
......@@ -60,6 +61,8 @@ torch.cuda.manual_seed(seed)
_cpu_rng_state = torch.get_rng_state()
_cuda_rng_state = torch.cuda.get_rng_state()
torch._dynamo.config.recompile_limit = 16
class ModelConfig:
def __init__(self, hidden_size, eps, num_attention_heads, embed, num_layers, seq_len):
......@@ -78,9 +81,9 @@ model_configs = {
model_configs_inference = {
# hidden_size, eps, num_attention_heads, embed, num_layers, seq_len
"126m": ModelConfig(768, 1e-5, 12, 64, 12, 16),
"126m": ModelConfig(768, 1e-5, 12, 64, 12, 256),
}
backends_inference = ["FlashAttention", "UnfusedAttention"]
backends_inference = ["FlashAttention", "UnfusedAttention", "FusedAttention"]
module_inference = ["TransformerLayer", "MultiheadAttention"]
input_formats_inference = ["sbhd", "bshd"]
......@@ -329,9 +332,9 @@ class TorchLayerNormLinear(nn.Module):
in_features: int,
out_features: int,
eps: float,
bias: bool = True,
normalization: str = "LayerNorm",
zero_centered_gamma: bool = False,
bias: bool = True,
):
super().__init__()
if normalization == "LayerNorm":
......@@ -345,7 +348,7 @@ class TorchLayerNormLinear(nn.Module):
else:
raise RuntimeError("Unsupported normalization")
self.linear = nn.Linear(in_features, out_features)
self.linear = nn.Linear(in_features, out_features, bias=bias)
def forward(self, x: torch.Tensor) -> torch.Tensor:
return self.linear(self.layernorm(x))
......@@ -445,6 +448,7 @@ class TorchLayerNormMLP(nn.Module):
eps: float = 1e-5,
activation="gelu",
normalization: str = "LayerNorm",
bias: bool = True,
):
super().__init__()
if normalization == "LayerNorm":
......@@ -460,8 +464,8 @@ class TorchLayerNormMLP(nn.Module):
fc1_output_features = ffn_hidden_size
self.gelu = _supported_act[activation]
self.fc1 = nn.Linear(hidden_size, fc1_output_features)
self.fc2 = nn.Linear(ffn_hidden_size, hidden_size)
self.fc1 = nn.Linear(hidden_size, fc1_output_features, bias=bias)
self.fc2 = nn.Linear(ffn_hidden_size, hidden_size, bias=bias)
def forward(self, x):
t = self.gelu(self.fc1(self.ln(x)))
......@@ -672,8 +676,6 @@ def test_gpt_full_activation_recompute(
pytest.skip(reason_for_no_fp8)
if recipe.mxfp8() and not mxfp8_available:
pytest.skip(reason_for_no_mxfp8)
if fp8 and recipe.float8_current_scaling():
pytest.skip("Float8 Current Scaling unsupported for full recompute.")
config = model_configs[model]
......@@ -1039,6 +1041,8 @@ def _test_granular_accuracy(block, bs, dtype, config):
inp_hidden_states.retain_grad()
out = block(inp_hidden_states)
if isinstance(out, (List, Tuple)):
out = out[0]
loss = out.sum()
loss.backward()
......@@ -1117,32 +1121,53 @@ def test_dpa_accuracy(dtype, bs, model):
assert_allclose(te_output, torch_output, atol=5e-2, rtol=1e-2)
class TestReturnBiasModule(nn.Module):
def __init__(self, mod, **kwargs):
super().__init__()
self.te_module = mod(**kwargs)
self.return_bias = kwargs["return_bias"]
self.bias = kwargs["bias"]
def forward(self, x):
if self.return_bias:
out, bias = self.te_module(x)
if self.bias:
out = out + bias
return out
return self.te_module(x)
@pytest.mark.parametrize("dtype", param_types)
@pytest.mark.parametrize("bs", batch_sizes)
@pytest.mark.parametrize("model", ["small"])
def test_linear_accuracy(dtype, bs, model):
@pytest.mark.parametrize("return_bias", all_boolean)
@pytest.mark.parametrize("bias", all_boolean)
def test_linear_accuracy(dtype, bs, model, return_bias, bias):
config = model_configs[model]
te_linear = Linear(
config.hidden_size,
4 * config.hidden_size,
bias=True,
te_linear = TestReturnBiasModule(
Linear,
in_features=config.hidden_size,
out_features=4 * config.hidden_size,
params_dtype=dtype,
return_bias=return_bias,
bias=bias,
device="cuda",
).eval()
)
torch_linear = torch.nn.Linear(
config.hidden_size,
4 * config.hidden_size,
bias=True,
bias=bias,
device="cuda",
dtype=dtype,
).eval()
)
# Share params
with torch.no_grad():
torch_linear.weight = Parameter(te_linear.weight.clone())
torch_linear.bias = Parameter(te_linear.bias.clone())
torch_linear.weight = Parameter(te_linear.te_module.weight.clone())
if bias:
torch_linear.bias = Parameter(te_linear.te_module.bias.clone())
te_outputs = _test_granular_accuracy(te_linear, bs, dtype, config)
torch_outputs = _test_granular_accuracy(torch_linear, bs, dtype, config)
......@@ -1265,41 +1290,51 @@ def test_layernorm_accuracy(dtype, bs, model, eps, zero_centered_gamma):
@pytest.mark.parametrize("model", ["small"])
@pytest.mark.parametrize("normalization", all_normalizations)
@pytest.mark.parametrize("zero_centered_gamma", all_boolean)
def test_layernorm_linear_accuracy(dtype, bs, model, normalization, zero_centered_gamma):
@pytest.mark.parametrize("return_bias", all_boolean)
@pytest.mark.parametrize("bias", all_boolean)
def test_layernorm_linear_accuracy(
dtype, bs, model, normalization, zero_centered_gamma, return_bias, bias
):
config = model_configs[model]
te_ln_linear = LayerNormLinear(
config.hidden_size,
4 * config.hidden_size,
config.eps,
bias=True,
te_ln_linear = TestReturnBiasModule(
LayerNormLinear,
in_features=config.hidden_size,
out_features=4 * config.hidden_size,
eps=config.eps,
normalization=normalization,
params_dtype=dtype,
zero_centered_gamma=zero_centered_gamma,
return_bias=return_bias,
bias=bias,
device="cuda",
).eval()
)
torch_ln_linear = (
TorchLayerNormLinear(
config.hidden_size,
4 * config.hidden_size,
config.eps,
bias=True,
normalization=normalization,
zero_centered_gamma=zero_centered_gamma,
bias=bias,
)
.to(dtype=dtype)
.cuda()
.eval()
)
# Share params
with torch.no_grad():
torch_ln_linear.layernorm.weight = Parameter(te_ln_linear.layer_norm_weight.clone())
torch_ln_linear.layernorm.weight = Parameter(
te_ln_linear.te_module.layer_norm_weight.clone()
)
if normalization != "RMSNorm":
torch_ln_linear.layernorm.bias = Parameter(te_ln_linear.layer_norm_bias.clone())
torch_ln_linear.linear.weight = Parameter(te_ln_linear.weight.clone())
torch_ln_linear.linear.bias = Parameter(te_ln_linear.bias.clone())
torch_ln_linear.layernorm.bias = Parameter(
te_ln_linear.te_module.layer_norm_bias.clone()
)
torch_ln_linear.linear.weight = Parameter(te_ln_linear.te_module.weight.clone())
if bias:
torch_ln_linear.linear.bias = Parameter(te_ln_linear.te_module.bias.clone())
te_outputs = _test_granular_accuracy(te_ln_linear, bs, dtype, config)
torch_outputs = _test_granular_accuracy(torch_ln_linear, bs, dtype, config)
......@@ -1339,17 +1374,22 @@ def test_layernorm_linear_accuracy(dtype, bs, model, normalization, zero_centere
@pytest.mark.parametrize("model", ["small"])
@pytest.mark.parametrize("activation", all_activations)
@pytest.mark.parametrize("normalization", all_normalizations)
def test_layernorm_mlp_accuracy(dtype, bs, model, activation, normalization):
@pytest.mark.parametrize("return_bias", all_boolean)
@pytest.mark.parametrize("bias", all_boolean)
def test_layernorm_mlp_accuracy(dtype, bs, model, activation, normalization, return_bias, bias):
config = model_configs[model]
te_ln_mlp = LayerNormMLP(
config.hidden_size,
4 * config.hidden_size,
te_ln_mlp = TestReturnBiasModule(
LayerNormMLP,
hidden_size=config.hidden_size,
ffn_hidden_size=4 * config.hidden_size,
activation=activation,
normalization=normalization,
params_dtype=dtype,
return_bias=return_bias,
bias=bias,
device="cuda",
).eval()
)
torch_ln_mlp = (
TorchLayerNormMLP(
......@@ -1357,21 +1397,22 @@ def test_layernorm_mlp_accuracy(dtype, bs, model, activation, normalization):
4 * config.hidden_size,
activation=activation,
normalization=normalization,
bias=bias,
)
.to(dtype=dtype)
.cuda()
.eval()
)
# Share params
with torch.no_grad():
torch_ln_mlp.ln.weight = Parameter(te_ln_mlp.layer_norm_weight.clone())
torch_ln_mlp.ln.weight = Parameter(te_ln_mlp.te_module.layer_norm_weight.clone())
if normalization != "RMSNorm":
torch_ln_mlp.ln.bias = Parameter(te_ln_mlp.layer_norm_bias.clone())
torch_ln_mlp.fc1.weight = Parameter(te_ln_mlp.fc1_weight.clone())
torch_ln_mlp.fc1.bias = Parameter(te_ln_mlp.fc1_bias.clone())
torch_ln_mlp.fc2.weight = Parameter(te_ln_mlp.fc2_weight.clone())
torch_ln_mlp.fc2.bias = Parameter(te_ln_mlp.fc2_bias.clone())
torch_ln_mlp.ln.bias = Parameter(te_ln_mlp.te_module.layer_norm_bias.clone())
torch_ln_mlp.fc1.weight = Parameter(te_ln_mlp.te_module.fc1_weight.clone())
torch_ln_mlp.fc2.weight = Parameter(te_ln_mlp.te_module.fc2_weight.clone())
if bias:
torch_ln_mlp.fc1.bias = Parameter(te_ln_mlp.te_module.fc1_bias.clone())
torch_ln_mlp.fc2.bias = Parameter(te_ln_mlp.te_module.fc2_bias.clone())
te_outputs = _test_granular_accuracy(te_ln_mlp, bs, dtype, config)
torch_outputs = _test_granular_accuracy(torch_ln_mlp, bs, dtype, config)
......@@ -2040,14 +2081,25 @@ def test_transformer_layer_hidden_states_format(dtype, bs, model):
@pytest.mark.parametrize("input_format", input_formats_inference)
@pytest.mark.parametrize("module", module_inference)
@pytest.mark.parametrize("backend", backends_inference)
def test_kv_cache_accuracy(dtype, bs, model_key, use_RoPE, input_format, module, backend):
@pytest.mark.parametrize("is_paged", [False, True])
def test_kv_cache_accuracy(dtype, bs, model_key, use_RoPE, input_format, module, backend, is_paged):
reset_rng_states()
if backend in ["FusedAttention", "FlashAttention"] and dtype == torch.float32:
pytest.skip("FusedAttention and FlashAttention do not support FP32")
if use_RoPE:
pytest.skip("KV cache does not support starting positions for RoPE")
os.environ["NVTE_FLASH_ATTN"] = "0"
os.environ["NVTE_FUSED_ATTN"] = "0"
os.environ["NVTE_UNFUSED_ATTN"] = "0"
if backend == "FlashAttention":
os.environ["NVTE_FLASH_ATTN"] = "1"
elif backend == "FusedAttention":
os.environ["NVTE_FUSED_ATTN"] = "1"
elif backend == "UnfusedAttention":
os.environ["NVTE_UNFUSED_ATTN"] = "1"
config = model_configs_inference[model_key]
......@@ -2060,7 +2112,7 @@ def test_kv_cache_accuracy(dtype, bs, model_key, use_RoPE, input_format, module,
# Limits the max size of KV-cache
B_max = B
S_max = S + 2
S_max = S
if module == "TransformerLayer":
model = TransformerLayer(
......@@ -2090,7 +2142,17 @@ def test_kv_cache_accuracy(dtype, bs, model_key, use_RoPE, input_format, module,
.eval()
)
inference_params = InferenceParams(max_batch_size=B_max, max_sequence_length=S_max)
inference_params = InferenceParams(
max_batch_size=B_max,
max_seqlen_kv=S_max,
num_heads_kv=H,
head_dim_k=head_size,
dtype=dtype,
is_paged=is_paged,
total_num_pages=int(B_max * S_max / 256),
page_size=256,
)
rotary_freqs = torch.randn((S_max, 1, 1, head_size), dtype=torch.float, device="cuda")
input = torch.randn((S, B, D), dtype=dtype, device="cuda")
......@@ -2103,22 +2165,39 @@ def test_kv_cache_accuracy(dtype, bs, model_key, use_RoPE, input_format, module,
full_output = model(hidden_states=input, rotary_pos_emb=rotary_freqs if use_RoPE else None)
# Incrementaly generate outputs using KV-cache
step_dict = OrderedDict(zip(list(range(B)), [1] * B))
for i in range(S):
inference_params.pre_step(step_dict)
if input_format == "sbhd":
incremental_input = input[i].view(1, B, D)
else:
incremental_input = input[:, i, :].view(B, 1, D)
seqlens_q = torch.ones(B, dtype=torch.int32, device="cuda")
cu_seqlens_q = torch.zeros(B + 1, dtype=torch.int32, device="cuda")
cu_seqlens_q[1:] = torch.cumsum(seqlens_q, dim=0)
cu_seqlens_kv = cu_seqlens_q.clone()
mask_type = "padding"
kwargs = {}
if module == "TransformerLayer":
kwargs["self_attn_mask_type"] = mask_type
else:
kwargs["attn_mask_type"] = mask_type
line_output = model(
hidden_states=incremental_input,
inference_params=inference_params,
rotary_pos_emb=rotary_freqs if use_RoPE else None,
**kwargs,
max_seqlen_q=1,
max_seqlen_kv=S,
cu_seqlens_q=cu_seqlens_q,
cu_seqlens_kv=cu_seqlens_kv,
)
inference_params.sequence_len_offset += 1
if input_format == "sbhd":
incremental_output[i] = line_output.view(B, D)
incremental_output[i, :, :] = line_output.view(B, D)
else:
incremental_output[:, i, :] = line_output.view(B, D)
......
......@@ -58,9 +58,13 @@ cudnn_frontend::DataType_t get_cudnn_fe_dtype(const transformer_engine::DType t)
}
}
void nvte_cudnn_handle_init() {
auto handle = cudnnExecutionPlanManager::Instance().GetCudnnHandle();
}
void nvte_cudnn_handle_init() { auto _ = cudnnExecutionPlanManager::Instance().GetHandle(); }
namespace detail {
void CreateCuDNNHandle(cudnnHandle_t* handle) { NVTE_CHECK_CUDNN(cudnnCreate(handle)); }
} // namespace detail
#endif
} // namespace transformer_engine
......@@ -70,6 +74,6 @@ namespace cudnn_frontend {
// This is needed to define the symbol `cudnn_dlhandle`
// When using the flag NV_CUDNN_FRONTEND_USE_DYNAMIC_LOADING
// to enable dynamic loading.
void *cudnn_dlhandle = nullptr;
void* cudnn_dlhandle = nullptr;
} // namespace cudnn_frontend
......@@ -11,40 +11,26 @@
#include <cudnn.h>
#include <cudnn_frontend.h>
#include <cudnn_frontend_utils.h>
#include <cudnn_graph.h>
#endif
#include <cstdint>
#include <mutex>
#include "transformer_engine/transformer_engine.h"
#include "util/handle_manager.h"
namespace transformer_engine {
#ifndef __HIP_PLATFORM_AMD__
cudnnDataType_t get_cudnn_dtype(const transformer_engine::DType t);
namespace detail {
cudnn_frontend::DataType_t get_cudnn_fe_dtype(const transformer_engine::DType t);
void CreateCuDNNHandle(cudnnHandle_t* handle);
class cudnnExecutionPlanManager {
public:
static cudnnExecutionPlanManager &Instance() {
static thread_local cudnnExecutionPlanManager instance;
return instance;
}
} // namespace detail
cudnnHandle_t GetCudnnHandle() {
static thread_local std::once_flag flag;
std::call_once(flag, [&] { cudnnCreate(&handle_); });
return handle_;
}
cudnnDataType_t get_cudnn_dtype(const transformer_engine::DType t);
~cudnnExecutionPlanManager() {}
cudnn_frontend::DataType_t get_cudnn_fe_dtype(const transformer_engine::DType t);
private:
cudnnHandle_t handle_ = nullptr;
};
using cudnnExecutionPlanManager = detail::HandleManager<cudnnHandle_t, detail::CreateCuDNNHandle>;
#endif
} // namespace transformer_engine
#endif
#endif // TRANSFORMER_ENGINE_CUDNN_UTILS_H_
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment