"...git@developer.sourcefind.cn:OpenDAS/TransformerEngine.git" did not exist on "eac0d49b1ea7018959d8656929b6ed0615809965"
Unverified Commit 86f27e12 authored by Reese Wang's avatar Reese Wang Committed by GitHub
Browse files

[JAX] Support non-deterministic algo for cuDNN FA (#1056)



* Support non-deterministic algo
Signed-off-by: default avatarReese Wang <rewang@nvidia.com>

* Refine the helper function name
Signed-off-by: default avatarReese Wang <rewang@nvidia.com>

* Move fixture to conftest.py
Signed-off-by: default avatarReese Wang <rewang@nvidia.com>

---------
Signed-off-by: default avatarReese Wang <rewang@nvidia.com>
Co-authored-by: default avatarPhuong Nguyen <36155692+phu0ngng@users.noreply.github.com>
parent 6717554f
...@@ -2,9 +2,12 @@ ...@@ -2,9 +2,12 @@
# #
# See LICENSE for license information. # See LICENSE for license information.
"""conftest for tests/jax""" """conftest for tests/jax"""
import os
import jax import jax
import pytest import pytest
from transformer_engine.transformer_engine_jax import get_device_compute_capability
@pytest.fixture(autouse=True, scope="function") @pytest.fixture(autouse=True, scope="function")
def clear_live_arrays(): def clear_live_arrays():
...@@ -14,3 +17,19 @@ def clear_live_arrays(): ...@@ -14,3 +17,19 @@ def clear_live_arrays():
yield yield
for arr in jax.live_arrays(): for arr in jax.live_arrays():
arr.delete() arr.delete()
@pytest.fixture(autouse=True, scope="module")
def enable_fused_attn():
"""
Enable fused attn for hopper+ arch.
Fused attn kernels on pre-hopper arch are not deterministic.
"""
if get_device_compute_capability(0) >= 90:
os.environ["NVTE_FUSED_ATTN"] = "1"
os.environ["NVTE_ALLOW_NONDETERMINISTIC_ALGO"] = "0"
yield
if "NVTE_FUSED_ATTN" in os.environ:
del os.environ["NVTE_FUSED_ATTN"]
if "NVTE_ALLOW_NONDETERMINISTIC_ALGO" in os.environ:
del os.environ["NVTE_ALLOW_NONDETERMINISTIC_ALGO"]
...@@ -15,7 +15,6 @@ import pytest ...@@ -15,7 +15,6 @@ import pytest
from utils import assert_allclose from utils import assert_allclose
from transformer_engine.transformer_engine_jax import get_device_compute_capability
from transformer_engine.common.recipe import DelayedScaling, Format from transformer_engine.common.recipe import DelayedScaling, Format
from transformer_engine.jax import fp8_autocast, update_collections from transformer_engine.jax import fp8_autocast, update_collections
from transformer_engine.jax.flax import DenseGeneral, LayerNormDenseGeneral from transformer_engine.jax.flax import DenseGeneral, LayerNormDenseGeneral
...@@ -43,19 +42,6 @@ ENABLE_FP8 = [False, True] ...@@ -43,19 +42,6 @@ ENABLE_FP8 = [False, True]
FP8_FORMATS = [Format.E4M3, Format.HYBRID] FP8_FORMATS = [Format.E4M3, Format.HYBRID]
@pytest.fixture(autouse=True, scope="module")
def enable_fused_attn():
"""
Enable fused attn for hopper+ arch.
Fused attn kernels on pre-hopper arch are not deterministic.
"""
if get_device_compute_capability(0) >= 90:
os.environ["NVTE_FUSED_ATTN"] = "1"
yield
if "NVTE_FUSED_ATTN" in os.environ:
del os.environ["NVTE_FUSED_ATTN"]
def compare_dict(ref_fd, test_fd, rtol=1e-05, atol=1e-08): def compare_dict(ref_fd, test_fd, rtol=1e-05, atol=1e-08):
for key in ref_fd: for key in ref_fd:
assert key in test_fd, f"{key} not found in test dict {test_fd}" assert key in test_fd, f"{key} not found in test dict {test_fd}"
......
...@@ -3,8 +3,9 @@ ...@@ -3,8 +3,9 @@
# See LICENSE for license information. # See LICENSE for license information.
"""JAX/TE custom ops for attention""" """JAX/TE custom ops for attention"""
from dataclasses import dataclass from dataclasses import dataclass
from functools import partial, reduce from functools import partial, reduce, cache
import operator import operator
import os
from typing import Optional, Tuple from typing import Optional, Tuple
import warnings import warnings
...@@ -84,6 +85,12 @@ class FusedAttnHelper: ...@@ -84,6 +85,12 @@ class FusedAttnHelper:
self.head_dim, self.head_dim,
) )
@staticmethod
@cache
def is_non_deterministic_allowed():
"""Check if non-deterministic kernels are allowed"""
return bool(int(os.getenv("NVTE_ALLOW_NONDETERMINISTIC_ALGO", "1")))
@staticmethod @staticmethod
def parse_qkv_aval(q_aval, k_aval, v_aval, qkv_layout): def parse_qkv_aval(q_aval, k_aval, v_aval, qkv_layout):
"""Parse qkv aval""" """Parse qkv aval"""
...@@ -365,6 +372,7 @@ class FusedAttnFwdPrimitive(BasePrimitive): ...@@ -365,6 +372,7 @@ class FusedAttnFwdPrimitive(BasePrimitive):
jax_dtype_to_te_dtype(q_aval.dtype), jax_dtype_to_te_dtype(q_aval.dtype),
jax_dtype_to_te_dtype(wkspace_aval.dtype), jax_dtype_to_te_dtype(wkspace_aval.dtype),
is_training, is_training,
not FusedAttnHelper.is_non_deterministic_allowed(),
) )
out = custom_caller(FusedAttnFwdPrimitive.name, args, opaque, has_side_effect=False) out = custom_caller(FusedAttnFwdPrimitive.name, args, opaque, has_side_effect=False)
...@@ -642,6 +650,8 @@ class FusedAttnBwdPrimitive(BasePrimitive): ...@@ -642,6 +650,8 @@ class FusedAttnBwdPrimitive(BasePrimitive):
*bias_batch_shape, bias_heads, _, _ = bias_aval.shape *bias_batch_shape, bias_heads, _, _ = bias_aval.shape
bias_batch = reduce(operator.mul, bias_batch_shape) bias_batch = reduce(operator.mul, bias_batch_shape)
deterministic = not FusedAttnHelper.is_non_deterministic_allowed()
input_batch = reduce(operator.mul, batch_shape) input_batch = reduce(operator.mul, batch_shape)
wkspace_shape, wkspace_dtype = transformer_engine_jax.get_fused_attn_bwd_workspace_sizes( wkspace_shape, wkspace_dtype = transformer_engine_jax.get_fused_attn_bwd_workspace_sizes(
input_batch, input_batch,
...@@ -659,6 +669,7 @@ class FusedAttnBwdPrimitive(BasePrimitive): ...@@ -659,6 +669,7 @@ class FusedAttnBwdPrimitive(BasePrimitive):
qkv_layout, qkv_layout,
jax_dtype_to_te_dtype(q_aval.dtype), jax_dtype_to_te_dtype(q_aval.dtype),
is_training, is_training,
deterministic,
max_segments_per_seq, max_segments_per_seq,
) )
...@@ -764,6 +775,7 @@ class FusedAttnBwdPrimitive(BasePrimitive): ...@@ -764,6 +775,7 @@ class FusedAttnBwdPrimitive(BasePrimitive):
jax_dtype_to_te_dtype(q_aval.dtype), jax_dtype_to_te_dtype(q_aval.dtype),
jax_dtype_to_te_dtype(wkspace_aval.dtype), jax_dtype_to_te_dtype(wkspace_aval.dtype),
is_training, is_training,
not FusedAttnHelper.is_non_deterministic_allowed(),
) )
out = custom_caller(FusedAttnBwdPrimitive.name, args, opaque, has_side_effect=False) out = custom_caller(FusedAttnBwdPrimitive.name, args, opaque, has_side_effect=False)
......
...@@ -147,6 +147,7 @@ struct CustomCallFusedAttnDescriptor { ...@@ -147,6 +147,7 @@ struct CustomCallFusedAttnDescriptor {
DType dtype; DType dtype;
DType wkspace_dtype; DType wkspace_dtype;
bool is_training; bool is_training;
bool deterministic;
}; };
pybind11::bytes PackCustomCallFusedAttnDescriptor( pybind11::bytes PackCustomCallFusedAttnDescriptor(
...@@ -154,7 +155,8 @@ pybind11::bytes PackCustomCallFusedAttnDescriptor( ...@@ -154,7 +155,8 @@ pybind11::bytes PackCustomCallFusedAttnDescriptor(
size_t attn_heads, size_t num_gqa_groups, size_t bias_heads, size_t head_dim, size_t attn_heads, size_t num_gqa_groups, size_t bias_heads, size_t head_dim,
size_t max_segments_per_seq, size_t wkspace_size, float scaling_factor, size_t max_segments_per_seq, size_t wkspace_size, float scaling_factor,
float dropout_probability, NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, float dropout_probability, NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type,
NVTE_QKV_Layout qkv_layout, DType dtype, DType wkspace_dtype, bool is_training); NVTE_QKV_Layout qkv_layout, DType dtype, DType wkspace_dtype, bool is_training,
bool deterministic);
// Transpose // Transpose
...@@ -260,7 +262,7 @@ pybind11::tuple GetFusedAttnBackwardWorkspaceSizes( ...@@ -260,7 +262,7 @@ pybind11::tuple GetFusedAttnBackwardWorkspaceSizes(
size_t attn_heads, size_t num_gqa_groups, size_t bias_heads, size_t head_dim, size_t attn_heads, size_t num_gqa_groups, size_t bias_heads, size_t head_dim,
float scaling_factor, float dropout_probability, NVTE_Bias_Type bias_type, float scaling_factor, float dropout_probability, NVTE_Bias_Type bias_type,
NVTE_Mask_Type mask_type, NVTE_QKV_Layout qkv_layout, DType dtype, bool is_training, NVTE_Mask_Type mask_type, NVTE_QKV_Layout qkv_layout, DType dtype, bool is_training,
size_t max_segments_per_seq); bool deterministic, size_t max_segments_per_seq);
void FusedAttnBackward(cudaStream_t stream, void **buffers, const char *opaque, size_t opaque_len); void FusedAttnBackward(cudaStream_t stream, void **buffers, const char *opaque, size_t opaque_len);
......
...@@ -336,7 +336,7 @@ pybind11::tuple GetFusedAttnBackwardWorkspaceSizes( ...@@ -336,7 +336,7 @@ pybind11::tuple GetFusedAttnBackwardWorkspaceSizes(
size_t attn_heads, size_t num_gqa_groups, size_t bias_heads, size_t head_dim, size_t attn_heads, size_t num_gqa_groups, size_t bias_heads, size_t head_dim,
float scaling_factor, float dropout_probability, NVTE_Bias_Type bias_type, float scaling_factor, float dropout_probability, NVTE_Bias_Type bias_type,
NVTE_Mask_Type mask_type, NVTE_QKV_Layout qkv_layout, DType dtype, bool is_training, NVTE_Mask_Type mask_type, NVTE_QKV_Layout qkv_layout, DType dtype, bool is_training,
size_t max_segments_per_seq) { bool deterministic, size_t max_segments_per_seq) {
// For qkv_packed // For qkv_packed
auto qkv_shape = std::vector<size_t>{input_batch * q_max_seqlen, 3, attn_heads, head_dim}; auto qkv_shape = std::vector<size_t>{input_batch * q_max_seqlen, 3, attn_heads, head_dim};
auto qkv_tensor = TensorWrapper(nullptr, qkv_shape, dtype); auto qkv_tensor = TensorWrapper(nullptr, qkv_shape, dtype);
...@@ -392,13 +392,14 @@ pybind11::tuple GetFusedAttnBackwardWorkspaceSizes( ...@@ -392,13 +392,14 @@ pybind11::tuple GetFusedAttnBackwardWorkspaceSizes(
auto dummy_ragged_offset_tensor = auto dummy_ragged_offset_tensor =
TensorWrapper(nullptr, std::vector<size_t>{num_segments + 1}, DType::kInt32); TensorWrapper(nullptr, std::vector<size_t>{num_segments + 1}, DType::kInt32);
if (layout_group == NVTE_QKV_Layout_Group::NVTE_3HD) { if (layout_group == NVTE_QKV_Layout_Group::NVTE_3HD) {
nvte_fused_attn_bwd_qkvpacked( nvte_fused_attn_bwd_qkvpacked(qkv_tensor.data(), output_tensor.data(), doutput_tensor.data(),
qkv_tensor.data(), output_tensor.data(), doutput_tensor.data(), s_tensor.data(), // not used for F16
s_tensor.data(), // not used for F16 s_tensor.data(), // not used for F16
s_tensor.data(), // not used for F16 &aux_input_tensors, dqkv_tensor.data(), dbias_tensor.data(),
&aux_input_tensors, dqkv_tensor.data(), dbias_tensor.data(), q_cu_seqlens_tensor.data(), q_cu_seqlens_tensor.data(), dummy_ragged_offset_tensor.data(),
dummy_ragged_offset_tensor.data(), q_max_seqlen, scaling_factor, dropout_probability, q_max_seqlen, scaling_factor, dropout_probability, qkv_layout,
qkv_layout, bias_type, mask_type, -1, -1, true, query_workspace_tensor.data(), nullptr); bias_type, mask_type, -1, -1, deterministic,
query_workspace_tensor.data(), nullptr);
} else if (layout_group == NVTE_QKV_Layout_Group::NVTE_HD_2HD) { } else if (layout_group == NVTE_QKV_Layout_Group::NVTE_HD_2HD) {
nvte_fused_attn_bwd_kvpacked( nvte_fused_attn_bwd_kvpacked(
q_tensor.data(), kv_tensor.data(), output_tensor.data(), doutput_tensor.data(), q_tensor.data(), kv_tensor.data(), output_tensor.data(), doutput_tensor.data(),
...@@ -408,7 +409,7 @@ pybind11::tuple GetFusedAttnBackwardWorkspaceSizes( ...@@ -408,7 +409,7 @@ pybind11::tuple GetFusedAttnBackwardWorkspaceSizes(
q_cu_seqlens_tensor.data(), kv_cu_seqlens_tensor.data(), q_cu_seqlens_tensor.data(), kv_cu_seqlens_tensor.data(),
dummy_ragged_offset_tensor.data(), dummy_ragged_offset_tensor.data(), q_max_seqlen, dummy_ragged_offset_tensor.data(), dummy_ragged_offset_tensor.data(), q_max_seqlen,
kv_max_seqlen, scaling_factor, dropout_probability, qkv_layout, bias_type, mask_type, -1, kv_max_seqlen, scaling_factor, dropout_probability, qkv_layout, bias_type, mask_type, -1,
-1, true, query_workspace_tensor.data(), nullptr); -1, deterministic, query_workspace_tensor.data(), nullptr);
} else if (layout_group == NVTE_QKV_Layout_Group::NVTE_HD_HD_HD) { } else if (layout_group == NVTE_QKV_Layout_Group::NVTE_HD_HD_HD) {
nvte_fused_attn_bwd(q_tensor.data(), k_tensor.data(), v_tensor.data(), output_tensor.data(), nvte_fused_attn_bwd(q_tensor.data(), k_tensor.data(), v_tensor.data(), output_tensor.data(),
doutput_tensor.data(), doutput_tensor.data(),
...@@ -419,7 +420,7 @@ pybind11::tuple GetFusedAttnBackwardWorkspaceSizes( ...@@ -419,7 +420,7 @@ pybind11::tuple GetFusedAttnBackwardWorkspaceSizes(
kv_cu_seqlens_tensor.data(), dummy_ragged_offset_tensor.data(), kv_cu_seqlens_tensor.data(), dummy_ragged_offset_tensor.data(),
dummy_ragged_offset_tensor.data(), q_max_seqlen, kv_max_seqlen, dummy_ragged_offset_tensor.data(), q_max_seqlen, kv_max_seqlen,
scaling_factor, dropout_probability, qkv_layout, bias_type, mask_type, -1, scaling_factor, dropout_probability, qkv_layout, bias_type, mask_type, -1,
-1, true, query_workspace_tensor.data(), nullptr); -1, deterministic, query_workspace_tensor.data(), nullptr);
} else { } else {
NVTE_ERROR("Unsupported qkv_layout."); NVTE_ERROR("Unsupported qkv_layout.");
} }
...@@ -467,6 +468,7 @@ void FusedAttnBackward(cudaStream_t stream, void **buffers, const char *opaque, ...@@ -467,6 +468,7 @@ void FusedAttnBackward(cudaStream_t stream, void **buffers, const char *opaque,
auto bias_type = descriptor.bias_type; auto bias_type = descriptor.bias_type;
auto mask_type = descriptor.mask_type; auto mask_type = descriptor.mask_type;
auto dtype = descriptor.dtype; auto dtype = descriptor.dtype;
auto deterministic = descriptor.deterministic;
auto max_segments_per_seq = descriptor.max_segments_per_seq; auto max_segments_per_seq = descriptor.max_segments_per_seq;
/* Input tensors */ /* Input tensors */
...@@ -539,7 +541,7 @@ void FusedAttnBackward(cudaStream_t stream, void **buffers, const char *opaque, ...@@ -539,7 +541,7 @@ void FusedAttnBackward(cudaStream_t stream, void **buffers, const char *opaque,
s_tensor.data(), // not used for F16 s_tensor.data(), // not used for F16
&aux_input_tensors, dqkv_tensor.data(), dbias_tensor.data(), q_cu_seqlens_tensor.data(), &aux_input_tensors, dqkv_tensor.data(), dbias_tensor.data(), q_cu_seqlens_tensor.data(),
q_seq_offsets_tensor.data(), q_max_seqlen, scaling_factor, dropout_probability, qkv_layout, q_seq_offsets_tensor.data(), q_max_seqlen, scaling_factor, dropout_probability, qkv_layout,
bias_type, mask_type, -1, -1, true, workspace_tensor.data(), stream); bias_type, mask_type, -1, -1, deterministic, workspace_tensor.data(), stream);
} else if (layout_group == NVTE_QKV_Layout_Group::NVTE_HD_2HD) { } else if (layout_group == NVTE_QKV_Layout_Group::NVTE_HD_2HD) {
auto q = buffers[0]; auto q = buffers[0];
auto q_shape = std::vector<size_t>{input_batch * q_max_seqlen, attn_heads, head_dim}; auto q_shape = std::vector<size_t>{input_batch * q_max_seqlen, attn_heads, head_dim};
...@@ -566,7 +568,7 @@ void FusedAttnBackward(cudaStream_t stream, void **buffers, const char *opaque, ...@@ -566,7 +568,7 @@ void FusedAttnBackward(cudaStream_t stream, void **buffers, const char *opaque,
&aux_input_tensors, dq_tensor.data(), dkv_tensor.data(), dbias_tensor.data(), &aux_input_tensors, dq_tensor.data(), dkv_tensor.data(), dbias_tensor.data(),
q_cu_seqlens_tensor.data(), kv_cu_seqlens_tensor.data(), q_seq_offsets_tensor.data(), q_cu_seqlens_tensor.data(), kv_cu_seqlens_tensor.data(), q_seq_offsets_tensor.data(),
k_seq_offsets_tensor.data(), q_max_seqlen, kv_max_seqlen, scaling_factor, k_seq_offsets_tensor.data(), q_max_seqlen, kv_max_seqlen, scaling_factor,
dropout_probability, qkv_layout, bias_type, mask_type, -1, -1, true, dropout_probability, qkv_layout, bias_type, mask_type, -1, -1, deterministic,
workspace_tensor.data(), stream); workspace_tensor.data(), stream);
} else if (layout_group == NVTE_QKV_Layout_Group::NVTE_HD_HD_HD) { } else if (layout_group == NVTE_QKV_Layout_Group::NVTE_HD_HD_HD) {
auto q = buffers[0]; auto q = buffers[0];
...@@ -602,8 +604,8 @@ void FusedAttnBackward(cudaStream_t stream, void **buffers, const char *opaque, ...@@ -602,8 +604,8 @@ void FusedAttnBackward(cudaStream_t stream, void **buffers, const char *opaque,
dbias_tensor.data(), q_cu_seqlens_tensor.data(), dbias_tensor.data(), q_cu_seqlens_tensor.data(),
kv_cu_seqlens_tensor.data(), q_seq_offsets_tensor.data(), kv_cu_seqlens_tensor.data(), q_seq_offsets_tensor.data(),
k_seq_offsets_tensor.data(), q_max_seqlen, kv_max_seqlen, scaling_factor, k_seq_offsets_tensor.data(), q_max_seqlen, kv_max_seqlen, scaling_factor,
dropout_probability, qkv_layout, bias_type, mask_type, -1, -1, true, dropout_probability, qkv_layout, bias_type, mask_type, -1, -1,
workspace_tensor.data(), stream); deterministic, workspace_tensor.data(), stream);
} else { } else {
NVTE_ERROR("Unsupported qkv_layout."); NVTE_ERROR("Unsupported qkv_layout.");
} }
......
...@@ -68,11 +68,12 @@ pybind11::bytes PackCustomCallFusedAttnDescriptor( ...@@ -68,11 +68,12 @@ pybind11::bytes PackCustomCallFusedAttnDescriptor(
size_t attn_heads, size_t num_gqa_groups, size_t bias_heads, size_t head_dim, size_t attn_heads, size_t num_gqa_groups, size_t bias_heads, size_t head_dim,
size_t max_segments_per_seq, size_t wkspace_size, float scaling_factor, size_t max_segments_per_seq, size_t wkspace_size, float scaling_factor,
float dropout_probability, NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, float dropout_probability, NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type,
NVTE_QKV_Layout qkv_layout, DType dtype, DType wkspace_dtype, bool is_training) { NVTE_QKV_Layout qkv_layout, DType dtype, DType wkspace_dtype, bool is_training,
bool deterministic) {
return PackOpaque(CustomCallFusedAttnDescriptor{ return PackOpaque(CustomCallFusedAttnDescriptor{
input_batch, bias_batch, q_max_seqlen, kv_max_seqlen, attn_heads, num_gqa_groups, bias_heads, input_batch, bias_batch, q_max_seqlen, kv_max_seqlen, attn_heads, num_gqa_groups, bias_heads,
head_dim, max_segments_per_seq, wkspace_size, scaling_factor, dropout_probability, bias_type, head_dim, max_segments_per_seq, wkspace_size, scaling_factor, dropout_probability, bias_type,
mask_type, qkv_layout, dtype, wkspace_dtype, is_training}); mask_type, qkv_layout, dtype, wkspace_dtype, is_training, deterministic});
} }
} // namespace jax } // namespace jax
......
...@@ -359,6 +359,14 @@ class DotProductAttention(nn.Module): # pylint: disable=too-few-public-methods ...@@ -359,6 +359,14 @@ class DotProductAttention(nn.Module): # pylint: disable=too-few-public-methods
kernel is not available on the system, a warning will be issued, and the module will kernel is not available on the system, a warning will be issued, and the module will
automatically fall back to the unfused backend. automatically fall back to the unfused backend.
.. note::
The DotProductAttention default setting enables non-deterministic kernels for reduced
workspace requirements and faster computation. Users can disable the non-deterministic
kernels via the :attr:`NVTE_ALLOW_NONDETERMINISTIC_ALGO` environment variable:
* :attr:`NVTE_ALLOW_NONDETERMINISTIC_ALGO=0` to allow only deterministic kernels.
* :attr:`NVTE_ALLOW_NONDETERMINISTIC_ALGO=1` to allow non-deterministic kernels (default).
Parameters Parameters
---------- ----------
head_dim: int head_dim: int
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment