Unverified Commit 86f27e12 authored by Reese Wang's avatar Reese Wang Committed by GitHub
Browse files

[JAX] Support non-deterministic algo for cuDNN FA (#1056)



* Support non-deterministic algo
Signed-off-by: default avatarReese Wang <rewang@nvidia.com>

* Refine the helper function name
Signed-off-by: default avatarReese Wang <rewang@nvidia.com>

* Move fixture to conftest.py
Signed-off-by: default avatarReese Wang <rewang@nvidia.com>

---------
Signed-off-by: default avatarReese Wang <rewang@nvidia.com>
Co-authored-by: default avatarPhuong Nguyen <36155692+phu0ngng@users.noreply.github.com>
parent 6717554f
......@@ -2,9 +2,12 @@
#
# See LICENSE for license information.
"""conftest for tests/jax"""
import os
import jax
import pytest
from transformer_engine.transformer_engine_jax import get_device_compute_capability
@pytest.fixture(autouse=True, scope="function")
def clear_live_arrays():
......@@ -14,3 +17,19 @@ def clear_live_arrays():
yield
for arr in jax.live_arrays():
arr.delete()
@pytest.fixture(autouse=True, scope="module")
def enable_fused_attn():
"""
Enable fused attn for hopper+ arch.
Fused attn kernels on pre-hopper arch are not deterministic.
"""
if get_device_compute_capability(0) >= 90:
os.environ["NVTE_FUSED_ATTN"] = "1"
os.environ["NVTE_ALLOW_NONDETERMINISTIC_ALGO"] = "0"
yield
if "NVTE_FUSED_ATTN" in os.environ:
del os.environ["NVTE_FUSED_ATTN"]
if "NVTE_ALLOW_NONDETERMINISTIC_ALGO" in os.environ:
del os.environ["NVTE_ALLOW_NONDETERMINISTIC_ALGO"]
......@@ -15,7 +15,6 @@ import pytest
from utils import assert_allclose
from transformer_engine.transformer_engine_jax import get_device_compute_capability
from transformer_engine.common.recipe import DelayedScaling, Format
from transformer_engine.jax import fp8_autocast, update_collections
from transformer_engine.jax.flax import DenseGeneral, LayerNormDenseGeneral
......@@ -43,19 +42,6 @@ ENABLE_FP8 = [False, True]
FP8_FORMATS = [Format.E4M3, Format.HYBRID]
@pytest.fixture(autouse=True, scope="module")
def enable_fused_attn():
"""
Enable fused attn for hopper+ arch.
Fused attn kernels on pre-hopper arch are not deterministic.
"""
if get_device_compute_capability(0) >= 90:
os.environ["NVTE_FUSED_ATTN"] = "1"
yield
if "NVTE_FUSED_ATTN" in os.environ:
del os.environ["NVTE_FUSED_ATTN"]
def compare_dict(ref_fd, test_fd, rtol=1e-05, atol=1e-08):
for key in ref_fd:
assert key in test_fd, f"{key} not found in test dict {test_fd}"
......
......@@ -3,8 +3,9 @@
# See LICENSE for license information.
"""JAX/TE custom ops for attention"""
from dataclasses import dataclass
from functools import partial, reduce
from functools import partial, reduce, cache
import operator
import os
from typing import Optional, Tuple
import warnings
......@@ -84,6 +85,12 @@ class FusedAttnHelper:
self.head_dim,
)
@staticmethod
@cache
def is_non_deterministic_allowed():
"""Check if non-deterministic kernels are allowed"""
return bool(int(os.getenv("NVTE_ALLOW_NONDETERMINISTIC_ALGO", "1")))
@staticmethod
def parse_qkv_aval(q_aval, k_aval, v_aval, qkv_layout):
"""Parse qkv aval"""
......@@ -365,6 +372,7 @@ class FusedAttnFwdPrimitive(BasePrimitive):
jax_dtype_to_te_dtype(q_aval.dtype),
jax_dtype_to_te_dtype(wkspace_aval.dtype),
is_training,
not FusedAttnHelper.is_non_deterministic_allowed(),
)
out = custom_caller(FusedAttnFwdPrimitive.name, args, opaque, has_side_effect=False)
......@@ -642,6 +650,8 @@ class FusedAttnBwdPrimitive(BasePrimitive):
*bias_batch_shape, bias_heads, _, _ = bias_aval.shape
bias_batch = reduce(operator.mul, bias_batch_shape)
deterministic = not FusedAttnHelper.is_non_deterministic_allowed()
input_batch = reduce(operator.mul, batch_shape)
wkspace_shape, wkspace_dtype = transformer_engine_jax.get_fused_attn_bwd_workspace_sizes(
input_batch,
......@@ -659,6 +669,7 @@ class FusedAttnBwdPrimitive(BasePrimitive):
qkv_layout,
jax_dtype_to_te_dtype(q_aval.dtype),
is_training,
deterministic,
max_segments_per_seq,
)
......@@ -764,6 +775,7 @@ class FusedAttnBwdPrimitive(BasePrimitive):
jax_dtype_to_te_dtype(q_aval.dtype),
jax_dtype_to_te_dtype(wkspace_aval.dtype),
is_training,
not FusedAttnHelper.is_non_deterministic_allowed(),
)
out = custom_caller(FusedAttnBwdPrimitive.name, args, opaque, has_side_effect=False)
......
......@@ -147,6 +147,7 @@ struct CustomCallFusedAttnDescriptor {
DType dtype;
DType wkspace_dtype;
bool is_training;
bool deterministic;
};
pybind11::bytes PackCustomCallFusedAttnDescriptor(
......@@ -154,7 +155,8 @@ pybind11::bytes PackCustomCallFusedAttnDescriptor(
size_t attn_heads, size_t num_gqa_groups, size_t bias_heads, size_t head_dim,
size_t max_segments_per_seq, size_t wkspace_size, float scaling_factor,
float dropout_probability, NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type,
NVTE_QKV_Layout qkv_layout, DType dtype, DType wkspace_dtype, bool is_training);
NVTE_QKV_Layout qkv_layout, DType dtype, DType wkspace_dtype, bool is_training,
bool deterministic);
// Transpose
......@@ -260,7 +262,7 @@ pybind11::tuple GetFusedAttnBackwardWorkspaceSizes(
size_t attn_heads, size_t num_gqa_groups, size_t bias_heads, size_t head_dim,
float scaling_factor, float dropout_probability, NVTE_Bias_Type bias_type,
NVTE_Mask_Type mask_type, NVTE_QKV_Layout qkv_layout, DType dtype, bool is_training,
size_t max_segments_per_seq);
bool deterministic, size_t max_segments_per_seq);
void FusedAttnBackward(cudaStream_t stream, void **buffers, const char *opaque, size_t opaque_len);
......
......@@ -336,7 +336,7 @@ pybind11::tuple GetFusedAttnBackwardWorkspaceSizes(
size_t attn_heads, size_t num_gqa_groups, size_t bias_heads, size_t head_dim,
float scaling_factor, float dropout_probability, NVTE_Bias_Type bias_type,
NVTE_Mask_Type mask_type, NVTE_QKV_Layout qkv_layout, DType dtype, bool is_training,
size_t max_segments_per_seq) {
bool deterministic, size_t max_segments_per_seq) {
// For qkv_packed
auto qkv_shape = std::vector<size_t>{input_batch * q_max_seqlen, 3, attn_heads, head_dim};
auto qkv_tensor = TensorWrapper(nullptr, qkv_shape, dtype);
......@@ -392,13 +392,14 @@ pybind11::tuple GetFusedAttnBackwardWorkspaceSizes(
auto dummy_ragged_offset_tensor =
TensorWrapper(nullptr, std::vector<size_t>{num_segments + 1}, DType::kInt32);
if (layout_group == NVTE_QKV_Layout_Group::NVTE_3HD) {
nvte_fused_attn_bwd_qkvpacked(
qkv_tensor.data(), output_tensor.data(), doutput_tensor.data(),
nvte_fused_attn_bwd_qkvpacked(qkv_tensor.data(), output_tensor.data(), doutput_tensor.data(),
s_tensor.data(), // not used for F16
s_tensor.data(), // not used for F16
&aux_input_tensors, dqkv_tensor.data(), dbias_tensor.data(), q_cu_seqlens_tensor.data(),
dummy_ragged_offset_tensor.data(), q_max_seqlen, scaling_factor, dropout_probability,
qkv_layout, bias_type, mask_type, -1, -1, true, query_workspace_tensor.data(), nullptr);
&aux_input_tensors, dqkv_tensor.data(), dbias_tensor.data(),
q_cu_seqlens_tensor.data(), dummy_ragged_offset_tensor.data(),
q_max_seqlen, scaling_factor, dropout_probability, qkv_layout,
bias_type, mask_type, -1, -1, deterministic,
query_workspace_tensor.data(), nullptr);
} else if (layout_group == NVTE_QKV_Layout_Group::NVTE_HD_2HD) {
nvte_fused_attn_bwd_kvpacked(
q_tensor.data(), kv_tensor.data(), output_tensor.data(), doutput_tensor.data(),
......@@ -408,7 +409,7 @@ pybind11::tuple GetFusedAttnBackwardWorkspaceSizes(
q_cu_seqlens_tensor.data(), kv_cu_seqlens_tensor.data(),
dummy_ragged_offset_tensor.data(), dummy_ragged_offset_tensor.data(), q_max_seqlen,
kv_max_seqlen, scaling_factor, dropout_probability, qkv_layout, bias_type, mask_type, -1,
-1, true, query_workspace_tensor.data(), nullptr);
-1, deterministic, query_workspace_tensor.data(), nullptr);
} else if (layout_group == NVTE_QKV_Layout_Group::NVTE_HD_HD_HD) {
nvte_fused_attn_bwd(q_tensor.data(), k_tensor.data(), v_tensor.data(), output_tensor.data(),
doutput_tensor.data(),
......@@ -419,7 +420,7 @@ pybind11::tuple GetFusedAttnBackwardWorkspaceSizes(
kv_cu_seqlens_tensor.data(), dummy_ragged_offset_tensor.data(),
dummy_ragged_offset_tensor.data(), q_max_seqlen, kv_max_seqlen,
scaling_factor, dropout_probability, qkv_layout, bias_type, mask_type, -1,
-1, true, query_workspace_tensor.data(), nullptr);
-1, deterministic, query_workspace_tensor.data(), nullptr);
} else {
NVTE_ERROR("Unsupported qkv_layout.");
}
......@@ -467,6 +468,7 @@ void FusedAttnBackward(cudaStream_t stream, void **buffers, const char *opaque,
auto bias_type = descriptor.bias_type;
auto mask_type = descriptor.mask_type;
auto dtype = descriptor.dtype;
auto deterministic = descriptor.deterministic;
auto max_segments_per_seq = descriptor.max_segments_per_seq;
/* Input tensors */
......@@ -539,7 +541,7 @@ void FusedAttnBackward(cudaStream_t stream, void **buffers, const char *opaque,
s_tensor.data(), // not used for F16
&aux_input_tensors, dqkv_tensor.data(), dbias_tensor.data(), q_cu_seqlens_tensor.data(),
q_seq_offsets_tensor.data(), q_max_seqlen, scaling_factor, dropout_probability, qkv_layout,
bias_type, mask_type, -1, -1, true, workspace_tensor.data(), stream);
bias_type, mask_type, -1, -1, deterministic, workspace_tensor.data(), stream);
} else if (layout_group == NVTE_QKV_Layout_Group::NVTE_HD_2HD) {
auto q = buffers[0];
auto q_shape = std::vector<size_t>{input_batch * q_max_seqlen, attn_heads, head_dim};
......@@ -566,7 +568,7 @@ void FusedAttnBackward(cudaStream_t stream, void **buffers, const char *opaque,
&aux_input_tensors, dq_tensor.data(), dkv_tensor.data(), dbias_tensor.data(),
q_cu_seqlens_tensor.data(), kv_cu_seqlens_tensor.data(), q_seq_offsets_tensor.data(),
k_seq_offsets_tensor.data(), q_max_seqlen, kv_max_seqlen, scaling_factor,
dropout_probability, qkv_layout, bias_type, mask_type, -1, -1, true,
dropout_probability, qkv_layout, bias_type, mask_type, -1, -1, deterministic,
workspace_tensor.data(), stream);
} else if (layout_group == NVTE_QKV_Layout_Group::NVTE_HD_HD_HD) {
auto q = buffers[0];
......@@ -602,8 +604,8 @@ void FusedAttnBackward(cudaStream_t stream, void **buffers, const char *opaque,
dbias_tensor.data(), q_cu_seqlens_tensor.data(),
kv_cu_seqlens_tensor.data(), q_seq_offsets_tensor.data(),
k_seq_offsets_tensor.data(), q_max_seqlen, kv_max_seqlen, scaling_factor,
dropout_probability, qkv_layout, bias_type, mask_type, -1, -1, true,
workspace_tensor.data(), stream);
dropout_probability, qkv_layout, bias_type, mask_type, -1, -1,
deterministic, workspace_tensor.data(), stream);
} else {
NVTE_ERROR("Unsupported qkv_layout.");
}
......
......@@ -68,11 +68,12 @@ pybind11::bytes PackCustomCallFusedAttnDescriptor(
size_t attn_heads, size_t num_gqa_groups, size_t bias_heads, size_t head_dim,
size_t max_segments_per_seq, size_t wkspace_size, float scaling_factor,
float dropout_probability, NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type,
NVTE_QKV_Layout qkv_layout, DType dtype, DType wkspace_dtype, bool is_training) {
NVTE_QKV_Layout qkv_layout, DType dtype, DType wkspace_dtype, bool is_training,
bool deterministic) {
return PackOpaque(CustomCallFusedAttnDescriptor{
input_batch, bias_batch, q_max_seqlen, kv_max_seqlen, attn_heads, num_gqa_groups, bias_heads,
head_dim, max_segments_per_seq, wkspace_size, scaling_factor, dropout_probability, bias_type,
mask_type, qkv_layout, dtype, wkspace_dtype, is_training});
mask_type, qkv_layout, dtype, wkspace_dtype, is_training, deterministic});
}
} // namespace jax
......
......@@ -359,6 +359,14 @@ class DotProductAttention(nn.Module): # pylint: disable=too-few-public-methods
kernel is not available on the system, a warning will be issued, and the module will
automatically fall back to the unfused backend.
.. note::
The DotProductAttention default setting enables non-deterministic kernels for reduced
workspace requirements and faster computation. Users can disable the non-deterministic
kernels via the :attr:`NVTE_ALLOW_NONDETERMINISTIC_ALGO` environment variable:
* :attr:`NVTE_ALLOW_NONDETERMINISTIC_ALGO=0` to allow only deterministic kernels.
* :attr:`NVTE_ALLOW_NONDETERMINISTIC_ALGO=1` to allow non-deterministic kernels (default).
Parameters
----------
head_dim: int
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment