"vscode:/vscode.git/clone" did not exist on "f73d623791994af3282bf7aa5d161288272557dd"
Unverified Commit 73c9f421 authored by zlsh80826's avatar zlsh80826 Committed by GitHub
Browse files

Add FP16/BF16 fused_attention support with max_seqlen=512 (#175)



* Add fused attention unit tests
Signed-off-by: default avatarReese Wang <rewang@nvidia.com>

* Use NVTE_* enums
Signed-off-by: default avatarReese Wang <rewang@nvidia.com>

* Use NVTE_Mask_Type and remove FMHADescriptor
Signed-off-by: default avatarReese Wang <rewang@nvidia.com>

* Move common functions to utils
Signed-off-by: default avatarReese Wang <rewang@nvidia.com>

* Change namespace to fused_attn
Signed-off-by: default avatarReese Wang <rewang@nvidia.com>

* Move fused_attn_max_512_fwd_qkvpacked under the general APIs
Signed-off-by: default avatarReese Wang <rewang@nvidia.com>

* Add fused_attn_max_512_bwd_qkvpacked
Signed-off-by: default avatarReese Wang <rewang@nvidia.com>

* Move fused_attn_max_512_bwd_qkvpacked under the general APIs
Signed-off-by: default avatarReese Wang <rewang@nvidia.com>

* Remove redundant blank line
Signed-off-by: default avatarReese Wang <rewang@nvidia.com>

* Fix a potential bug for cu_seqlen converter
Signed-off-by: default avatarReese Wang <rewang@nvidia.com>

* Reformat fused_attn_max_512
Signed-off-by: default avatarReese Wang <rewang@nvidia.com>

* Refine the unfused attention warning message
Signed-off-by: default avatarReese Wang <rewang@nvidia.com>

* Rename to fused_attn_max_512
Signed-off-by: default avatarReese Wang <rewang@nvidia.com>

* Remove the deprecated header
Signed-off-by: default avatarReese Wang <rewang@nvidia.com>

* Fix flax import
Signed-off-by: default avatarReese Wang <rewang@nvidia.com>

* Rename to fused attn
Signed-off-by: default avatarReese Wang <rewang@nvidia.com>

* Add attention related mask
Signed-off-by: default avatarReese Wang <rewang@nvidia.com>

* Add attn_mask_type and attn_bias_type
Signed-off-by: default avatarReese Wang <rewang@nvidia.com>

* Refactor jax primitive API
* Merge q_cu_seqlen and kv_cu_seqlen
* Remove is_causal_masking
* Replace seed with rng_state
* Add is_training argument
Signed-off-by: default avatarReese Wang <rewang@nvidia.com>

* Remove dsoftmax from the customcall
Signed-off-by: default avatarReese Wang <rewang@nvidia.com>

* Add None guard for bias and dropout_rng
Signed-off-by: default avatarReese Wang <rewang@nvidia.com>

* Add version guard
Signed-off-by: default avatarReese Wang <rewang@nvidia.com>

* Add is_fused_attn_kernel_available() to correctly dispatch the attention impl
Signed-off-by: default avatarReese Wang <rewang@nvidia.com>

* Fix the merge conflict
Signed-off-by: default avatarReese Wang <rewang@nvidia.com>

* Adjust the code style
Signed-off-by: default avatarReese Wang <rewang@nvidia.com>

* Add the missing blank lines
Signed-off-by: default avatarReese Wang <rewang@nvidia.com>

* Change the order of FADescriptor members
Signed-off-by: default avatarReese Wang <rewang@nvidia.com>

* Enhance the readability of fused_attn_max_512.cu
Signed-off-by: default avatarReese Wang <rewang@nvidia.com>

* Generalize the input dimension unpacking
Signed-off-by: default avatarReese Wang <rewang@nvidia.com>

* 16 bits fused attention requires 8.9.1
Signed-off-by: default avatarReese Wang <rewang@nvidia.com>

* Update fused attention support matrix
Signed-off-by: default avatarReese Wang <rewang@nvidia.com>

* Handle None type when sharding
Signed-off-by: default avatarReese Wang <rewang@nvidia.com>

* Change to the padding ratio
Signed-off-by: default avatarReese Wang <rewang@nvidia.com>

* Performance optimization for non-bias cases
Signed-off-by: default avatarReese Wang <rewang@nvidia.com>

* Revert the cudnn-frontend PRIVATE keyword which was used for debugging
Signed-off-by: default avatarReese Wang <rewang@nvidia.com>

* Revert "Update fused attention support matrix"

This reverts commit 4effe67d0f08f733919a329ce5ab421958740f4a.
Signed-off-by: default avatarReese Wang <rewang@nvidia.com>

* Treat b * s as total_seqs to align ragged cases
Signed-off-by: default avatarReese Wang <rewang@nvidia.com>

* Add FP16/BF16 max_seqlen <= 512 fused attention to the support matrix
Signed-off-by: default avatarReese Wang <rewang@nvidia.com>

* Refine test_fused_attn.py

* Replace reference code with flax.linen
* Remove unnecessary comments
* Use AttnMaskType
Signed-off-by: default avatarReese Wang <rewang@nvidia.com>

* Unify the cuDNN compile version
Signed-off-by: default avatarReese Wang <rewang@nvidia.com>

* Add dropout to the support matrix
Signed-off-by: default avatarReese Wang <rewang@nvidia.com>

* Slightly adjust the headers
Signed-off-by: default avatarReese Wang <rewang@nvidia.com>

* Typo fix: remove redundant either
Signed-off-by: default avatarReese Wang <rewang@nvidia.com>

* Consolidating fused attention requirements
Signed-off-by: default avatarReese Wang <rewang@nvidia.com>

* Replace cudnn_frontend::throw_if with NVTE_CHECK for the better error line report
Signed-off-by: default avatarReese Wang <rewang@nvidia.com>

* Rename to fused_attn_fp16_bf16_max_seqlen_512 for the better readability
Signed-off-by: default avatarReese Wang <rewang@nvidia.com>

* Remove CUDNN_FRONTEND_UNUSED
Signed-off-by: default avatarReese Wang <rewang@nvidia.com>

* Add more annotations to the custom calls
Signed-off-by: default avatarReese Wang <rewang@nvidia.com>

---------
Signed-off-by: default avatarReese Wang <rewang@nvidia.com>
parent c6a4a4e0
...@@ -15,7 +15,7 @@ Prerequisites ...@@ -15,7 +15,7 @@ Prerequisites
2. `CUDA 11.8 <https://developer.nvidia.com/cuda-downloads>`__ 2. `CUDA 11.8 <https://developer.nvidia.com/cuda-downloads>`__
3. |driver link|_ supporting CUDA 11.8 or later. 3. |driver link|_ supporting CUDA 11.8 or later.
4. `cuDNN 8.1 <https://developer.nvidia.com/cudnn>`__ or later. 4. `cuDNN 8.1 <https://developer.nvidia.com/cudnn>`__ or later.
5. For FP8 fused attention, `CUDA 12.1 <https://developer.nvidia.com/cuda-downloads>`__ or later, |driver link|_ supporting CUDA 12.1 or later, and `cuDNN 8.9 <https://developer.nvidia.com/cudnn>`__ or later. 5. For FP8/FP16/BF16 fused attention, `CUDA 12.1 <https://developer.nvidia.com/cuda-downloads>`__ or later, |driver link|_ supporting CUDA 12.1 or later, and `cuDNN 8.9.1 <https://developer.nvidia.com/cudnn>`__ or later.
Transformer Engine in NGC Containers Transformer Engine in NGC Containers
......
This diff is collapsed.
...@@ -54,7 +54,7 @@ def compare_frozen_dict(ref_fd, test_fd, rtol=1e-05, atol=1e-08): ...@@ -54,7 +54,7 @@ def compare_frozen_dict(ref_fd, test_fd, rtol=1e-05, atol=1e-08):
err_msg=f"{key=} is not close") err_msg=f"{key=} is not close")
DATA_SHAPE = [(128, 32, 512), (512, 32, 512)] # (seqlen, batch, emb_dim) DATA_SHAPE = [(32, 128, 1024), (32, 512, 1024)] # (batch, seqlen, emb_dim)
DTYPE = [jnp.float32, jnp.bfloat16] DTYPE = [jnp.float32, jnp.bfloat16]
FP8_FORMATS = [Format.E4M3, Format.HYBRID] FP8_FORMATS = [Format.E4M3, Format.HYBRID]
...@@ -68,6 +68,7 @@ _KEY_OF_FUSE_MLP_WI = "fuse_mlp_wi" ...@@ -68,6 +68,7 @@ _KEY_OF_FUSE_MLP_WI = "fuse_mlp_wi"
_KEY_OF_LAYERNORM_TYPE = 'layernorm_type' _KEY_OF_LAYERNORM_TYPE = 'layernorm_type'
_KEY_OF_ZERO_CENTERED_GAMMA = 'zero_centered_gamma' _KEY_OF_ZERO_CENTERED_GAMMA = 'zero_centered_gamma'
_KEY_OF_TRANSPOSE_BS = 'transpose_batch_sequence' _KEY_OF_TRANSPOSE_BS = 'transpose_batch_sequence'
_KEY_OF_SCALE_ATTN_LOGITS = "scale_attn_logits"
BASE_ATTRS = {_KEY_OF_TRANSPOSE_BS: True} BASE_ATTRS = {_KEY_OF_TRANSPOSE_BS: True}
...@@ -99,6 +100,13 @@ ATTRS = [{ ...@@ -99,6 +100,13 @@ ATTRS = [{
_KEY_OF_DROPOUT_RATE: 0.0, _KEY_OF_DROPOUT_RATE: 0.0,
_KEY_OF_MLP_ACTIVATIONS: (('gelu', 'linear')), _KEY_OF_MLP_ACTIVATIONS: (('gelu', 'linear')),
_KEY_OF_FUSE_MLP_WI: True _KEY_OF_FUSE_MLP_WI: True
}, {
_KEY_OF_TRANSPOSE_BS: False,
_KEY_OF_SCALE_ATTN_LOGITS: True,
_KEY_OF_LAYERNORM_TYPE: 'rmsnorm',
_KEY_OF_DROPOUT_RATE: 0.0,
_KEY_OF_MLP_ACTIVATIONS: (('gelu', 'linear')),
_KEY_OF_FUSE_MLP_WI: True
}] }]
ATTRS = [{**BASE_ATTRS, **attr} for attr in ATTRS] ATTRS = [{**BASE_ATTRS, **attr} for attr in ATTRS]
...@@ -129,12 +137,15 @@ class TestEncoderLayer: ...@@ -129,12 +137,15 @@ class TestEncoderLayer:
return ref, flax.core.frozen_dict.FrozenDict(unfreeze_target) return ref, flax.core.frozen_dict.FrozenDict(unfreeze_target)
def forward_runner(self, data_shape, dtype, attrs, rtol=1e-05, atol=1e-08): def forward_runner(self, data_shape, dtype, attrs, rtol=1e-05, atol=1e-08):
transpose_batch_sequence = _KEY_OF_TRANSPOSE_BS in attrs and attrs[_KEY_OF_TRANSPOSE_BS]
batch, seqlen = data_shape[:2]
if transpose_batch_sequence:
data_shape = (data_shape[1], data_shape[0], *data_shape[2:])
sequence_dim = 0 if transpose_batch_sequence else 1
data_rng, init_rng, apply_rng = generate_test_rngs() data_rng, init_rng, apply_rng = generate_test_rngs()
inputs = (jax.random.normal(data_rng, data_shape, dtype),) inputs = (jax.random.normal(data_rng, data_shape, dtype),)
batch = data_shape[1] if _KEY_OF_TRANSPOSE_BS in attrs else data_shape[0]
seqlen = data_shape[0] if _KEY_OF_TRANSPOSE_BS in attrs else data_shape[1]
padded_mask = jnp.zeros((batch, 1, seqlen, seqlen), dtype=jnp.uint8) padded_mask = jnp.zeros((batch, 1, seqlen, seqlen), dtype=jnp.uint8)
ref_masks = (1 - padded_mask,) ref_masks = (1 - padded_mask,)
test_masks = (None, padded_mask) # The second arg of Transformer is encoded tokens. test_masks = (None, padded_mask) # The second arg of Transformer is encoded tokens.
...@@ -149,7 +160,6 @@ class TestEncoderLayer: ...@@ -149,7 +160,6 @@ class TestEncoderLayer:
else: else:
te_layer_attrs[k] = v te_layer_attrs[k] = v
ref_layer_cls = partial(RefEncoderLayer, dtype=dtype, **attrs) ref_layer_cls = partial(RefEncoderLayer, dtype=dtype, **attrs)
sequence_dim = 0
layer_cls = partial(TransformerLayer, layer_cls = partial(TransformerLayer,
hidden_dropout_dims=(sequence_dim,), hidden_dropout_dims=(sequence_dim,),
layer_type=TransformerLayerType.ENCODER, layer_type=TransformerLayerType.ENCODER,
...@@ -171,12 +181,15 @@ class TestEncoderLayer: ...@@ -171,12 +181,15 @@ class TestEncoderLayer:
del data_rng, init_rng, apply_rng del data_rng, init_rng, apply_rng
def forward_backward_runner(self, data_shape, dtype, attrs, rtol=1e-05, atol=1e-08): def forward_backward_runner(self, data_shape, dtype, attrs, rtol=1e-05, atol=1e-08):
transpose_batch_sequence = _KEY_OF_TRANSPOSE_BS in attrs and attrs[_KEY_OF_TRANSPOSE_BS]
batch, seqlen = data_shape[:2]
if transpose_batch_sequence:
data_shape = (data_shape[1], data_shape[0], *data_shape[2:])
sequence_dim = 0 if transpose_batch_sequence else 1
data_rng, init_rng, apply_rng = generate_test_rngs() data_rng, init_rng, apply_rng = generate_test_rngs()
inputs = (jax.random.normal(data_rng, data_shape, dtype),) inputs = (jax.random.normal(data_rng, data_shape, dtype),)
batch = data_shape[1] if _KEY_OF_TRANSPOSE_BS in attrs else data_shape[0]
seqlen = data_shape[0] if _KEY_OF_TRANSPOSE_BS in attrs else data_shape[1]
padded_mask = jnp.zeros((batch, 1, seqlen, seqlen), dtype=jnp.uint8) padded_mask = jnp.zeros((batch, 1, seqlen, seqlen), dtype=jnp.uint8)
ref_masks = (1 - padded_mask,) ref_masks = (1 - padded_mask,)
test_masks = (None, padded_mask) # The second arg of Transformer is encoded tokens. test_masks = (None, padded_mask) # The second arg of Transformer is encoded tokens.
...@@ -191,7 +204,6 @@ class TestEncoderLayer: ...@@ -191,7 +204,6 @@ class TestEncoderLayer:
else: else:
te_layer_attrs[k] = v te_layer_attrs[k] = v
ref_layer_cls = partial(RefEncoderLayer, dtype=dtype, **attrs) ref_layer_cls = partial(RefEncoderLayer, dtype=dtype, **attrs)
sequence_dim = 0
layer_cls = partial(TransformerLayer, layer_cls = partial(TransformerLayer,
hidden_dropout_dims=(sequence_dim,), hidden_dropout_dims=(sequence_dim,),
layer_type=TransformerLayerType.ENCODER, layer_type=TransformerLayerType.ENCODER,
...@@ -335,11 +347,13 @@ class TestDecoderLayer: ...@@ -335,11 +347,13 @@ class TestDecoderLayer:
return ref, flax.core.frozen_dict.FrozenDict(unfreeze_target) return ref, flax.core.frozen_dict.FrozenDict(unfreeze_target)
def forward_runner(self, data_shape, dtype, attrs, rtol=1e-05, atol=1e-08): def forward_runner(self, data_shape, dtype, attrs, rtol=1e-05, atol=1e-08):
data_rng, init_rng, apply_rng = generate_test_rngs() transpose_batch_sequence = _KEY_OF_TRANSPOSE_BS in attrs and attrs[_KEY_OF_TRANSPOSE_BS]
batch, seqlen = data_shape[:2]
batch = data_shape[1] if _KEY_OF_TRANSPOSE_BS in attrs else data_shape[0] if transpose_batch_sequence:
seqlen = data_shape[0] if _KEY_OF_TRANSPOSE_BS in attrs else data_shape[1] data_shape = (data_shape[1], data_shape[0], *data_shape[2:])
sequence_dim = 0 if transpose_batch_sequence else 1
data_rng, init_rng, apply_rng = generate_test_rngs()
inputs = (jax.random.normal(data_rng, data_shape, inputs = (jax.random.normal(data_rng, data_shape,
dtype), jax.random.normal(data_rng, data_shape, dtype)) dtype), jax.random.normal(data_rng, data_shape, dtype))
...@@ -358,7 +372,6 @@ class TestDecoderLayer: ...@@ -358,7 +372,6 @@ class TestDecoderLayer:
else: else:
te_layer_attrs[k] = v te_layer_attrs[k] = v
ref_layer_cls = partial(RefDecoderLayer, dtype=dtype, **attrs) ref_layer_cls = partial(RefDecoderLayer, dtype=dtype, **attrs)
sequence_dim = 0
layer_cls = partial(TransformerLayer, layer_cls = partial(TransformerLayer,
hidden_dropout_dims=(sequence_dim,), hidden_dropout_dims=(sequence_dim,),
layer_type=TransformerLayerType.DECODER, layer_type=TransformerLayerType.DECODER,
...@@ -379,11 +392,13 @@ class TestDecoderLayer: ...@@ -379,11 +392,13 @@ class TestDecoderLayer:
del data_rng, init_rng, apply_rng del data_rng, init_rng, apply_rng
def forward_backward_runner(self, data_shape, dtype, attrs, rtol=1e-05, atol=1e-08): def forward_backward_runner(self, data_shape, dtype, attrs, rtol=1e-05, atol=1e-08):
data_rng, init_rng, apply_rng = generate_test_rngs() transpose_batch_sequence = _KEY_OF_TRANSPOSE_BS in attrs and attrs[_KEY_OF_TRANSPOSE_BS]
batch, seqlen = data_shape[:2]
batch = data_shape[1] if _KEY_OF_TRANSPOSE_BS in attrs else data_shape[0] if transpose_batch_sequence:
seqlen = data_shape[0] if _KEY_OF_TRANSPOSE_BS in attrs else data_shape[1] data_shape = (data_shape[1], data_shape[0], *data_shape[2:])
sequence_dim = 0 if transpose_batch_sequence else 1
data_rng, init_rng, apply_rng = generate_test_rngs()
inputs = (jax.random.normal(data_rng, data_shape, inputs = (jax.random.normal(data_rng, data_shape,
dtype), jax.random.normal(data_rng, data_shape, dtype)) dtype), jax.random.normal(data_rng, data_shape, dtype))
...@@ -402,7 +417,6 @@ class TestDecoderLayer: ...@@ -402,7 +417,6 @@ class TestDecoderLayer:
else: else:
te_layer_attrs[k] = v te_layer_attrs[k] = v
ref_layer_cls = partial(RefDecoderLayer, dtype=dtype, **attrs) ref_layer_cls = partial(RefDecoderLayer, dtype=dtype, **attrs)
sequence_dim = 0
layer_cls = partial(TransformerLayer, layer_cls = partial(TransformerLayer,
hidden_dropout_dims=(sequence_dim,), hidden_dropout_dims=(sequence_dim,),
layer_type=TransformerLayerType.DECODER, layer_type=TransformerLayerType.DECODER,
......
...@@ -12,6 +12,7 @@ list(APPEND transformer_engine_SOURCES ...@@ -12,6 +12,7 @@ list(APPEND transformer_engine_SOURCES
transpose/transpose_fusion.cu transpose/transpose_fusion.cu
transpose/multi_cast_transpose.cu transpose/multi_cast_transpose.cu
activation/gelu.cu activation/gelu.cu
fused_attn/fused_attn_fp16_bf16_max_seqlen_512.cu
fused_attn/fused_attn_fp8.cu fused_attn/fused_attn_fp8.cu
fused_attn/fused_attn.cpp fused_attn/fused_attn.cpp
fused_attn/utils.cu fused_attn/utils.cu
......
...@@ -7,6 +7,7 @@ ...@@ -7,6 +7,7 @@
#include "transformer_engine/fused_attn.h" #include "transformer_engine/fused_attn.h"
#include "../common.h" #include "../common.h"
#include "utils.h" #include "utils.h"
#include "fused_attn_fp16_bf16_max_seqlen_512.h"
#include "fused_attn_fp8.h" #include "fused_attn_fp8.h"
// NVTE fused attention FWD FP8 with packed QKV // NVTE fused attention FWD FP8 with packed QKV
...@@ -26,6 +27,7 @@ void nvte_fused_attn_fwd_qkvpacked( ...@@ -26,6 +27,7 @@ void nvte_fused_attn_fwd_qkvpacked(
cudaStream_t stream) { cudaStream_t stream) {
NVTE_API_CALL(nvte_flash_attn_fwd_qkvpacked); NVTE_API_CALL(nvte_flash_attn_fwd_qkvpacked);
using namespace transformer_engine; using namespace transformer_engine;
const Tensor *input_cu_seqlens = reinterpret_cast<const Tensor*>(cu_seqlens); const Tensor *input_cu_seqlens = reinterpret_cast<const Tensor*>(cu_seqlens);
const Tensor *input_rng_state = reinterpret_cast<const Tensor*>(rng_state); const Tensor *input_rng_state = reinterpret_cast<const Tensor*>(rng_state);
const Tensor *input_QKV = reinterpret_cast<const Tensor*>(QKV); const Tensor *input_QKV = reinterpret_cast<const Tensor*>(QKV);
...@@ -35,15 +37,17 @@ void nvte_fused_attn_fwd_qkvpacked( ...@@ -35,15 +37,17 @@ void nvte_fused_attn_fwd_qkvpacked(
Tensor *wkspace = reinterpret_cast<Tensor*>(workspace); Tensor *wkspace = reinterpret_cast<Tensor*>(workspace);
// QKV shape is [total_seqs, 3, h, d] // QKV shape is [total_seqs, 3, h, d]
auto ndim = input_QKV->data.shape.size();
size_t b = input_cu_seqlens->data.shape[0] - 1; size_t b = input_cu_seqlens->data.shape[0] - 1;
size_t h = input_QKV->data.shape[2]; size_t h = input_QKV->data.shape[ndim - 2];
size_t d = input_QKV->data.shape[3]; size_t d = input_QKV->data.shape[ndim - 1];
auto handle = cudnnExecutionPlanManager::Instance().GetCudnnHandle();
const DType QKV_type = input_QKV->data.dtype; const DType QKV_type = input_QKV->data.dtype;
if (((QKV_type == DType::kFloat8E4M3) || (QKV_type == DType::kFloat8E5M2)) if (((QKV_type == DType::kFloat8E4M3) || (QKV_type == DType::kFloat8E5M2))
&& (max_seqlen <= 512)) { && (max_seqlen <= 512)) {
#if (CUDNN_VERSION >= 8900) #if (CUDNN_VERSION >= 8900)
auto handle = cudnnExecutionPlanManager::Instance().GetCudnnHandle();
// FP8 API doesn't use input_Bias, bias_type or attn_mask_type // FP8 API doesn't use input_Bias, bias_type or attn_mask_type
fused_attn_fwd_fp8_qkvpacked( fused_attn_fwd_fp8_qkvpacked(
b, max_seqlen, h, d, b, max_seqlen, h, d,
...@@ -58,7 +62,31 @@ void nvte_fused_attn_fwd_qkvpacked( ...@@ -58,7 +62,31 @@ void nvte_fused_attn_fwd_qkvpacked(
#endif #endif
} else if (((QKV_type == DType::kFloat16) || (QKV_type == DType::kBFloat16)) } else if (((QKV_type == DType::kFloat16) || (QKV_type == DType::kBFloat16))
&& (max_seqlen <= 512)) { && (max_seqlen <= 512)) {
NVTE_ERROR("TBD: No support for BF16/FP16 fused attention currently. \n"); #if (CUDNN_VERSION >= 8901)
fused_attn_max_512_fwd_qkvpacked(
b,
max_seqlen,
h,
d,
is_training,
attn_scale,
dropout,
qkv_layout,
bias_type,
attn_mask_type,
input_QKV,
input_Bias,
output_O,
Aux_Output_Tensors,
input_cu_seqlens,
input_rng_state,
wkspace,
stream,
handle);
#else
NVTE_ERROR(
"cuDNN 8.9.1 is required to run BF16/FP16 fused attention with max_seqlen<=512. \n");
#endif
} else if (max_seqlen > 512) { } else if (max_seqlen > 512) {
NVTE_ERROR("TBD: No support for fused attention with >512 seqlence length currently. \n"); NVTE_ERROR("TBD: No support for fused attention with >512 seqlence length currently. \n");
} else { } else {
...@@ -84,6 +112,7 @@ void nvte_fused_attn_bwd_qkvpacked( ...@@ -84,6 +112,7 @@ void nvte_fused_attn_bwd_qkvpacked(
cudaStream_t stream) { cudaStream_t stream) {
NVTE_API_CALL(nvte_flash_attn_bwd_qkvpacked); NVTE_API_CALL(nvte_flash_attn_bwd_qkvpacked);
using namespace transformer_engine; using namespace transformer_engine;
const Tensor *input_cu_seqlens = reinterpret_cast<const Tensor*>(cu_seqlens); const Tensor *input_cu_seqlens = reinterpret_cast<const Tensor*>(cu_seqlens);
const Tensor *input_QKV = reinterpret_cast<const Tensor*>(QKV); const Tensor *input_QKV = reinterpret_cast<const Tensor*>(QKV);
const Tensor *input_O = reinterpret_cast<const Tensor*>(O); const Tensor *input_O = reinterpret_cast<const Tensor*>(O);
...@@ -95,9 +124,12 @@ void nvte_fused_attn_bwd_qkvpacked( ...@@ -95,9 +124,12 @@ void nvte_fused_attn_bwd_qkvpacked(
Tensor *wkspace = reinterpret_cast<Tensor*>(workspace); Tensor *wkspace = reinterpret_cast<Tensor*>(workspace);
// QKV shape is [total_seqs, 3, h, d] // QKV shape is [total_seqs, 3, h, d]
auto ndim = input_QKV->data.shape.size();
size_t b = input_cu_seqlens->data.shape[0] - 1; size_t b = input_cu_seqlens->data.shape[0] - 1;
size_t h = input_QKV->data.shape[2]; size_t h = input_QKV->data.shape[ndim - 2];
size_t d = input_QKV->data.shape[3]; size_t d = input_QKV->data.shape[ndim - 1];
auto handle = cudnnExecutionPlanManager::Instance().GetCudnnHandle();
const DType QKV_type = input_QKV->data.dtype; const DType QKV_type = input_QKV->data.dtype;
if (((QKV_type == DType::kFloat8E4M3) || (QKV_type == DType::kFloat8E5M2)) if (((QKV_type == DType::kFloat8E4M3) || (QKV_type == DType::kFloat8E5M2))
...@@ -107,7 +139,7 @@ void nvte_fused_attn_bwd_qkvpacked( ...@@ -107,7 +139,7 @@ void nvte_fused_attn_bwd_qkvpacked(
const Tensor *input_M = reinterpret_cast<const Tensor*>(Aux_CTX_Tensors->tensors[0]); const Tensor *input_M = reinterpret_cast<const Tensor*>(Aux_CTX_Tensors->tensors[0]);
const Tensor *input_ZInv = reinterpret_cast<const Tensor*>(Aux_CTX_Tensors->tensors[1]); const Tensor *input_ZInv = reinterpret_cast<const Tensor*>(Aux_CTX_Tensors->tensors[1]);
const Tensor *input_rng_state = reinterpret_cast<const Tensor*>(Aux_CTX_Tensors->tensors[2]); const Tensor *input_rng_state = reinterpret_cast<const Tensor*>(Aux_CTX_Tensors->tensors[2]);
auto handle = cudnnExecutionPlanManager::Instance().GetCudnnHandle();
// FP8 API doesn't use input_dBias, bias_type or attn_mask_type // FP8 API doesn't use input_dBias, bias_type or attn_mask_type
fused_attn_bwd_fp8_qkvpacked( fused_attn_bwd_fp8_qkvpacked(
b, max_seqlen, h, d, b, max_seqlen, h, d,
...@@ -124,7 +156,30 @@ void nvte_fused_attn_bwd_qkvpacked( ...@@ -124,7 +156,30 @@ void nvte_fused_attn_bwd_qkvpacked(
#endif #endif
} else if (((QKV_type == DType::kFloat16) || (QKV_type == DType::kBFloat16)) } else if (((QKV_type == DType::kFloat16) || (QKV_type == DType::kBFloat16))
&& (max_seqlen <= 512)) { && (max_seqlen <= 512)) {
NVTE_ERROR("TBD: No support for BF16/FP16 fused attention currently. \n"); #if (CUDNN_VERSION >= 8901)
fused_attn_max_512_bwd_qkvpacked(
b,
max_seqlen,
h,
d,
attn_scale,
dropout,
qkv_layout,
bias_type,
attn_mask_type,
input_QKV,
input_dO,
Aux_CTX_Tensors,
output_dQKV,
output_dBias,
input_cu_seqlens,
wkspace,
stream,
handle);
#else
NVTE_ERROR(
"cuDNN 8.9.1 is required to run BF16/FP16 fused attention with max_seqlen<=512. \n");
#endif
} else if (max_seqlen > 512) { } else if (max_seqlen > 512) {
NVTE_ERROR("TBD: No support for fused attention with >512 seqlence length currently. \n"); NVTE_ERROR("TBD: No support for fused attention with >512 seqlence length currently. \n");
} else { } else {
...@@ -161,9 +216,13 @@ void nvte_fused_attn_fwd_kvpacked( ...@@ -161,9 +216,13 @@ void nvte_fused_attn_fwd_kvpacked(
Tensor *wkspace = reinterpret_cast<Tensor*>(workspace); Tensor *wkspace = reinterpret_cast<Tensor*>(workspace);
// Q shape is [total_seqs, h, d] // Q shape is [total_seqs, h, d]
// KV shape is [total_seqs, h, d]
auto ndim = input_Q->data.shape.size();
size_t b = input_cu_seqlens_q->data.shape[0] - 1; size_t b = input_cu_seqlens_q->data.shape[0] - 1;
size_t h = input_Q->data.shape[1]; size_t h = input_Q->data.shape[ndim - 2];
size_t d = input_Q->data.shape[2]; size_t d = input_Q->data.shape[ndim - 1];
auto handle = cudnnExecutionPlanManager::Instance().GetCudnnHandle();
const DType QKV_type = input_Q->data.dtype; const DType QKV_type = input_Q->data.dtype;
if (((QKV_type == DType::kFloat8E4M3) || (QKV_type == DType::kFloat8E5M2)) if (((QKV_type == DType::kFloat8E4M3) || (QKV_type == DType::kFloat8E5M2))
...@@ -171,7 +230,34 @@ void nvte_fused_attn_fwd_kvpacked( ...@@ -171,7 +230,34 @@ void nvte_fused_attn_fwd_kvpacked(
NVTE_ERROR("The FP8 fused attention API only supports packed QKV input. \n"); NVTE_ERROR("The FP8 fused attention API only supports packed QKV input. \n");
} else if (((QKV_type == DType::kFloat16) || (QKV_type == DType::kBFloat16)) } else if (((QKV_type == DType::kFloat16) || (QKV_type == DType::kBFloat16))
&& (max_seqlen_q <= 512) && (max_seqlen_kv <= 512)) { && (max_seqlen_q <= 512) && (max_seqlen_kv <= 512)) {
NVTE_ERROR("TBD: No support for BF16/FP16 fused attention currently. \n"); #if (CUDNN_VERSION >= 8901)
fused_attn_max_512_fwd_kvpacked(
b,
max_seqlen_q,
max_seqlen_kv,
h,
d,
is_training,
attn_scale,
dropout,
qkv_layout,
bias_type,
attn_mask_type,
input_Q,
input_KV,
input_Bias,
output_O,
Aux_Output_Tensors,
input_cu_seqlens_q,
input_cu_seqlens_kv,
input_rng_state,
wkspace,
stream,
handle);
#else
NVTE_ERROR(
"cuDNN 8.9.1 is required to run BF16/FP16 fused attention with max_seqlen<=512. \n");
#endif
} else if ((max_seqlen_q > 512) || (max_seqlen_kv > 512)) { } else if ((max_seqlen_q > 512) || (max_seqlen_kv > 512)) {
NVTE_ERROR("TBD: No support for fused attention with >512 seqlence length currently. \n"); NVTE_ERROR("TBD: No support for fused attention with >512 seqlence length currently. \n");
} else { } else {
...@@ -214,16 +300,48 @@ void nvte_fused_attn_bwd_kvpacked( ...@@ -214,16 +300,48 @@ void nvte_fused_attn_bwd_kvpacked(
Tensor *wkspace = reinterpret_cast<Tensor*>(workspace); Tensor *wkspace = reinterpret_cast<Tensor*>(workspace);
// Q shape is [total_seqs, h, d] // Q shape is [total_seqs, h, d]
// KV shape is [total_seqs, h, d]
auto ndim = input_Q->data.shape.size();
size_t b = input_cu_seqlens_q->data.shape[0] - 1; size_t b = input_cu_seqlens_q->data.shape[0] - 1;
size_t h = input_Q->data.shape[1]; size_t h = input_Q->data.shape[ndim - 2];
size_t d = input_Q->data.shape[2]; size_t d = input_Q->data.shape[ndim - 1];
auto handle = cudnnExecutionPlanManager::Instance().GetCudnnHandle();
const DType QKV_type = input_Q->data.dtype; const DType QKV_type = input_Q->data.dtype;
if (((QKV_type == DType::kFloat8E4M3) || (QKV_type == DType::kFloat8E5M2)) if (((QKV_type == DType::kFloat8E4M3) || (QKV_type == DType::kFloat8E5M2))
&& (max_seqlen_q <= 512) && (max_seqlen_kv <= 512)) { && (max_seqlen_q <= 512) && (max_seqlen_kv <= 512)) {
NVTE_ERROR("The FP8 fused attention API only supports packed QKV input. \n"); NVTE_ERROR("The FP8 fused attention API only supports packed QKV input. \n");
} else if (((QKV_type == DType::kFloat16) || (QKV_type == DType::kBFloat16)) } else if (((QKV_type == DType::kFloat16) || (QKV_type == DType::kBFloat16))
&& (max_seqlen_q <= 512) && (max_seqlen_kv <= 512)) { && (max_seqlen_q <= 512) && (max_seqlen_kv <= 512)) {
NVTE_ERROR("TBD: No support for BF16/FP16 fused attention currently. \n"); #if (CUDNN_VERSION >= 8901)
fused_attn_max_512_bwd_kvpacked(
b,
max_seqlen_q,
max_seqlen_kv,
h,
d,
attn_scale,
dropout,
qkv_layout,
bias_type,
attn_mask_type,
input_Q,
input_KV,
input_dO,
Aux_CTX_Tensors,
output_dQ,
output_dKV,
output_dBias,
input_cu_seqlens_q,
input_cu_seqlens_kv,
wkspace,
stream,
handle);
#else
NVTE_ERROR(
"cuDNN 8.9.1 is required to run BF16/FP16 fused attention with max_seqlen<=512. \n");
#endif
} else if ((max_seqlen_q > 512) || (max_seqlen_kv > 512)) { } else if ((max_seqlen_q > 512) || (max_seqlen_kv > 512)) {
NVTE_ERROR("TBD: No support for fused attention with >512 seqlence length currently. \n"); NVTE_ERROR("TBD: No support for fused attention with >512 seqlence length currently. \n");
} else { } else {
......
/*************************************************************************
* Copyright (c) 2022-2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
*
* See LICENSE for license information.
************************************************************************/
/*! \file fused_attn_max_512.h
* \brief Functions for fused attention with seqlen <= 512
*/
#ifndef TRANSFORMER_ENGINE_COMMON_FUSED_ATTN_FUSED_ATTN_MAX_512_H_
#define TRANSFORMER_ENGINE_COMMON_FUSED_ATTN_FUSED_ATTN_MAX_512_H_
#include "transformer_engine/fused_attn.h"
#include <cudnn.h>
#include "common/common.h"
namespace transformer_engine {
#if (CUDNN_VERSION >= 8901)
void fused_attn_max_512_fwd_qkvpacked(size_t batch, size_t max_seqlen, size_t num_head,
size_t head_size, bool is_training, float attn_scale,
float p_dropout, NVTE_QKV_Layout qkv_layout,
NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type,
const Tensor *input_QKV, const Tensor *input_Bias,
Tensor *output_O, NVTETensorPack *Aux_Output_Tensors,
const Tensor *cu_seqlens, const Tensor *rng_state,
Tensor *workspace, cudaStream_t stream, cudnnHandle_t handle);
void fused_attn_max_512_fwd_kvpacked(size_t batch, size_t q_max_seqlen, size_t kv_max_seqlen,
size_t num_head, size_t head_dim, bool is_training,
float attn_scale, float p_dropout, NVTE_QKV_Layout qkv_layout,
NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type,
const Tensor *input_Q, const Tensor *input_KV,
const Tensor *input_Bias, Tensor *output_O,
NVTETensorPack *Aux_Output_Tensors, const Tensor *q_cu_seqlens,
const Tensor *kv_cu_seqlens, const Tensor *rng_state,
Tensor *workspace, cudaStream_t stream, cudnnHandle_t handle);
void fused_attn_max_512_bwd_qkvpacked(size_t batch, size_t max_seqlen, size_t num_head,
size_t head_dim, float attn_scale, float p_dropout,
NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type,
NVTE_Mask_Type mask_type, const Tensor *input_QKV,
const Tensor *input_dO, const NVTETensorPack *Aux_CTX_Tensors,
Tensor *output_dQKV, Tensor *output_dBias,
const Tensor *cu_seqlens, Tensor *workspace,
cudaStream_t stream, cudnnHandle_t handle);
void fused_attn_max_512_bwd_kvpacked(size_t batch, size_t q_max_seqlen, size_t kv_max_seqlen,
size_t num_head, size_t head_dim, float attn_scale,
float p_dropout, NVTE_QKV_Layout qkv_layout,
NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type,
const Tensor *input_Q, const Tensor *input_KV,
const Tensor *input_dO, const NVTETensorPack *Aux_CTX_Tensors,
Tensor *output_dQ, Tensor *output_dKV, Tensor *output_dBias,
const Tensor *q_cu_seqlens, const Tensor *kv_cu_seqlens,
Tensor *workspace, cudaStream_t stream, cudnnHandle_t handle);
#endif // CUDNN_VERSION >= 8901
} // namespace transformer_engine
#endif // TRANSFORMER_ENGINE_COMMON_FUSED_ATTN_FUSED_ATTN_MAX_512_H_
...@@ -5,6 +5,7 @@ ...@@ -5,6 +5,7 @@
************************************************************************/ ************************************************************************/
#include "transformer_engine/fused_attn.h" #include "transformer_engine/fused_attn.h"
#include "../common.h" #include "../common.h"
#include "utils.h" #include "utils.h"
#include "fused_attn_fp8.h" #include "fused_attn_fp8.h"
...@@ -70,84 +71,6 @@ std::unordered_map<std::string, int> tensor_name_to_uid = { ...@@ -70,84 +71,6 @@ std::unordered_map<std::string, int> tensor_name_to_uid = {
{"VIRTUAL", 80} {"VIRTUAL", 80}
}; };
bool allowAllConfig(cudnnBackendDescriptor_t engine_config) {
(void)engine_config;
return false;
}
static cudnn_frontend::Tensor tensor_create(
cudnnDataType_t type, int64_t id,
int64_t const * dim, int64_t const * stride,
bool is_virtual, bool is_value) {
int nbDims = 4;
auto tensor_created = cudnn_frontend::TensorBuilder()
.setDim(nbDims, dim)
.setStride(nbDims, stride)
.setId(id)
.setAlignment(16) // 16B alignment is needed to run a tensor core engine
.setDataType(type)
.setVirtual(is_virtual)
.setByValue(is_value)
.build();
return tensor_created;
}
static cudnn_frontend::Tensor tensor_create_with_offset(
cudnnDataType_t type, int64_t id,
int64_t const * dim, int64_t const * stride,
bool is_virtual, bool is_value,
std::shared_ptr<cudnn_frontend::Tensor> raggedOffset) {
int nbDims = 4;
auto tensor_created = cudnn_frontend::TensorBuilder()
.setDim(nbDims, dim)
.setStride(nbDims, stride)
.setId(id)
.setAlignment(16) // 16B alignment is needed to run a tensor core engine
.setDataType(type)
.setVirtual(is_virtual)
.setByValue(is_value)
.setRaggedOffset(raggedOffset)
.build();
return tensor_created;
}
static cudnn_frontend::PointWiseDesc pw_desc_create(
cudnnDataType_t type, cudnnPointwiseMode_t mode) {
auto pw_desc_created = cudnn_frontend::PointWiseDescBuilder()
.setMode(mode)
.setComputeType(type)
.build();
return pw_desc_created;
}
static cudnn_frontend::Operation unary_pw_op_create(
cudnn_frontend::Tensor const &xDesc,
cudnn_frontend::Tensor const &yDesc,
cudnn_frontend::PointWiseDesc const &pwDesc) {
auto pw_op_created = cudnn_frontend::OperationBuilder(
CUDNN_BACKEND_OPERATION_POINTWISE_DESCRIPTOR)
.setxDesc(xDesc)
.setyDesc(yDesc)
.setpwDesc(pwDesc)
.build();
return pw_op_created;
}
static cudnn_frontend::Operation binary_pw_op_create(
cudnn_frontend::Tensor const &xDesc,
cudnn_frontend::Tensor const &bDesc,
cudnn_frontend::Tensor const &yDesc,
cudnn_frontend::PointWiseDesc const &pwDesc) {
auto pw_op_created = cudnn_frontend::OperationBuilder(
CUDNN_BACKEND_OPERATION_POINTWISE_DESCRIPTOR)
.setxDesc(xDesc)
.setbDesc(bDesc)
.setyDesc(yDesc)
.setpwDesc(pwDesc)
.build();
return pw_op_created;
}
static cudnn_frontend::Tensor createAmax( static cudnn_frontend::Tensor createAmax(
const std::string& amax_tensor_name, const std::string& amax_tensor_name,
const cudnn_frontend::Tensor& prevBlockOutputTensor, const cudnn_frontend::Tensor& prevBlockOutputTensor,
...@@ -1089,7 +1012,8 @@ void fa_fwd_fp8(int64_t b, int64_t s_q, int64_t s_kv, int64_t h, int64_t d, ...@@ -1089,7 +1012,8 @@ void fa_fwd_fp8(int64_t b, int64_t s_q, int64_t s_kv, int64_t h, int64_t d,
FADescriptor descriptor{ FADescriptor descriptor{
b, h, s_q, s_kv, d, b, h, s_q, s_kv, d,
attnScale, isTraining, dropoutProbability, layout, tensorType}; attnScale, isTraining, dropoutProbability, layout,
NVTE_Bias_Type::NVTE_NO_BIAS, NVTE_Mask_Type::NVTE_PADDING_MASK, tensorType};
using CacheType = std::map<FADescriptor, cudnn_frontend::ExecutionPlan>; using CacheType = std::map<FADescriptor, cudnn_frontend::ExecutionPlan>;
static CacheType fa_fprop_cache; static CacheType fa_fprop_cache;
...@@ -1404,7 +1328,8 @@ void fa_bwd_fp8(int64_t b, int64_t s_q, int64_t s_kv, int64_t h, int64_t d, ...@@ -1404,7 +1328,8 @@ void fa_bwd_fp8(int64_t b, int64_t s_q, int64_t s_kv, int64_t h, int64_t d,
FADescriptor descriptor{ FADescriptor descriptor{
b, h, s_q, s_kv, d, b, h, s_q, s_kv, d,
attnScale, false, dropoutProbability, layout, tensorType}; attnScale, false, dropoutProbability, layout,
NVTE_Bias_Type::NVTE_NO_BIAS, NVTE_Mask_Type::NVTE_PADDING_MASK, tensorType};
using CacheType = std::map<FADescriptor, cudnn_frontend::ExecutionPlan>; using CacheType = std::map<FADescriptor, cudnn_frontend::ExecutionPlan>;
static CacheType fa_bprop_cache; static CacheType fa_bprop_cache;
......
...@@ -51,13 +51,13 @@ void generateMatrixStrides( ...@@ -51,13 +51,13 @@ void generateMatrixStrides(
strideA[head_dim_idx] = d; strideA[head_dim_idx] = d;
strideA[batch_dim_idx] = s_kv * 3 * h * d; strideA[batch_dim_idx] = s_kv * 3 * h * d;
} else if (layout == NVTE_QKV_Layout::NVTE_KV_INTERLEAVED) { } else if (layout == NVTE_QKV_Layout::NVTE_KV_INTERLEAVED) {
strideA[seqlen_transpose_dim_idx] = 2 * h * d; strideA[seqlen_dim_idx] = 2 * h * d;
strideA[hidden_transpose_dim_idx] = 1; strideA[hidden_dim_idx] = 1;
strideA[head_dim_idx] = d; strideA[head_dim_idx] = d;
strideA[batch_dim_idx] = s_kv * 2 * h * d; strideA[batch_dim_idx] = s_kv * 2 * h * d;
} else { } else {
strideA[seqlen_transpose_dim_idx] = h * d; strideA[seqlen_dim_idx] = h * d;
strideA[hidden_transpose_dim_idx] = 1; strideA[hidden_dim_idx] = 1;
strideA[head_dim_idx] = d; strideA[head_dim_idx] = d;
strideA[batch_dim_idx] = s_kv * h * d; strideA[batch_dim_idx] = s_kv * h * d;
} }
...@@ -131,6 +131,99 @@ void generateMatrixStrides( ...@@ -131,6 +131,99 @@ void generateMatrixStrides(
} }
} }
bool allowAllConfig(cudnnBackendDescriptor_t engine_config) {
(void)engine_config;
return false;
}
cudnn_frontend::Tensor tensor_create(
cudnnDataType_t type, int64_t id,
int64_t const * dim, int64_t const * stride,
bool is_virtual, bool is_value) {
int nbDims = 4;
auto tensor_created = cudnn_frontend::TensorBuilder()
.setDim(nbDims, dim)
.setStride(nbDims, stride)
.setId(id)
.setAlignment(16) // 16B alignment is needed to run a tensor core engine
.setDataType(type)
.setVirtual(is_virtual)
.setByValue(is_value)
.build();
return tensor_created;
}
cudnn_frontend::Tensor tensor_create_with_offset(
cudnnDataType_t type, int64_t id,
int64_t const * dim, int64_t const * stride,
bool is_virtual, bool is_value,
std::shared_ptr<cudnn_frontend::Tensor> raggedOffset) {
int nbDims = 4;
auto tensor_created = cudnn_frontend::TensorBuilder()
.setDim(nbDims, dim)
.setStride(nbDims, stride)
.setId(id)
.setAlignment(16) // 16B alignment is needed to run a tensor core engine
.setDataType(type)
.setVirtual(is_virtual)
.setByValue(is_value)
.setRaggedOffset(raggedOffset)
.build();
return tensor_created;
}
cudnn_frontend::PointWiseDesc pw_desc_create(
cudnnDataType_t type, cudnnPointwiseMode_t mode) {
auto pw_desc_created = cudnn_frontend::PointWiseDescBuilder()
.setMode(mode)
.setComputeType(type)
.build();
return pw_desc_created;
}
cudnn_frontend::Operation unary_pw_op_create(
cudnn_frontend::Tensor const &xDesc,
cudnn_frontend::Tensor const &yDesc,
cudnn_frontend::PointWiseDesc const &pwDesc) {
auto pw_op_created = cudnn_frontend::OperationBuilder(
CUDNN_BACKEND_OPERATION_POINTWISE_DESCRIPTOR)
.setxDesc(xDesc)
.setyDesc(yDesc)
.setpwDesc(pwDesc)
.build();
return pw_op_created;
}
cudnn_frontend::Operation binary_pw_op_create(
cudnn_frontend::Tensor const &xDesc,
cudnn_frontend::Tensor const &bDesc,
cudnn_frontend::Tensor const &yDesc,
cudnn_frontend::PointWiseDesc const &pwDesc) {
auto pw_op_created = cudnn_frontend::OperationBuilder(
CUDNN_BACKEND_OPERATION_POINTWISE_DESCRIPTOR)
.setxDesc(xDesc)
.setbDesc(bDesc)
.setyDesc(yDesc)
.setpwDesc(pwDesc)
.build();
return pw_op_created;
}
cudnn_frontend::Operation ternary_pw_op_create(
cudnn_frontend::Tensor const &xDesc, cudnn_frontend::Tensor const &bDesc,
cudnn_frontend::Tensor const &tDesc, cudnn_frontend::Tensor const &yDesc,
cudnn_frontend::PointWiseDesc const &pwDesc) {
auto pw_op_created = cudnn_frontend::OperationBuilder(
CUDNN_BACKEND_OPERATION_POINTWISE_DESCRIPTOR)
.setxDesc(xDesc)
.setbDesc(bDesc)
.settDesc(tDesc)
.setyDesc(yDesc)
.setpwDesc(pwDesc)
.build();
return pw_op_created;
}
// convert cu_seqlens_q to qkv/o_ragged_offset and actual_seqlens_q // convert cu_seqlens_q to qkv/o_ragged_offset and actual_seqlens_q
__global__ void cu_seqlens_to_offsets(size_t b, size_t h, size_t d, __global__ void cu_seqlens_to_offsets(size_t b, size_t h, size_t d,
int32_t *cu_seqlens_q, int32_t *actual_seqlens_q, int32_t *cu_seqlens_q, int32_t *actual_seqlens_q,
...@@ -144,6 +237,19 @@ __global__ void cu_seqlens_to_offsets(size_t b, size_t h, size_t d, ...@@ -144,6 +237,19 @@ __global__ void cu_seqlens_to_offsets(size_t b, size_t h, size_t d,
o_ragged_offset[tid] = cu_seqlens_q[tid] * h * d; o_ragged_offset[tid] = cu_seqlens_q[tid] * h * d;
} }
} }
// convert cu_seqlens to actual_seqlens
__global__ void cu_seqlens_to_actual_seqlens(size_t b,
int32_t const * const q_cu_seqlens,
int32_t const * const kv_cu_seqlens,
int32_t *q_seqlens, int32_t *kv_seqlens) {
size_t tid = blockIdx.x * blockDim.x + threadIdx.x;
if (tid < b) {
q_seqlens[tid] = q_cu_seqlens[tid + 1] - q_cu_seqlens[tid];
kv_seqlens[tid] = kv_cu_seqlens[tid + 1] - kv_cu_seqlens[tid];
}
}
} // namespace fused_attn } // namespace fused_attn
// get cuDNN data type // get cuDNN data type
......
...@@ -7,9 +7,15 @@ ...@@ -7,9 +7,15 @@
#ifndef TRANSFORMER_ENGINE_FUSED_ATTN_UTILS_H_ #ifndef TRANSFORMER_ENGINE_FUSED_ATTN_UTILS_H_
#define TRANSFORMER_ENGINE_FUSED_ATTN_UTILS_H_ #define TRANSFORMER_ENGINE_FUSED_ATTN_UTILS_H_
#include "transformer_engine/fused_attn.h"
#include "transformer_engine/transformer_engine.h" #include "transformer_engine/transformer_engine.h"
#include <cudnn.h>
#include <cudnn_frontend.h> #include <cudnn_frontend.h>
#include <cstdint>
#include <mutex>
namespace transformer_engine { namespace transformer_engine {
namespace fused_attn { namespace fused_attn {
...@@ -31,6 +37,36 @@ void generateMatrixStrides( ...@@ -31,6 +37,36 @@ void generateMatrixStrides(
int64_t d, int64_t* strideA, int64_t d, int64_t* strideA,
NVTE_QKV_Layout layout, NVTE_QKV_Matrix matrix); NVTE_QKV_Layout layout, NVTE_QKV_Matrix matrix);
bool allowAllConfig(cudnnBackendDescriptor_t engine_config);
cudnn_frontend::Tensor tensor_create(cudnnDataType_t type, int64_t id,
int64_t const *dim,
int64_t const *stride,
bool is_virtual, bool is_value);
cudnn_frontend::Tensor tensor_create_with_offset(
cudnnDataType_t type, int64_t id,
int64_t const * dim, int64_t const * stride,
bool is_virtual, bool is_value,
std::shared_ptr<cudnn_frontend::Tensor> raggedOffset);
cudnn_frontend::PointWiseDesc pw_desc_create(cudnnDataType_t type,
cudnnPointwiseMode_t mode);
cudnn_frontend::Operation unary_pw_op_create(
cudnn_frontend::Tensor const &xDesc, cudnn_frontend::Tensor const &yDesc,
cudnn_frontend::PointWiseDesc const &pwDesc);
cudnn_frontend::Operation binary_pw_op_create(
cudnn_frontend::Tensor const &xDesc, cudnn_frontend::Tensor const &bDesc,
cudnn_frontend::Tensor const &yDesc,
cudnn_frontend::PointWiseDesc const &pwDesc);
cudnn_frontend::Operation ternary_pw_op_create(
cudnn_frontend::Tensor const &xDesc, cudnn_frontend::Tensor const &bDesc,
cudnn_frontend::Tensor const &tDesc, cudnn_frontend::Tensor const &yDesc,
cudnn_frontend::PointWiseDesc const &pwDesc);
struct FADescriptor { struct FADescriptor {
std::int64_t b; std::int64_t b;
std::int64_t h; std::int64_t h;
...@@ -41,15 +77,19 @@ struct FADescriptor { ...@@ -41,15 +77,19 @@ struct FADescriptor {
bool isTraining; bool isTraining;
float dropoutProbability; float dropoutProbability;
NVTE_QKV_Layout layout; NVTE_QKV_Layout layout;
NVTE_Bias_Type bias_type;
NVTE_Mask_Type mask_type;
cudnnDataType_t tensor_type; cudnnDataType_t tensor_type;
bool operator<(const FADescriptor &rhs) const { bool operator<(const FADescriptor &rhs) const {
return std::tie(b, h, s_q, s_kv, d, return std::tie(b, h, s_q, s_kv, d,
attnScale, isTraining, dropoutProbability, attnScale, isTraining, dropoutProbability,
layout, tensor_type) < std::tie( layout, mask_type, bias_type, tensor_type)
rhs.b, rhs.h, rhs.s_q, rhs.s_kv, rhs.d, < std::tie(
rhs.attnScale, rhs.isTraining, rhs.b, rhs.h, rhs.s_q, rhs.s_kv, rhs.d,
rhs.dropoutProbability, rhs.layout, rhs.tensor_type); rhs.attnScale, rhs.isTraining,
rhs.dropoutProbability, rhs.layout,
rhs.mask_type, rhs.bias_type, rhs.tensor_type);
} }
}; };
...@@ -57,6 +97,11 @@ __global__ void cu_seqlens_to_offsets(size_t b, size_t h, size_t d, ...@@ -57,6 +97,11 @@ __global__ void cu_seqlens_to_offsets(size_t b, size_t h, size_t d,
int32_t *cu_seqlens_q, int32_t *actual_seqlens_q, int32_t *cu_seqlens_q, int32_t *actual_seqlens_q,
int32_t *qkv_ragged_offset, int32_t *o_ragged_offset); int32_t *qkv_ragged_offset, int32_t *o_ragged_offset);
__global__ void cu_seqlens_to_actual_seqlens(size_t b,
int32_t const * const q_cu_seqlens,
int32_t const * const kv_cu_seqlens,
int32_t *q_seqlens, int32_t *kv_seqlens);
} // namespace fused_attn } // namespace fused_attn
cudnnDataType_t get_cudnn_dtype(const transformer_engine::DType t); cudnnDataType_t get_cudnn_dtype(const transformer_engine::DType t);
......
...@@ -78,8 +78,9 @@ enum NVTE_Mask_Type { ...@@ -78,8 +78,9 @@ enum NVTE_Mask_Type {
* - O = D * V.T * - O = D * V.T
* *
* Support Matrix: * Support Matrix:
* | precision | qkv layout | bias | mask | sequence length | head_dim | * | precision | qkv layout | bias | mask | dropout | sequence length | head_dim |
* | FP8 | QKV_INTERLEAVED | NO_BIAS | PADDING | <= 512 | 64 | * | FP8 | QKV_INTERLEAVED | NO_BIAS | PADDING | Yes | <= 512 | 64 |
* | FP16/BF16 | QKV_INTERLEAVED | NO_BIAS/POST_SCALE_BIAS | PADDING/CAUSAL | No | <= 512 | 64 |
* *
* *
* \param[in] QKV The QKV tensor in packed format, * \param[in] QKV The QKV tensor in packed format,
...@@ -119,8 +120,9 @@ void nvte_fused_attn_fwd_qkvpacked( ...@@ -119,8 +120,9 @@ void nvte_fused_attn_fwd_qkvpacked(
/*! \brief Compute the backward of the dot product attention with packed QKV input. /*! \brief Compute the backward of the dot product attention with packed QKV input.
* *
* Support Matrix: * Support Matrix:
* | precision | qkv layout | bias | mask | sequence length | head_dim | * | precision | qkv layout | bias | mask | dropout | sequence length | head_dim |
* | FP8 | QKV_INTERLEAVED | NO_BIAS | PADDING | <= 512 | 64 | * | FP8 | QKV_INTERLEAVED | NO_BIAS | PADDING | Yes | <= 512 | 64 |
* | FP16/BF16 | QKV_INTERLEAVED | NO_BIAS/POST_SCALE_BIAS | PADDING/CAUSAL | No | <= 512 | 64 |
* *
* *
* \param[in] QKV The QKV tensor in packed format, * \param[in] QKV The QKV tensor in packed format,
...@@ -168,6 +170,11 @@ void nvte_fused_attn_bwd_qkvpacked( ...@@ -168,6 +170,11 @@ void nvte_fused_attn_bwd_qkvpacked(
* - D = Dropout(S) * - D = Dropout(S)
* - O = D * V.T * - O = D * V.T
* *
* Support Matrix:
* | precision | qkv layout | bias | mask | dropout | sequence length | head_dim |
* | FP16/BF16 | QKV_INTERLEAVED | NO_BIAS/POST_SCALE_BIAS | PADDING/CAUSAL | No | <= 512 | 64 |
*
*
* \param[in] Q The Q tensor, [total_seqs_q, num_heads, head_dim]. * \param[in] Q The Q tensor, [total_seqs_q, num_heads, head_dim].
* \param[in] KV The KV tensor, [total_seqs_kv, 2, num_heads, head_dim]. * \param[in] KV The KV tensor, [total_seqs_kv, 2, num_heads, head_dim].
* \param[in] Bias The Bias tensor. * \param[in] Bias The Bias tensor.
...@@ -208,6 +215,11 @@ void nvte_fused_attn_fwd_kvpacked( ...@@ -208,6 +215,11 @@ void nvte_fused_attn_fwd_kvpacked(
cudaStream_t stream); cudaStream_t stream);
/*! \brief Compute the backward of the dot product attention with packed KV input. /*! \brief Compute the backward of the dot product attention with packed KV input.
*
* Support Matrix:
* | precision | qkv layout | bias | mask | dropout | sequence length | head_dim |
* | FP16/BF16 | QKV_INTERLEAVED | NO_BIAS/POST_SCALE_BIAS | PADDING/CAUSAL | No | <= 512 | 64 |
*
* *
* \param[in] Q The Q tensor, [total_seqs_q, num_heads, head_dim]. * \param[in] Q The Q tensor, [total_seqs_q, num_heads, head_dim].
* \param[in] KV The KV tensor, [total_seqs_kv, 2, num_heads, head_dim]. * \param[in] KV The KV tensor, [total_seqs_kv, 2, num_heads, head_dim].
......
...@@ -19,6 +19,8 @@ from jax.interpreters.mlir import ir, dtype_to_ir_type ...@@ -19,6 +19,8 @@ from jax.interpreters.mlir import ir, dtype_to_ir_type
import transformer_engine_jax import transformer_engine_jax
from transformer_engine_jax import DType as TEDType from transformer_engine_jax import DType as TEDType
from transformer_engine_jax import NVTE_Bias_Type
from transformer_engine_jax import NVTE_Mask_Type
for _name, _value in transformer_engine_jax.registrations().items(): for _name, _value in transformer_engine_jax.registrations().items():
xla_client.register_custom_call_target(_name, _value, platform="CUDA") xla_client.register_custom_call_target(_name, _value, platform="CUDA")
...@@ -1973,3 +1975,453 @@ def scaled_upper_triang_masked_softmax_bwd(grad_outputs: jnp.ndarray, softmax_ou ...@@ -1973,3 +1975,453 @@ def scaled_upper_triang_masked_softmax_bwd(grad_outputs: jnp.ndarray, softmax_ou
return _scaled_upper_triang_masked_softmax_bwd_p.bind(grad_outputs, return _scaled_upper_triang_masked_softmax_bwd_p.bind(grad_outputs,
softmax_outputs, softmax_outputs,
scale_factor=scale_factor) scale_factor=scale_factor)
class SelfFusedAttnMax512FwdPrimitive(BasePrimitive):
"""
Self Fused Attention Max Seqlen 512 Forward Primitive
"""
name = "te_self_fused_attn_max_512_forward"
multiple_results = True
@staticmethod
def abstract(
qkv,
bias,
cu_seqlen, # pylint: disable=unused-argument
rng_state, # pylint: disable=unused-argument
*,
attn_bias_type, # pylint: disable=unused-argument
attn_mask_type, # pylint: disable=unused-argument
scaling_factor, # pylint: disable=unused-argument
dropout_probability, # pylint: disable=unused-argument
is_training # pylint: disable=unused-argument
):
"""
Self fused attention max seqlen 512 fwd abstract
"""
qkv_dtype = dtypes.canonicalize_dtype(qkv.dtype)
batch, max_seqlen, nqkv, num_head, head_dim = qkv.shape
assert nqkv == 3
assert qkv.dtype == bias.dtype
output_shape = (batch, max_seqlen, num_head, head_dim)
output_dtype = qkv_dtype
softmax_aux_shape = (batch, num_head, max_seqlen, max_seqlen)
softmax_dtype = qkv_dtype
return (
ShapedArray(output_shape, output_dtype, named_shape=qkv.named_shape), # output
ShapedArray(softmax_aux_shape, softmax_dtype,
named_shape=qkv.named_shape), # softmax_aux
)
@staticmethod
def lowering(ctx, qkv, bias, cu_seqlen, rng_state, *, attn_bias_type, attn_mask_type,
scaling_factor, dropout_probability, is_training):
"""
Self fused attention max seqlen 512 fwd lowering rules
"""
qkv_aval, _, _, _ = ctx.avals_in
ir_qkv_type = ir.RankedTensorType(qkv.type)
ir_qkv_shape = ir_qkv_type.shape
ir_bias_type = ir.RankedTensorType(bias.type)
ir_bias_shape = ir_bias_type.shape
ir_cu_seqlen_type = ir.RankedTensorType(cu_seqlen.type)
ir_cu_seqlen_shape = ir_cu_seqlen_type.shape
ir_rng_state_type = ir.RankedTensorType(rng_state.type)
ir_rng_state_shape = ir_rng_state_type.shape
batch, max_seqlen, nqkv, num_head, head_dim = ir_qkv_shape
assert nqkv == 3
output_shape = (batch, max_seqlen, num_head, head_dim)
softmax_aux_shape = (batch, num_head, max_seqlen, max_seqlen)
out_types = [
ir.RankedTensorType.get(output_shape, ir_qkv_type.element_type),
ir.RankedTensorType.get(softmax_aux_shape, ir_qkv_type.element_type)
]
operands = [qkv, bias, cu_seqlen, rng_state]
operand_shapes = [ir_qkv_shape, ir_bias_shape, ir_cu_seqlen_shape, ir_rng_state_shape]
args = CustomCallArgsWrapper(out_types, operands, operand_shapes)
opaque = transformer_engine_jax.pack_fused_attn_descriptor(
batch, num_head, max_seqlen, max_seqlen, head_dim, scaling_factor, dropout_probability,
attn_bias_type, attn_mask_type, jax_dtype_to_te_dtype(qkv_aval.dtype), is_training)
out = custom_caller(SelfFusedAttnMax512FwdPrimitive.name,
args,
opaque,
has_side_effect=False)
return out
_self_fused_attn_max_512_fwd_p = register_primitive(SelfFusedAttnMax512FwdPrimitive)
def self_fused_attn_max_512_fwd(qkv: jnp.ndarray, bias: jnp.ndarray, cu_seqlen: jnp.ndarray,
rng_state: jnp.ndarray, attn_bias_type: NVTE_Bias_Type,
attn_mask_type: NVTE_Mask_Type, scaling_factor: float,
dropout_probability: float, is_training: bool):
"""
Wrapper for TE self fused attention max seqlen 512 fwd
Return BMM1 -> (PreBias) -> ScaleMaskSoftmax -> (PostBias) -> (Dropout) -> BMM2
"""
# Jax can't bind None, create a dummy tensor for None
if rng_state is None:
rng_state = jnp.zeros(2, dtype=jnp.int32)
if bias is None:
assert attn_bias_type == NVTE_Bias_Type.NVTE_NO_BIAS
bias = jnp.zeros(0, dtype=qkv.dtype)
return _self_fused_attn_max_512_fwd_p.bind(qkv,
bias,
cu_seqlen,
rng_state,
attn_bias_type=attn_bias_type,
attn_mask_type=attn_mask_type,
scaling_factor=scaling_factor,
dropout_probability=dropout_probability,
is_training=is_training)
class SelfFusedAttnMax512BwdPrimitive(BasePrimitive):
"""
Self Fused Attention Max Seqlen 512 Backward Primitive
"""
name = "te_self_fused_attn_max_512_backward"
multiple_results = True
@staticmethod
def abstract(
qkv,
softmax_aux,
doutput,
cu_seqlen, # pylint: disable=unused-argument
*,
attn_bias_type, # pylint: disable=unused-argument
attn_mask_type, # pylint: disable=unused-argument
scaling_factor, # pylint: disable=unused-argument
dropout_probability, # pylint: disable=unused-argument
is_training # pylint: disable=unused-argument
):
"""
Self fused attention bwd abstract
"""
qkv_dtype = dtypes.canonicalize_dtype(qkv.dtype)
assert qkv.dtype == softmax_aux.dtype == doutput.dtype
_, seqlen, _, num_head, _ = qkv.shape
bias_shape = (1, num_head, seqlen, seqlen)
bias_dtype = qkv_dtype
return (
ShapedArray(qkv.shape, qkv_dtype, named_shape=qkv.named_shape), # dqkv
ShapedArray(bias_shape, bias_dtype, named_shape=qkv.named_shape))
@staticmethod
def lowering(ctx, qkv, softmax_aux, doutput, cu_seqlen, *, attn_bias_type, attn_mask_type,
scaling_factor, dropout_probability, is_training):
"""
Self fused attention max seqlen 512 bwd lowering rules
"""
qkv_aval, _, _, _ = ctx.avals_in
ir_qkv_type = ir.RankedTensorType(qkv.type)
ir_qkv_shape = ir_qkv_type.shape
ir_softmax_aux_type = ir.RankedTensorType(softmax_aux.type)
ir_softmax_aux_shape = ir_softmax_aux_type.shape
ir_doutput_type = ir.RankedTensorType(doutput.type)
ir_doutput_shape = ir_doutput_type.shape
ir_cu_seqlen_type = ir.RankedTensorType(cu_seqlen.type)
ir_cu_seqlen_shape = ir_cu_seqlen_type.shape
batch, max_seqlen, num_head, head_dim = ir_doutput_shape
dbias_shape = (1, num_head, max_seqlen, max_seqlen)
dbias_dtype = ir_qkv_type.element_type
out_types = [
ir.RankedTensorType.get(ir_qkv_shape, ir_qkv_type.element_type),
ir.RankedTensorType.get(dbias_shape, dbias_dtype)
]
operands = [qkv, softmax_aux, doutput, cu_seqlen]
operand_shapes = [ir_qkv_shape, ir_softmax_aux_shape, ir_doutput_shape, ir_cu_seqlen_shape]
args = CustomCallArgsWrapper(out_types, operands, operand_shapes)
opaque = transformer_engine_jax.pack_fused_attn_descriptor(
batch, num_head, max_seqlen, max_seqlen, head_dim, scaling_factor, dropout_probability,
attn_bias_type, attn_mask_type, jax_dtype_to_te_dtype(qkv_aval.dtype), is_training)
out = custom_caller(SelfFusedAttnMax512BwdPrimitive.name,
args,
opaque,
has_side_effect=False)
return out
_self_fused_attn_max_512_bwd_p = register_primitive(SelfFusedAttnMax512BwdPrimitive)
def self_fused_attn_max_512_bwd(qkv: jnp.ndarray, softmax_aux: jnp.ndarray, doutput: jnp.ndarray,
cu_seqlen: jnp.ndarray, attn_bias_type: NVTE_Bias_Type,
attn_mask_type: NVTE_Mask_Type, scaling_factor: float,
dropout_probability: float, is_training: bool):
"""
Wrapper for TE self fused attention max seqlen 512 bwd
Return the gradients of self fused attention with packed qkv input
"""
return _self_fused_attn_max_512_bwd_p.bind(qkv,
softmax_aux,
doutput,
cu_seqlen,
attn_bias_type=attn_bias_type,
attn_mask_type=attn_mask_type,
scaling_factor=scaling_factor,
dropout_probability=dropout_probability,
is_training=is_training)
class CrossFusedAttnMax512FwdPrimitive(BasePrimitive):
"""
Cross Fused Attention Forward Max Seqlen 512 Primitive
"""
name = "te_cross_fused_attn_max_512_forward"
multiple_results = True
@staticmethod
def abstract(
q,
kv,
q_cu_seqlen,
kv_cu_seqlen,
rng_state, # pylint: disable=unused-argument
*,
attn_bias_type, # pylint: disable=unused-argument
attn_mask_type, # pylint: disable=unused-argument
scaling_factor, # pylint: disable=unused-argument
dropout_probability, # pylint: disable=unused-argument
is_training # pylint: disable=unused-argument
):
"""
Cross fused attention max seqlen 512 fwd abstract
"""
q_dtype = dtypes.canonicalize_dtype(q.dtype)
batch_q, q_max_seqlen, num_head_q, head_dim_q = q.shape
kv_dtype = dtypes.canonicalize_dtype(kv.dtype)
batch_kv, kv_max_seqlen, nkv, num_head_kv, head_dim_kv = kv.shape
assert q_dtype == kv_dtype
assert batch_q == batch_kv
assert num_head_q == num_head_kv
assert head_dim_q == head_dim_kv
assert nkv == 2
assert q_cu_seqlen.dtype == kv_cu_seqlen.dtype
output_shape = q.shape
output_dtype = q_dtype
softmax_aux_shape = (batch_q, num_head_q, q_max_seqlen, kv_max_seqlen)
softmax_aux_dtype = q_dtype
return (
ShapedArray(output_shape, output_dtype, named_shape=q.named_shape), # output
ShapedArray(softmax_aux_shape, softmax_aux_dtype,
named_shape=q.named_shape), # softmax_aux
)
@staticmethod
def lowering(ctx, q, kv, q_cu_seqlen, kv_cu_seqlen, rng_state, *, attn_bias_type,
attn_mask_type, scaling_factor, dropout_probability, is_training):
"""
Cross fused attention max seqlen 512 fwd lowering rules
"""
q_aval, kv_aval, _, _, _ = ctx.avals_in
assert q_aval.dtype == kv_aval.dtype
ir_q_type = ir.RankedTensorType(q.type)
ir_q_shape = ir_q_type.shape
ir_kv_type = ir.RankedTensorType(kv.type)
ir_kv_shape = ir_kv_type.shape
ir_q_cu_seqlen_shape = ir.RankedTensorType(q_cu_seqlen.type).shape
ir_kv_cu_seqlen_shape = ir.RankedTensorType(kv_cu_seqlen.type).shape
ir_rng_state_type = ir.RankedTensorType(rng_state.type)
ir_rng_state_shape = ir_rng_state_type.shape
batch, q_max_seqlen, num_head, head_dim = ir_q_shape
kv_max_seqlen = ir_kv_shape[1]
output_shape = (batch, q_max_seqlen, num_head, head_dim)
softmax_aux_shape = (batch, num_head, q_max_seqlen, kv_max_seqlen)
out_types = [
ir.RankedTensorType.get(output_shape, ir_q_type.element_type),
ir.RankedTensorType.get(softmax_aux_shape, ir_q_type.element_type)
]
operands = [q, kv, q_cu_seqlen, kv_cu_seqlen, rng_state]
operand_shapes = [
ir_q_shape, ir_kv_shape, ir_q_cu_seqlen_shape, ir_kv_cu_seqlen_shape, ir_rng_state_shape
]
args = CustomCallArgsWrapper(out_types, operands, operand_shapes)
opaque = transformer_engine_jax.pack_fused_attn_descriptor(
batch, num_head, q_max_seqlen, kv_max_seqlen, head_dim,
scaling_factor, dropout_probability, attn_bias_type, attn_mask_type,
jax_dtype_to_te_dtype(q_aval.dtype), is_training)
out = custom_caller(CrossFusedAttnMax512FwdPrimitive.name,
args,
opaque,
has_side_effect=False)
return out
_cross_fused_attn_max_512_fwd_p = register_primitive(CrossFusedAttnMax512FwdPrimitive)
def cross_fused_attn_max_512_fwd(q: jnp.ndarray, kv: jnp.ndarray, q_cu_seqlen: jnp.ndarray,
kv_cu_seqlen: jnp.ndarray, rng_state: jnp.ndarray,
attn_bias_type: NVTE_Bias_Type, attn_mask_type: NVTE_Mask_Type,
scaling_factor: float, dropout_probability: float,
is_training: bool):
"""
Wrapper for TE cross fused attention max seqlen 512 fwd
Return BMM1 -> (PreBias) -> ScaleMaskSoftmax -> (PostBias) -> (Dropout) -> BMM2
"""
# Jax can't bind None, create a dummy tensor for None
if rng_state is None:
rng_state = jnp.zeros(2, dtype=jnp.int32)
return _cross_fused_attn_max_512_fwd_p.bind(q,
kv,
q_cu_seqlen,
kv_cu_seqlen,
rng_state,
attn_bias_type=attn_bias_type,
attn_mask_type=attn_mask_type,
scaling_factor=scaling_factor,
dropout_probability=dropout_probability,
is_training=is_training)
class CrossFusedAttnMax512BwdPrimitive(BasePrimitive):
"""
Cross Fused Attention Max Seqlen 512 Backward Primitive
"""
name = "te_cross_fused_attn_max_512_backward"
multiple_results = True
@staticmethod
def abstract(
q,
kv,
softmax_aux,
doutput,
q_cu_seqlen,
kv_cu_seqlen,
*,
attn_bias_type, # pylint: disable=unused-argument
attn_mask_type, # pylint: disable=unused-argument
scaling_factor, # pylint: disable=unused-argument
dropout_probability, # pylint: disable=unused-argument
is_training # pylint: disable=unused-argument
):
"""
Cross fused attention max seqlen 512 bwd abstract
"""
q_dtype = dtypes.canonicalize_dtype(q.dtype)
kv_dtype = dtypes.canonicalize_dtype(kv.dtype)
softmax_aux_dtype = dtypes.canonicalize_dtype(softmax_aux.dtype)
doutput_dtype = dtypes.canonicalize_dtype(doutput.dtype)
assert q_dtype == kv_dtype == softmax_aux_dtype == doutput_dtype
assert q_cu_seqlen.dtype == kv_cu_seqlen.dtype
return (
ShapedArray(q.shape, q_dtype, named_shape=q.named_shape), # dq
ShapedArray(kv.shape, kv_dtype, named_shape=kv.named_shape), # dkv
)
@staticmethod
def lowering(ctx, q, kv, softmax_aux, doutput, q_cu_seqlen, kv_cu_seqlen, *, attn_bias_type,
attn_mask_type, scaling_factor, dropout_probability, is_training):
"""
Cross fused attention max seqlen 512 bwd lowering rules
"""
q_aval, _, _, _, _, _ = ctx.avals_in
ir_q_type = ir.RankedTensorType(q.type)
ir_q_shape = ir_q_type.shape
ir_kv_type = ir.RankedTensorType(kv.type)
ir_kv_shape = ir_kv_type.shape
ir_softmax_aux_type = ir.RankedTensorType(softmax_aux.type)
ir_softmax_aux_shape = ir_softmax_aux_type.shape
ir_doutput_shape = ir.RankedTensorType(doutput.type).shape
ir_q_cu_seqlen_shape = ir.RankedTensorType(q_cu_seqlen.type).shape
ir_kv_cu_seqlen_shape = ir.RankedTensorType(kv_cu_seqlen.type).shape
batch, q_max_seqlen, num_head, head_dim = ir_doutput_shape
kv_max_seqlen = ir_kv_shape[1]
out_types = [
ir.RankedTensorType.get(ir_q_shape, ir_q_type.element_type),
ir.RankedTensorType.get(ir_kv_shape, ir_kv_type.element_type),
]
operands = [q, kv, softmax_aux, doutput, q_cu_seqlen, kv_cu_seqlen]
operand_shapes = [
ir_q_shape, ir_kv_shape, ir_softmax_aux_shape, ir_doutput_shape, ir_q_cu_seqlen_shape,
ir_kv_cu_seqlen_shape
]
args = CustomCallArgsWrapper(out_types, operands, operand_shapes)
opaque = transformer_engine_jax.pack_fused_attn_descriptor(
batch, num_head, q_max_seqlen, kv_max_seqlen, head_dim,
scaling_factor, dropout_probability, attn_bias_type, attn_mask_type,
jax_dtype_to_te_dtype(q_aval.dtype), is_training)
out = custom_caller(CrossFusedAttnMax512BwdPrimitive.name,
args,
opaque,
has_side_effect=False)
return out
_cross_fused_attn_max_512_bwd_p = register_primitive(CrossFusedAttnMax512BwdPrimitive)
def cross_fused_attn_max_512_bwd(q: jnp.ndarray, kv: jnp.ndarray, softmax_aux: jnp.ndarray,
doutput: jnp.ndarray, q_cu_seqlen: jnp.ndarray,
kv_cu_seqlen: jnp.ndarray, attn_bias_type: NVTE_Bias_Type,
attn_mask_type: NVTE_Mask_Type, scaling_factor: float,
dropout_probability: float, is_training: bool):
"""
Wrapper for TE cross fused attention max seqlen 512 bwd
Return the gradients of cross fused attention with packed kv input
"""
return _cross_fused_attn_max_512_bwd_p.bind(q,
kv,
softmax_aux,
doutput,
q_cu_seqlen,
kv_cu_seqlen,
attn_bias_type=attn_bias_type,
attn_mask_type=attn_mask_type,
scaling_factor=scaling_factor,
dropout_probability=dropout_probability,
is_training=is_training)
...@@ -6,6 +6,7 @@ ...@@ -6,6 +6,7 @@
#include <pybind11/pybind11.h> #include <pybind11/pybind11.h>
#include <pybind11/stl.h> #include <pybind11/stl.h>
#include "common/include/transformer_engine/fused_attn.h"
#include "common/include/transformer_engine/transformer_engine.h" #include "common/include/transformer_engine/transformer_engine.h"
#include "jax/csrc/modules.h" #include "jax/csrc/modules.h"
#include "jax/csrc/utils.h" #include "jax/csrc/utils.h"
...@@ -43,6 +44,11 @@ pybind11::dict Registrations() { ...@@ -43,6 +44,11 @@ pybind11::dict Registrations() {
EncapsulateFunction(ScaledUpperTriangMaskedSoftmaxForward); EncapsulateFunction(ScaledUpperTriangMaskedSoftmaxForward);
dict["te_scaled_upper_triang_masked_softmax_backward"] = dict["te_scaled_upper_triang_masked_softmax_backward"] =
EncapsulateFunction(ScaledUpperTriangMaskedSoftmaxBackward); EncapsulateFunction(ScaledUpperTriangMaskedSoftmaxBackward);
dict["te_self_fused_attn_max_512_forward"] = EncapsulateFunction(SelfFusedAttnMax512Forward);
dict["te_self_fused_attn_max_512_backward"] = EncapsulateFunction(SelfFusedAttnMax512Backward);
dict["te_cross_fused_attn_max_512_forward"] = EncapsulateFunction(CrossFusedAttnMax512Forward);
dict["te_cross_fused_attn_max_512_backward"] =
EncapsulateFunction(CrossFusedAttnMax512Backward);
return dict; return dict;
} }
...@@ -52,15 +58,28 @@ PYBIND11_MODULE(transformer_engine_jax, m) { ...@@ -52,15 +58,28 @@ PYBIND11_MODULE(transformer_engine_jax, m) {
m.def("pack_gemm_descriptor", &PackCustomCallGemmDescriptor); m.def("pack_gemm_descriptor", &PackCustomCallGemmDescriptor);
m.def("pack_norm_descriptor", &PackCustomCallNormDescriptor); m.def("pack_norm_descriptor", &PackCustomCallNormDescriptor);
m.def("pack_softmax_descriptor", &PackCustomCallSoftmaxDescriptor); m.def("pack_softmax_descriptor", &PackCustomCallSoftmaxDescriptor);
m.def("pack_fused_attn_descriptor", &PackCustomCallFusedAttnDescriptor);
m.def("is_fused_attn_kernel_available", &IsFusedAttnKernelAvailable);
pybind11::enum_<DType>(m, "DType", pybind11::module_local()) pybind11::enum_<DType>(m, "DType", pybind11::module_local())
.value("kByte", DType::kByte) .value("kByte", DType::kByte)
.value("kInt32", DType::kInt32) .value("kInt32", DType::kInt32)
.value("KInt64", DType::kInt64)
.value("kFloat32", DType::kFloat32) .value("kFloat32", DType::kFloat32)
.value("kFloat16", DType::kFloat16) .value("kFloat16", DType::kFloat16)
.value("kBFloat16", DType::kBFloat16) .value("kBFloat16", DType::kBFloat16)
.value("kFloat8E4M3", DType::kFloat8E4M3) .value("kFloat8E4M3", DType::kFloat8E4M3)
.value("kFloat8E5M2", DType::kFloat8E5M2); .value("kFloat8E5M2", DType::kFloat8E5M2);
pybind11::enum_<NVTE_Bias_Type>(m, "NVTE_Bias_Type", pybind11::module_local())
.value("NVTE_NO_BIAS", NVTE_Bias_Type::NVTE_NO_BIAS)
.value("NVTE_PRE_SCALE_BIAS", NVTE_Bias_Type::NVTE_PRE_SCALE_BIAS)
.value("NVTE_POST_SCALE_BIAS", NVTE_Bias_Type::NVTE_POST_SCALE_BIAS);
pybind11::enum_<NVTE_Mask_Type>(m, "NVTE_Mask_Type", pybind11::module_local())
.value("NVTE_NO_MASK", NVTE_Mask_Type::NVTE_NO_MASK)
.value("NVTE_PADDING_MASK", NVTE_Mask_Type::NVTE_PADDING_MASK)
.value("NVTE_CAUSAL_MASK", NVTE_Mask_Type::NVTE_CAUSAL_MASK);
} }
} // namespace jax } // namespace jax
......
...@@ -9,6 +9,7 @@ ...@@ -9,6 +9,7 @@
#include <cublasLt.h> #include <cublasLt.h>
#include <cublas_v2.h> #include <cublas_v2.h>
#include <cuda_runtime_api.h> #include <cuda_runtime_api.h>
#include <cudnn.h>
#include <functional> #include <functional>
#include <numeric> #include <numeric>
...@@ -19,6 +20,7 @@ ...@@ -19,6 +20,7 @@
#include "common/common.h" #include "common/common.h"
#include "transformer_engine/activation.h" #include "transformer_engine/activation.h"
#include "transformer_engine/cast.h" #include "transformer_engine/cast.h"
#include "transformer_engine/fused_attn.h"
#include "transformer_engine/gemm.h" #include "transformer_engine/gemm.h"
#include "transformer_engine/layer_norm.h" #include "transformer_engine/layer_norm.h"
#include "transformer_engine/rmsnorm.h" #include "transformer_engine/rmsnorm.h"
...@@ -78,6 +80,25 @@ pybind11::bytes PackCustomCallSoftmaxDescriptor(size_t batch, size_t pad_batch, ...@@ -78,6 +80,25 @@ pybind11::bytes PackCustomCallSoftmaxDescriptor(size_t batch, size_t pad_batch,
SoftmaxDescriptor{batch, pad_batch, heads, q_seqlen, k_seqlen, dtype, scale_factor}); SoftmaxDescriptor{batch, pad_batch, heads, q_seqlen, k_seqlen, dtype, scale_factor});
} }
pybind11::bytes PackCustomCallFusedAttnDescriptor(
size_t batch, size_t num_head, size_t q_max_seqlen, size_t kv_max_seqlen, size_t head_dim,
float scaling_factor, float dropout_probability, NVTE_Bias_Type bias_type,
NVTE_Mask_Type mask_type, DType dtype, bool is_training) {
return PackOpaque(CustomCallFusedAttnDescriptor{batch, num_head, q_max_seqlen, kv_max_seqlen,
head_dim, scaling_factor, dropout_probability,
bias_type, mask_type, dtype, is_training});
}
bool IsFusedAttnKernelAvailable() {
#if (CUDNN_VERSION >= 8901)
auto major = cudaDevicePropertiesManager::Instance().GetMajor();
// Fused attention requires at least Ampere
return major >= 8;
#else
return false;
#endif
}
void TransposeImpl(void *input, size_t rows, size_t cols, DType dtype, cudaStream_t stream, void TransposeImpl(void *input, size_t rows, size_t cols, DType dtype, cudaStream_t stream,
void *output) { void *output) {
auto input_shape = std::vector<size_t>{rows, cols}; auto input_shape = std::vector<size_t>{rows, cols};
...@@ -718,5 +739,333 @@ void ScaledUpperTriangMaskedSoftmaxBackward(cudaStream_t stream, void **buffers, ...@@ -718,5 +739,333 @@ void ScaledUpperTriangMaskedSoftmaxBackward(cudaStream_t stream, void **buffers,
grad_output_tensor.data(), softmax_output_tensor.data(), dgrad_tensor.data(), grad_output_tensor.data(), softmax_output_tensor.data(), dgrad_tensor.data(),
desc.scale_factor, stream); desc.scale_factor, stream);
} }
void SelfFusedAttnMax512Forward(cudaStream_t stream, void **buffers, const char *opaque,
size_t opaque_len) {
const CustomCallFusedAttnDescriptor &descriptor =
*UnpackOpaque<CustomCallFusedAttnDescriptor>(opaque, opaque_len);
// input
void *qkv = buffers[0];
void *bias = buffers[1];
void *cu_seqlens = buffers[2];
void *rng_state = buffers[3];
// output
void *output = buffers[4];
void *softmax_aux = buffers[5];
auto batch = descriptor.batch;
auto num_head = descriptor.num_head;
auto q_max_seqlen = descriptor.q_max_seqlen;
auto kv_max_seqlen = descriptor.kv_max_seqlen;
auto head_dim = descriptor.head_dim;
NVTE_CHECK(q_max_seqlen == kv_max_seqlen,
"q_max_seqlen should be equal to kv_max_seqlen in the self attention.");
auto dtype = descriptor.dtype;
auto qkv_shape = std::vector<size_t>{batch * q_max_seqlen, 3, num_head, head_dim};
auto bias_shape = std::vector<size_t>{1, num_head, q_max_seqlen, kv_max_seqlen};
auto qkv_tensor = TensorWrapper(qkv, qkv_shape, dtype);
auto bias_tensor = TensorWrapper(bias, bias_shape, dtype);
// FP16/BF16 doesn't use this tensor
auto s_tensor = TensorWrapper(nullptr, std::vector<size_t>{1}, dtype);
auto o_tensor =
TensorWrapper(output, std::vector<size_t>{batch * q_max_seqlen, num_head, head_dim}, dtype);
auto cu_seqlens_tensor =
TensorWrapper(cu_seqlens, std::vector<size_t>{batch + 1}, DType::kInt32);
auto rng_state_tensor = TensorWrapper(rng_state, std::vector<size_t>{1}, DType::kInt64);
NVTETensorPack aux_output_tensors;
nvte_tensor_pack_create(&aux_output_tensors);
TensorWrapper query_workspace_tensor;
nvte_fused_attn_fwd_qkvpacked(qkv_tensor.data(), bias_tensor.data(), s_tensor.data(),
o_tensor.data(), &aux_output_tensors, cu_seqlens_tensor.data(),
rng_state_tensor.data(), q_max_seqlen, descriptor.is_training,
descriptor.scaling_factor, descriptor.dropout_probability,
NVTE_QKV_Layout::NVTE_QKV_INTERLEAVED, descriptor.bias_type,
descriptor.mask_type, query_workspace_tensor.data(), stream);
auto *output_s = reinterpret_cast<Tensor *>(aux_output_tensors.tensors[0]);
output_s->data.dptr = softmax_aux;
size_t workspace_size =
query_workspace_tensor.shape().data[0] * typeToSize(query_workspace_tensor.dtype());
auto *workspace = cublasLtMetaManager::Instance().GetWorkspace(workspace_size);
auto workspace_tensor =
TensorWrapper(workspace, query_workspace_tensor.shape(), query_workspace_tensor.dtype());
nvte_fused_attn_fwd_qkvpacked(qkv_tensor.data(), bias_tensor.data(), s_tensor.data(),
o_tensor.data(), &aux_output_tensors, cu_seqlens_tensor.data(),
rng_state_tensor.data(), q_max_seqlen, descriptor.is_training,
descriptor.scaling_factor, descriptor.dropout_probability,
NVTE_QKV_Layout::NVTE_QKV_INTERLEAVED, descriptor.bias_type,
descriptor.mask_type, workspace_tensor.data(), stream);
nvte_tensor_pack_destroy(&aux_output_tensors);
}
void SelfFusedAttnMax512Backward(cudaStream_t stream, void **buffers, const char *opaque,
size_t opaque_len) {
const CustomCallFusedAttnDescriptor &descriptor =
*UnpackOpaque<CustomCallFusedAttnDescriptor>(opaque, opaque_len);
// input
void *qkv = buffers[0];
void *softmax_aux = buffers[1];
void *doutput = buffers[2];
void *cu_seqlens = buffers[3];
// output
void *dqkv = buffers[4];
void *dp = softmax_aux;
void *dbias = buffers[5];
auto batch = descriptor.batch;
auto num_head = descriptor.num_head;
auto q_max_seqlen = descriptor.q_max_seqlen;
auto kv_max_seqlen = descriptor.kv_max_seqlen;
auto head_dim = descriptor.head_dim;
NVTE_CHECK(q_max_seqlen == kv_max_seqlen,
"q_max_seqlen should be equal to kv_max_seqlen in the self attention.");
auto dtype = descriptor.dtype;
auto qkv_shape = std::vector<size_t>{batch * q_max_seqlen, 3, num_head, head_dim};
auto output_shape = std::vector<size_t>{batch * q_max_seqlen, num_head, head_dim};
auto bias_shape = std::vector<size_t>{1, num_head, q_max_seqlen, kv_max_seqlen};
auto qkv_tensor = TensorWrapper(qkv, qkv_shape, dtype);
auto doutput_tensor = TensorWrapper(doutput, output_shape, dtype);
// It's a little trick that the flash attn needs fwd output
// But when seqlen <= 512, it is not needed
auto output_tensor = TensorWrapper(nullptr, output_shape, dtype);
// FP16/BF16 doesn't use this tensor
auto s_tensor = TensorWrapper(nullptr, std::vector<size_t>{1}, dtype);
auto dqkv_tensor = TensorWrapper(dqkv, qkv_shape, dtype);
auto dbias_tensor = TensorWrapper(dbias, bias_shape, dtype);
auto cu_seqlens_tensor =
TensorWrapper(cu_seqlens, std::vector<size_t>{batch + 1}, DType::kInt32);
// Currently, no rng_state required for bwd
auto rng_state = TensorWrapper(nullptr, std::vector<size_t>{1}, DType::kInt64);
// TODO: needs to think about how to pass aux_output_tensors
NVTETensorPack aux_output_tensors;
nvte_tensor_pack_create(&aux_output_tensors);
aux_output_tensors.size = 1;
auto *output_s = reinterpret_cast<Tensor *>(aux_output_tensors.tensors[0]);
output_s->data.shape = std::vector<size_t>{batch, num_head, q_max_seqlen, kv_max_seqlen};
output_s->data.dptr = softmax_aux;
TensorWrapper query_workspace_tensor;
nvte_fused_attn_bwd_qkvpacked(qkv_tensor.data(), output_tensor.data(), doutput_tensor.data(),
s_tensor.data(), // not used for FP16/BF16
s_tensor.data(), // not used for FP16/BF16
&aux_output_tensors, dqkv_tensor.data(), dbias_tensor.data(),
cu_seqlens_tensor.data(), q_max_seqlen, descriptor.scaling_factor,
descriptor.dropout_probability,
NVTE_QKV_Layout::NVTE_QKV_INTERLEAVED, descriptor.bias_type,
descriptor.mask_type, query_workspace_tensor.data(), stream);
size_t workspace_size =
query_workspace_tensor.shape().data[0] * typeToSize(query_workspace_tensor.dtype());
auto *workspace = cublasLtMetaManager::Instance().GetWorkspace(workspace_size);
auto workspace_tensor =
TensorWrapper(workspace, query_workspace_tensor.shape(), query_workspace_tensor.dtype());
nvte_fused_attn_bwd_qkvpacked(qkv_tensor.data(), output_tensor.data(), doutput_tensor.data(),
s_tensor.data(), // not used for FP16/BF16
s_tensor.data(), // not used for FP16/BF16
&aux_output_tensors, dqkv_tensor.data(), dbias_tensor.data(),
cu_seqlens_tensor.data(), q_max_seqlen, descriptor.scaling_factor,
descriptor.dropout_probability,
NVTE_QKV_Layout::NVTE_QKV_INTERLEAVED, descriptor.bias_type,
descriptor.mask_type, workspace_tensor.data(), stream);
nvte_tensor_pack_destroy(&aux_output_tensors);
}
void CrossFusedAttnMax512Forward(cudaStream_t stream, void **buffers, const char *opaque,
size_t opaque_len) {
const CustomCallFusedAttnDescriptor &descriptor =
*UnpackOpaque<CustomCallFusedAttnDescriptor>(opaque, opaque_len);
// input
void *q = buffers[0];
void *kv = buffers[1];
void *q_cu_seqlens = buffers[2];
void *kv_cu_seqlens = buffers[3];
void *rng_state = buffers[4];
// output
void *output = buffers[5];
void *softmax_aux = buffers[6];
auto batch = descriptor.batch;
auto num_head = descriptor.num_head;
auto q_max_seqlen = descriptor.q_max_seqlen;
auto kv_max_seqlen = descriptor.kv_max_seqlen;
auto head_dim = descriptor.head_dim;
auto dtype = descriptor.dtype;
auto q_shape = std::vector<size_t>{batch * q_max_seqlen, num_head, head_dim};
auto kv_shape = std::vector<size_t>{batch * kv_max_seqlen, 2, num_head, head_dim};
auto bias_shape = std::vector<size_t>{1, num_head, q_max_seqlen, kv_max_seqlen};
auto q_tensor = TensorWrapper(q, q_shape, dtype);
auto kv_tensor = TensorWrapper(kv, kv_shape, dtype);
// TODO(rewang): add bias for cross attn?
auto bias_tensor = TensorWrapper(nullptr, bias_shape, dtype);
// FP16/BF16 doesn't use this tensor
auto s_tensor = TensorWrapper(nullptr, std::vector<size_t>{1}, dtype);
auto o_tensor =
TensorWrapper(output, std::vector<size_t>{batch * q_max_seqlen, num_head, head_dim}, dtype);
auto q_cu_seqlens_tensor =
TensorWrapper(q_cu_seqlens, std::vector<size_t>{batch + 1}, DType::kInt32);
auto kv_cu_seqlens_tensor =
TensorWrapper(kv_cu_seqlens, std::vector<size_t>{batch + 1}, DType::kInt32);
auto rng_state_tensor = TensorWrapper(rng_state, std::vector<size_t>{1}, DType::kInt64);
NVTETensorPack aux_output_tensors;
nvte_tensor_pack_create(&aux_output_tensors);
TensorWrapper query_workspace_tensor;
nvte_fused_attn_fwd_kvpacked(
q_tensor.data(), kv_tensor.data(), bias_tensor.data(), s_tensor.data(), o_tensor.data(),
&aux_output_tensors, q_cu_seqlens_tensor.data(), kv_cu_seqlens_tensor.data(),
rng_state_tensor.data(), q_max_seqlen, kv_max_seqlen, descriptor.is_training,
descriptor.scaling_factor, descriptor.dropout_probability,
NVTE_QKV_Layout::NVTE_KV_INTERLEAVED, descriptor.bias_type, descriptor.mask_type,
query_workspace_tensor.data(), stream);
auto *output_s = reinterpret_cast<Tensor *>(aux_output_tensors.tensors[0]);
output_s->data.dptr = softmax_aux;
size_t workspace_size =
query_workspace_tensor.shape().data[0] * typeToSize(query_workspace_tensor.dtype());
auto *workspace = cublasLtMetaManager::Instance().GetWorkspace(workspace_size);
auto workspace_tensor =
TensorWrapper(workspace, query_workspace_tensor.shape(), query_workspace_tensor.dtype());
nvte_fused_attn_fwd_kvpacked(
q_tensor.data(), kv_tensor.data(), bias_tensor.data(), s_tensor.data(), o_tensor.data(),
&aux_output_tensors, q_cu_seqlens_tensor.data(), kv_cu_seqlens_tensor.data(),
rng_state_tensor.data(), q_max_seqlen, kv_max_seqlen, descriptor.is_training,
descriptor.scaling_factor, descriptor.dropout_probability,
NVTE_QKV_Layout::NVTE_KV_INTERLEAVED, descriptor.bias_type, descriptor.mask_type,
workspace_tensor.data(), stream);
nvte_tensor_pack_destroy(&aux_output_tensors);
}
void CrossFusedAttnMax512Backward(cudaStream_t stream, void **buffers, const char *opaque,
size_t opaque_len) {
const CustomCallFusedAttnDescriptor &descriptor =
*UnpackOpaque<CustomCallFusedAttnDescriptor>(opaque, opaque_len);
// input
void *q = buffers[0];
void *kv = buffers[1];
void *softmax_aux = buffers[2];
void *doutput = buffers[3];
void *q_cu_seqlens = buffers[4];
void *kv_cu_seqlens = buffers[5];
// output
void *dq = buffers[6];
void *dkv = buffers[7];
void *dp = softmax_aux;
auto batch = descriptor.batch;
auto num_head = descriptor.num_head;
auto q_max_seqlen = descriptor.q_max_seqlen;
auto kv_max_seqlen = descriptor.kv_max_seqlen;
auto head_dim = descriptor.head_dim;
auto dtype = descriptor.dtype;
auto q_shape = std::vector<size_t>{batch * q_max_seqlen, num_head, head_dim};
auto kv_shape = std::vector<size_t>{batch * kv_max_seqlen, 2, num_head, head_dim};
auto output_shape = std::vector<size_t>{batch * q_max_seqlen, num_head, head_dim};
auto bias_shape = std::vector<size_t>{1, num_head, q_max_seqlen, kv_max_seqlen};
auto q_tensor = TensorWrapper(q, q_shape, dtype);
auto kv_tensor = TensorWrapper(kv, kv_shape, dtype);
auto doutput_tensor = TensorWrapper(doutput, output_shape, dtype);
// It's a little trick that the flash attn needs fwd output
// But when seqlen <= 512, it is not needed
auto output_tensor = TensorWrapper(nullptr, output_shape, dtype);
// FP16/BF16 doesn't use this tensor
auto s_tensor = TensorWrapper(nullptr, std::vector<size_t>{1}, dtype);
auto dq_tensor = TensorWrapper(dq, q_shape, dtype);
auto dkv_tensor = TensorWrapper(dkv, kv_shape, dtype);
// TODO(rewang): generalize cross attn
auto dbias_tensor = TensorWrapper(nullptr, bias_shape, dtype);
auto q_cu_seqlens_tensor =
TensorWrapper(q_cu_seqlens, std::vector<size_t>{batch + 1}, DType::kInt32);
auto kv_cu_seqlens_tensor =
TensorWrapper(kv_cu_seqlens, std::vector<size_t>{batch + 1}, DType::kInt32);
// Currently, no rng_state required for bwd
auto rng_state = TensorWrapper(nullptr, std::vector<size_t>{1}, DType::kInt64);
// TODO(rewang): need to think about how to pass aux_output_tensors
NVTETensorPack aux_output_tensors;
nvte_tensor_pack_create(&aux_output_tensors);
aux_output_tensors.size = 1;
auto *output_s = reinterpret_cast<Tensor *>(aux_output_tensors.tensors[0]);
output_s->data.shape = std::vector<size_t>{batch * num_head, q_max_seqlen, kv_max_seqlen};
output_s->data.dptr = softmax_aux;
TensorWrapper query_workspace_tensor;
nvte_fused_attn_bwd_kvpacked(
q_tensor.data(), kv_tensor.data(), output_tensor.data(), doutput_tensor.data(),
s_tensor.data(), // not used for FP16/BF16
s_tensor.data(), // not used for FP16/BF16
&aux_output_tensors, dq_tensor.data(), dkv_tensor.data(), dbias_tensor.data(),
q_cu_seqlens_tensor.data(), kv_cu_seqlens_tensor.data(), q_max_seqlen, kv_max_seqlen,
descriptor.scaling_factor, descriptor.dropout_probability,
NVTE_QKV_Layout::NVTE_KV_INTERLEAVED, descriptor.bias_type, descriptor.mask_type,
query_workspace_tensor.data(), stream);
size_t workspace_size =
query_workspace_tensor.shape().data[0] * typeToSize(query_workspace_tensor.dtype());
auto *workspace = cublasLtMetaManager::Instance().GetWorkspace(workspace_size);
auto workspace_tensor =
TensorWrapper(workspace, query_workspace_tensor.shape(), query_workspace_tensor.dtype());
nvte_fused_attn_bwd_kvpacked(
q_tensor.data(), kv_tensor.data(), output_tensor.data(), doutput_tensor.data(),
s_tensor.data(), // not used for FP16/BF16
s_tensor.data(), // not used for FP16/BF16
&aux_output_tensors, dq_tensor.data(), dkv_tensor.data(), dbias_tensor.data(),
q_cu_seqlens_tensor.data(), kv_cu_seqlens_tensor.data(), q_max_seqlen, kv_max_seqlen,
descriptor.scaling_factor, descriptor.dropout_probability,
NVTE_QKV_Layout::NVTE_KV_INTERLEAVED, descriptor.bias_type, descriptor.mask_type,
workspace_tensor.data(), stream);
nvte_tensor_pack_destroy(&aux_output_tensors);
}
} // namespace jax } // namespace jax
} // namespace transformer_engine } // namespace transformer_engine
...@@ -17,6 +17,7 @@ ...@@ -17,6 +17,7 @@
#include <pybind11/pybind11.h> #include <pybind11/pybind11.h>
#include <pybind11/stl.h> #include <pybind11/stl.h>
#include "transformer_engine/fused_attn.h"
#include "transformer_engine/logging.h" #include "transformer_engine/logging.h"
#include "transformer_engine/transformer_engine.h" #include "transformer_engine/transformer_engine.h"
...@@ -94,6 +95,27 @@ pybind11::bytes PackCustomCallSoftmaxDescriptor(size_t batch, size_t pad_batch, ...@@ -94,6 +95,27 @@ pybind11::bytes PackCustomCallSoftmaxDescriptor(size_t batch, size_t pad_batch,
size_t q_seqlen, size_t k_seqlen, DType dtype, size_t q_seqlen, size_t k_seqlen, DType dtype,
float scale_factor); float scale_factor);
struct CustomCallFusedAttnDescriptor {
size_t batch;
size_t num_head;
size_t q_max_seqlen;
size_t kv_max_seqlen;
size_t head_dim;
float scaling_factor;
float dropout_probability;
NVTE_Bias_Type bias_type;
NVTE_Mask_Type mask_type;
DType dtype;
bool is_training;
};
pybind11::bytes PackCustomCallFusedAttnDescriptor(
size_t batch, size_t num_head, size_t q_max_seqlen, size_t kv_max_seqlen, size_t head_dim,
float scaling_factor, float dropout_probability, NVTE_Bias_Type bias_type,
NVTE_Mask_Type mask_type, DType dtype, bool is_training);
bool IsFusedAttnKernelAvailable();
void Transpose(cudaStream_t stream, void **buffers, const char *opaque, size_t opaque_len); void Transpose(cudaStream_t stream, void **buffers, const char *opaque, size_t opaque_len);
void CastTranspose(cudaStream_t stream, void **buffers, const char *opaque, size_t opaque_len); void CastTranspose(cudaStream_t stream, void **buffers, const char *opaque, size_t opaque_len);
...@@ -144,6 +166,18 @@ void ScaledUpperTriangMaskedSoftmaxForward(cudaStream_t stream, void **buffers, ...@@ -144,6 +166,18 @@ void ScaledUpperTriangMaskedSoftmaxForward(cudaStream_t stream, void **buffers,
void ScaledUpperTriangMaskedSoftmaxBackward(cudaStream_t stream, void **buffers, const char *opaque, void ScaledUpperTriangMaskedSoftmaxBackward(cudaStream_t stream, void **buffers, const char *opaque,
std::size_t opaque_len); std::size_t opaque_len);
void SelfFusedAttnMax512Forward(cudaStream_t stream, void **buffers, const char *opaque,
size_t opaque_len);
void SelfFusedAttnMax512Backward(cudaStream_t stream, void **buffers, const char *opaque,
size_t opaque_len);
void CrossFusedAttnMax512Forward(cudaStream_t stream, void **buffers, const char *opaque,
size_t opaque_len);
void CrossFusedAttnMax512Backward(cudaStream_t stream, void **buffers, const char *opaque,
size_t opaque_len);
} // namespace jax } // namespace jax
} // namespace transformer_engine } // namespace transformer_engine
......
...@@ -75,6 +75,16 @@ class cudaDevicePropertiesManager { ...@@ -75,6 +75,16 @@ class cudaDevicePropertiesManager {
return prop_.multiProcessorCount; return prop_.multiProcessorCount;
} }
int GetMajor() {
if (!prop_queried_) {
int device_id;
NVTE_CHECK_CUDA(cudaGetDevice(&device_id));
cudaGetDeviceProperties(&prop_, device_id);
prop_queried_ = true;
}
return prop_.major;
}
private: private:
bool prop_queried_ = false; bool prop_queried_ = false;
cudaDeviceProp prop_; cudaDeviceProp prop_;
......
...@@ -6,18 +6,24 @@ Wrapper module for Transformer related layers with FP8 support. ...@@ -6,18 +6,24 @@ Wrapper module for Transformer related layers with FP8 support.
""" """
import functools import functools
from enum import Enum from enum import Enum
from math import sqrt
from typing import Any, Callable, Optional, Sequence, Tuple, Union from typing import Any, Callable, Optional, Sequence, Tuple, Union
import warnings
import jax.numpy as jnp import jax.numpy as jnp
import numpy as np import numpy as np
from flax import linen as nn from flax import linen as nn
from flax.linen import partitioning as nn_partitioning from flax.linen import partitioning as nn_partitioning
from jax import dtypes
from jax import nn as jax_nn from jax import nn as jax_nn
from jax import random as jax_random from jax import random as jax_random
from jax import lax, vmap from jax import lax, vmap
from .module import DenseGeneral, LayerNormDenseGeneral, LayerNormMLP from .module import DenseGeneral, LayerNormDenseGeneral, LayerNormMLP
from .module import LayerNorm, Softmax from .module import LayerNorm, Softmax
from ..fused_attn import AttnBiasType, AttnMaskType
from ..fused_attn import is_fused_attn_kernel_available
from ..fused_attn import self_fused_attn, cross_fused_attn
from ..softmax import SoftmaxType from ..softmax import SoftmaxType
from ..sharding import infer_major_sharding_type, infer_sharding_type from ..sharding import infer_major_sharding_type, infer_sharding_type
from ..sharding import global_shard_resource, ShardingType from ..sharding import global_shard_resource, ShardingType
...@@ -129,6 +135,7 @@ def combine_biases(*masks: Optional[Array]): ...@@ -129,6 +135,7 @@ def combine_biases(*masks: Optional[Array]):
def core_attention(query: Array, def core_attention(query: Array,
key: Array, key: Array,
value: Array, value: Array,
scale_factor: float,
transpose_batch_sequence: bool, transpose_batch_sequence: bool,
softmax_type: SoftmaxType = SoftmaxType.SCALED, softmax_type: SoftmaxType = SoftmaxType.SCALED,
softmax_sharding_type: ShardingType = ShardingType.SINGLE, softmax_sharding_type: ShardingType = ShardingType.SINGLE,
...@@ -159,6 +166,7 @@ def core_attention(query: Array, ...@@ -159,6 +166,7 @@ def core_attention(query: Array,
attn_weights = jnp.einsum('bqhd,bkhd->bhqk', query, key) attn_weights = jnp.einsum('bqhd,bkhd->bhqk', query, key)
attn_weights = Softmax(softmax_type=softmax_type, attn_weights = Softmax(softmax_type=softmax_type,
scale_factor=scale_factor,
sharding_type=softmax_sharding_type)(attn_weights, mask, bias) sharding_type=softmax_sharding_type)(attn_weights, mask, bias)
if not deterministic and dropout_rate > 0.: if not deterministic and dropout_rate > 0.:
...@@ -181,8 +189,8 @@ dynamic_vector_slice_in_dim = vmap(lax.dynamic_slice_in_dim, in_axes=(None, 0, N ...@@ -181,8 +189,8 @@ dynamic_vector_slice_in_dim = vmap(lax.dynamic_slice_in_dim, in_axes=(None, 0, N
class AttentionType(Enum): class AttentionType(Enum):
"""TransformerLayerType.""" """TransformerLayerType."""
PADDING = "padding_attention" PADDING = AttnMaskType.PADDING_MASK
CAUSAL = "causal_attention" CAUSAL = AttnMaskType.CAUSAL_MASK
class MultiHeadAttention(nn.Module): class MultiHeadAttention(nn.Module):
...@@ -312,9 +320,8 @@ class MultiHeadAttention(nn.Module): ...@@ -312,9 +320,8 @@ class MultiHeadAttention(nn.Module):
Output tensors. Output tensors.
""" """
depth_scaling = jnp.sqrt(self.head_dim).astype(self.dtype)
def query_init(*args): def query_init(*args):
depth_scaling = jnp.sqrt(self.head_dim).astype(self.dtype)
return self.kernel_init(*args) / (depth_scaling if self.scaled_query_init else 1.0) return self.kernel_init(*args) / (depth_scaling if self.scaled_query_init else 1.0)
def qkv_init(key, shape, dtype): def qkv_init(key, shape, dtype):
...@@ -349,6 +356,43 @@ class MultiHeadAttention(nn.Module): ...@@ -349,6 +356,43 @@ class MultiHeadAttention(nn.Module):
first_sharding_type, second_sharding_type = infer_sharding_type() first_sharding_type, second_sharding_type = infer_sharding_type()
canonicalize_dtype = dtypes.canonicalize_dtype(self.dtype)
q_seqlen = inputs_q.shape[0] if self.transpose_batch_sequence else inputs_q.shape[1]
kv_seqlen = inputs_kv.shape[0] if self.transpose_batch_sequence else inputs_kv.shape[1]
fused_attn_supported_seqlen = [128, 256, 384, 512]
use_fused_attn = not decode and not self.transpose_batch_sequence and self.fuse_qkv and \
self.dropout_rate == 0 and canonicalize_dtype in [jnp.bfloat16, jnp.float16] and \
q_seqlen in fused_attn_supported_seqlen and kv_seqlen in fused_attn_supported_seqlen \
and is_fused_attn_kernel_available()
if not use_fused_attn:
reason = ""
if decode:
reason += f"decode=False is required but got {decode}, "
if self.transpose_batch_sequence:
reason += f"transpose_batch_sequence=False is required " \
f"but got {self.transpose_batch_sequence}, "
if not self.fuse_qkv:
reason += f"fuse_qkv=True is required but got {self.fuse_qkv}, "
if self.dropout_rate != 0:
# TODO(rewang): add dropout support
reason += f"no dropout is required but got dropout_rate={self.dropout_rate}, "
if canonicalize_dtype not in [jnp.bfloat16, jnp.float16]:
reason += f"dtype in [BF16, FP16] is required " \
f"but got dtype={canonicalize_dtype}, "
if q_seqlen not in fused_attn_supported_seqlen:
reason += f"q_seqlen in {fused_attn_supported_seqlen} is required " \
f"but got {q_seqlen=}, "
if kv_seqlen not in fused_attn_supported_seqlen:
reason += f"kv_seqlen in {fused_attn_supported_seqlen} is required " \
f"but got {kv_seqlen=}, "
if not is_fused_attn_kernel_available():
reason += "GPU arch >= Ampere and cuDNN >= 8.9.1 are required, "
warnings.warn(
f"Fused attention is not enabled, " \
f"{reason}fall back to unfused attention")
residual = inputs_q residual = inputs_q
if self.fuse_qkv: if self.fuse_qkv:
if inputs_q is inputs_kv: if inputs_q is inputs_kv:
...@@ -369,12 +413,8 @@ class MultiHeadAttention(nn.Module): ...@@ -369,12 +413,8 @@ class MultiHeadAttention(nn.Module):
bias_init=self.bias_init, bias_init=self.bias_init,
name='qkv', name='qkv',
dtype=self.dtype)(inputs_q) dtype=self.dtype)(inputs_q)
query, key, value = jnp.split(qkv_proj, [1, 2], axis=-2) if not use_fused_attn:
query = jnp.reshape(query, (*query.shape[:-2], -1)) query, key, value = jnp.split(qkv_proj, [1, 2], axis=-2)
key = jnp.reshape(key, (*key.shape[:-2], -1))
value = jnp.reshape(value, (*value.shape[:-2], -1))
if self.scale_attn_logits:
query = query / depth_scaling
else: else:
query, ln_out = LayerNormDenseGeneral( query, ln_out = LayerNormDenseGeneral(
enable_layernorm=not self.output_layernorm, enable_layernorm=not self.output_layernorm,
...@@ -386,7 +426,6 @@ class MultiHeadAttention(nn.Module): ...@@ -386,7 +426,6 @@ class MultiHeadAttention(nn.Module):
sharding_type=first_sharding_type, sharding_type=first_sharding_type,
transpose_batch_sequence=self.transpose_batch_sequence, transpose_batch_sequence=self.transpose_batch_sequence,
return_layernorm_output=self.apply_residual_connection_post_layernorm, return_layernorm_output=self.apply_residual_connection_post_layernorm,
depth_scaling=depth_scaling if self.scale_attn_logits else None,
scale_axes=('embed',), scale_axes=('embed',),
kernel_axes=('embed', 'joined_kv'), kernel_axes=('embed', 'joined_kv'),
use_bias=self.use_bias, use_bias=self.use_bias,
...@@ -404,11 +443,8 @@ class MultiHeadAttention(nn.Module): ...@@ -404,11 +443,8 @@ class MultiHeadAttention(nn.Module):
bias_init=self.bias_init, bias_init=self.bias_init,
name='kv', name='kv',
dtype=self.dtype)(inputs_kv) dtype=self.dtype)(inputs_kv)
key, value = jnp.split(kv_proj, [ if not use_fused_attn:
1, key, value = jnp.split(kv_proj, [1], axis=-2)
], axis=-2)
key = jnp.reshape(key, (*key.shape[:-2], -1))
value = jnp.reshape(value, (*value.shape[:-2], -1))
else: else:
kv_projection = functools.partial( kv_projection = functools.partial(
DenseGeneral, DenseGeneral,
...@@ -430,7 +466,6 @@ class MultiHeadAttention(nn.Module): ...@@ -430,7 +466,6 @@ class MultiHeadAttention(nn.Module):
sharding_type=first_sharding_type, sharding_type=first_sharding_type,
transpose_batch_sequence=self.transpose_batch_sequence, transpose_batch_sequence=self.transpose_batch_sequence,
return_layernorm_output=True, return_layernorm_output=True,
depth_scaling=depth_scaling if self.scale_attn_logits else None,
scale_axes=('embed',), scale_axes=('embed',),
kernel_axes=('embed', 'joined_kv'), kernel_axes=('embed', 'joined_kv'),
use_bias=self.use_bias, use_bias=self.use_bias,
...@@ -446,21 +481,21 @@ class MultiHeadAttention(nn.Module): ...@@ -446,21 +481,21 @@ class MultiHeadAttention(nn.Module):
key = kv_projection(kernel_init=self.kernel_init, name='key')(inputs_kv) key = kv_projection(kernel_init=self.kernel_init, name='key')(inputs_kv)
value = kv_projection(kernel_init=self.kernel_init, name='value')(inputs_kv) value = kv_projection(kernel_init=self.kernel_init, name='value')(inputs_kv)
query = query.reshape((query.shape[0], query.shape[1], self.num_heads, self.head_dim))
key = key.reshape((key.shape[0], key.shape[1], self.num_heads, self.head_dim))
value = value.reshape((value.shape[0], value.shape[1], self.num_heads, self.head_dim))
if self.apply_residual_connection_post_layernorm: if self.apply_residual_connection_post_layernorm:
assert ln_out is not None assert ln_out is not None
residual = ln_out residual = ln_out
qkv_sharding_constraint = \ if not use_fused_attn:
('length', 'batch', 'heads','kv') \ query = query.reshape((query.shape[0], query.shape[1], self.num_heads, self.head_dim))
if self.transpose_batch_sequence \ key = key.reshape((key.shape[0], key.shape[1], self.num_heads, self.head_dim))
else ('batch', 'length', 'heads', 'kv') value = value.reshape((value.shape[0], value.shape[1], self.num_heads, self.head_dim))
query = nn_partitioning.with_sharding_constraint(query, qkv_sharding_constraint) qkv_sharding_constraint = \
key = nn_partitioning.with_sharding_constraint(key, qkv_sharding_constraint) ('length', 'batch', 'heads','kv') \
value = nn_partitioning.with_sharding_constraint(value, qkv_sharding_constraint) if self.transpose_batch_sequence \
else ('batch', 'length', 'heads', 'kv')
query = nn_partitioning.with_sharding_constraint(query, qkv_sharding_constraint)
key = nn_partitioning.with_sharding_constraint(key, qkv_sharding_constraint)
value = nn_partitioning.with_sharding_constraint(value, qkv_sharding_constraint)
if decode: if decode:
is_initialized = self.has_variable('cache', 'cached_key') is_initialized = self.has_variable('cache', 'cached_key')
...@@ -502,30 +537,74 @@ class MultiHeadAttention(nn.Module): ...@@ -502,30 +537,74 @@ class MultiHeadAttention(nn.Module):
bias = dynamic_vector_slice_in_dim(jnp.squeeze(bias, axis=0), bias = dynamic_vector_slice_in_dim(jnp.squeeze(bias, axis=0),
jnp.reshape(cur_index, (-1)), 1, -2) jnp.reshape(cur_index, (-1)), 1, -2)
scale_factor = 1.0 / sqrt(self.head_dim) if self.scale_attn_logits else 1.0
dropout_rng = None dropout_rng = None
if not deterministic and self.dropout_rate > 0.: if not deterministic and self.dropout_rate > 0.:
dropout_rng = self.make_rng(self.dropout_rng_name) dropout_rng = self.make_rng(self.dropout_rng_name)
softmax_type = SoftmaxType.SCALED if use_fused_attn:
if self.attn_type is AttentionType.PADDING: assert mask is not None and mask.ndim == 4 # (b, 1, s_q, s_kv)
if mask is not None: assert not self.transpose_batch_sequence
softmax_type = SoftmaxType.SCALED_MASKED # TODO(rewang): make it configurable for pre_scale_bias
else: attn_bias_type = AttnBiasType.NO_BIAS if bias is None else AttnBiasType.POST_SCALE_BIAS
softmax_type = SoftmaxType.SCALED_UPPER_TRIANG_MASKED
x = core_attention(query, if inputs_q is inputs_kv:
key, qkv_proj = qkv_proj.reshape((*qkv_proj.shape[:-1], self.num_heads, self.head_dim))
value, qkv_sharding_constraint = ('batch', 'length', 'qkv_dim', 'heads', 'kv')
transpose_batch_sequence=self.transpose_batch_sequence, qkv_proj = nn_partitioning.with_sharding_constraint(qkv_proj,
softmax_type=softmax_type, qkv_sharding_constraint)
softmax_sharding_type=first_sharding_type, x = self_fused_attn(qkv_proj,
mask=mask, bias,
bias=bias, mask,
dropout_rng=dropout_rng, dropout_rng,
dropout_rate=self.dropout_rate, attn_bias_type=attn_bias_type,
deterministic=deterministic, attn_mask_type=self.attn_type.value,
dtype=self.dtype, scaling_factor=scale_factor,
float32_logits=self.float32_logits) dropout_probability=self.dropout_rate,
is_training=not deterministic,
sharding_type=first_sharding_type)
else:
assert bias is None
query = query.reshape((*query.shape[:-1], self.num_heads, self.head_dim))
kv_proj = kv_proj.reshape((*kv_proj.shape[:-1], self.num_heads, self.head_dim))
q_sharding_constraint = ('batch', 'length', 'heads', 'kv')
kv_sharding_constraint = ('batch', 'length', 'kv_dim', 'heads', 'kv')
query = nn_partitioning.with_sharding_constraint(query, q_sharding_constraint)
kv_proj = nn_partitioning.with_sharding_constraint(kv_proj, kv_sharding_constraint)
x = cross_fused_attn(query,
kv_proj,
mask,
dropout_rng,
attn_bias_type=attn_bias_type,
attn_mask_type=self.attn_type.value,
scaling_factor=scale_factor,
dropout_probability=self.dropout_rate,
is_training=not deterministic,
sharding_type=first_sharding_type)
else:
softmax_type = SoftmaxType.SCALED
if self.attn_type is AttentionType.PADDING:
if mask is not None:
softmax_type = SoftmaxType.SCALED_MASKED
else:
softmax_type = SoftmaxType.SCALED_UPPER_TRIANG_MASKED
x = core_attention(query,
key,
value,
scale_factor=scale_factor,
transpose_batch_sequence=self.transpose_batch_sequence,
softmax_type=softmax_type,
softmax_sharding_type=first_sharding_type,
mask=mask,
bias=bias,
dropout_rng=dropout_rng,
dropout_rate=self.dropout_rate,
deterministic=deterministic,
dtype=self.dtype,
float32_logits=self.float32_logits)
x = x.reshape((x.shape[0], x.shape[1], x.shape[2] * x.shape[3])) x = x.reshape((x.shape[0], x.shape[1], x.shape[2] * x.shape[3]))
......
# Copyright (c) 2022-2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
"""JAX multi-head attention modules"""
from enum import Enum
from functools import partial
import jax
import jax.numpy as jnp
import transformer_engine_jax
from transformer_engine_jax import NVTE_Bias_Type
from transformer_engine_jax import NVTE_Mask_Type
from .cpp_extensions import cross_fused_attn_max_512_fwd, cross_fused_attn_max_512_bwd
from .cpp_extensions import self_fused_attn_max_512_fwd, self_fused_attn_max_512_bwd
from .sharding import get_fused_attn_sharding_meta
from .sharding import ShardingType
from .sharding import xmap_runner
jax.config.update('experimental_xmap_spmd_lowering', True)
jax.config.update('experimental_xmap_spmd_lowering_manual', True)
def is_fused_attn_kernel_available():
"""
To check whether the fused attention kernel is available
"""
return transformer_engine_jax.is_fused_attn_kernel_available()
class AttnBiasType(Enum):
"""Attention Bias Type."""
NO_BIAS = NVTE_Bias_Type.NVTE_NO_BIAS
PRE_SCALE_BIAS = NVTE_Bias_Type.NVTE_PRE_SCALE_BIAS
POST_SCALE_BIAS = NVTE_Bias_Type.NVTE_POST_SCALE_BIAS
class AttnMaskType(Enum):
"""Attention Mask Type."""
NO_MASK = NVTE_Mask_Type.NVTE_NO_MASK
PADDING_MASK = NVTE_Mask_Type.NVTE_PADDING_MASK
CAUSAL_MASK = NVTE_Mask_Type.NVTE_CAUSAL_MASK
def self_fused_attn(qkv: jnp.ndarray,
bias: jnp.ndarray,
mask: jnp.ndarray,
rng_state: jnp.ndarray,
attn_bias_type: AttnBiasType,
attn_mask_type: AttnMaskType,
scaling_factor: float,
dropout_probability: float,
is_training: bool,
sharding_type: ShardingType = ShardingType.SINGLE):
"""
Self fused attention wrapper
"""
assert sharding_type not in (ShardingType.TP_ROW, ShardingType.DP_TP_ROW), \
"Fused_attn_max_512 does not support row-split tensor parallelism currently."
if sharding_type is ShardingType.SINGLE:
output = _self_fused_attn_max_512(qkv,
bias,
mask,
rng_state,
attn_bias_type=attn_bias_type,
attn_mask_type=attn_mask_type,
scaling_factor=scaling_factor,
dropout_probability=dropout_probability,
is_training=is_training)
else:
dp_axis_name = "batch"
tp_axis_name = "model"
inputs = [qkv, bias, mask, rng_state]
batch, seqlen, _, num_head, head_dim = qkv.shape
output_shape = [batch, seqlen, num_head, head_dim]
sharding_meta = get_fused_attn_sharding_meta(
sharding_type, [x.shape if x is not None else None for x in inputs], [output_shape],
dp_dims=([0, None, 0, None], [0]),
tp_dims=([3, 1, None, None], [2]),
dp_axis_name=dp_axis_name,
tp_axis_name=tp_axis_name)
inputs_ = tuple(
jnp.reshape(x, new_shape) if x is not None else None
for x, new_shape in zip(inputs, sharding_meta.input_shapes))
partial_self_fused_attn_max_512 = partial(_self_fused_attn_max_512,
attn_bias_type=attn_bias_type,
attn_mask_type=attn_mask_type,
scaling_factor=scaling_factor,
dropout_probability=dropout_probability,
is_training=is_training)
output_ = xmap_runner(partial_self_fused_attn_max_512, sharding_meta.in_axes,
sharding_meta.out_axes[0], sharding_meta.axis_resources, inputs_)
output = jnp.reshape(output_, sharding_meta.output_shapes[0])
return output
@partial(jax.custom_vjp, nondiff_argnums=(4, 5, 6, 7, 8))
def _self_fused_attn_max_512(qkv: jnp.ndarray, bias: jnp.ndarray, mask: jnp.ndarray,
rng_state: jnp.ndarray, attn_bias_type: AttnBiasType,
attn_mask_type: AttnMaskType, scaling_factor: float,
dropout_probability: float, is_training: bool):
output, _ = _self_fused_attn_max_512_fwd(qkv,
bias,
mask,
rng_state,
attn_bias_type=attn_bias_type,
attn_mask_type=attn_mask_type,
scaling_factor=scaling_factor,
dropout_probability=dropout_probability,
is_training=is_training)
return output
def _self_fused_attn_max_512_fwd(qkv, bias, mask, rng_state, attn_bias_type, attn_mask_type,
scaling_factor, dropout_probability, is_training):
seqlen = jnp.sum(mask[:, :, :, 0] == 0, axis=(-1, -2), dtype=jnp.int32)
cu_seqlen = jnp.cumsum(seqlen)
cu_seqlen = jnp.hstack((0, cu_seqlen))
output, softmax_aux = self_fused_attn_max_512_fwd(qkv,
bias,
cu_seqlen,
rng_state,
attn_bias_type=attn_bias_type.value,
attn_mask_type=attn_mask_type.value,
scaling_factor=scaling_factor,
dropout_probability=dropout_probability,
is_training=is_training)
return output, (qkv, softmax_aux, cu_seqlen)
def _self_fused_attn_max_512_bwd(attn_bias_type, attn_mask_type, scaling_factor,
dropout_probability, is_training, ctx, grad):
qkv, softmax_aux, cu_seqlen = ctx
doutput = grad
grad_qkv, grad_bias = self_fused_attn_max_512_bwd(qkv,
softmax_aux,
doutput,
cu_seqlen,
attn_bias_type=attn_bias_type.value,
attn_mask_type=attn_mask_type.value,
scaling_factor=scaling_factor,
dropout_probability=dropout_probability,
is_training=is_training)
return grad_qkv, grad_bias, None, None
_self_fused_attn_max_512.defvjp(_self_fused_attn_max_512_fwd, _self_fused_attn_max_512_bwd)
def cross_fused_attn(q: jnp.ndarray,
kv: jnp.ndarray,
mask: jnp.ndarray,
rng_state: jnp.ndarray,
attn_bias_type: AttnBiasType,
attn_mask_type: AttnMaskType,
scaling_factor: float,
dropout_probability: float,
is_training: bool,
sharding_type: ShardingType = ShardingType.SINGLE):
"""
Cross multi-head attention wrapper
"""
assert sharding_type not in (ShardingType.TP_ROW, ShardingType.DP_TP_ROW), \
"Fused_attn_max_512 does not support row-split tensor parallelism currently."
if sharding_type is ShardingType.SINGLE:
output = _cross_fused_attn_max_512(q,
kv,
mask,
rng_state,
attn_bias_type=attn_bias_type,
attn_mask_type=attn_mask_type,
scaling_factor=scaling_factor,
dropout_probability=dropout_probability,
is_training=is_training)
else:
dp_axis_name = "batch"
tp_axis_name = "model"
inputs = [q, kv, mask, rng_state]
output_shape = q.shape
sharding_meta = get_fused_attn_sharding_meta(
sharding_type, [x.shape if x is not None else None for x in inputs], [output_shape],
dp_dims=([0, 0, 0, None], [0]),
tp_dims=([2, 3, None, None], [2]),
dp_axis_name=dp_axis_name,
tp_axis_name=tp_axis_name)
inputs_ = tuple(
jnp.reshape(x, new_shape) if x is not None else None
for x, new_shape in zip(inputs, sharding_meta.input_shapes))
partial_cross_fused_attn_max_512 = partial(_cross_fused_attn_max_512,
attn_bias_type=attn_bias_type,
attn_mask_type=attn_mask_type,
scaling_factor=scaling_factor,
dropout_probability=dropout_probability,
is_training=is_training)
output_ = xmap_runner(partial_cross_fused_attn_max_512, sharding_meta.in_axes,
sharding_meta.out_axes[0], sharding_meta.axis_resources, inputs_)
output = jnp.reshape(output_, sharding_meta.output_shapes[0])
return output
@partial(jax.custom_vjp, nondiff_argnums=(4, 5, 6, 7, 8))
def _cross_fused_attn_max_512(q: jnp.ndarray, kv: jnp.ndarray, mask: jnp.ndarray,
rng_state: jnp.ndarray, attn_bias_type: AttnBiasType,
attn_mask_type: AttnMaskType, scaling_factor: float,
dropout_probability: float, is_training: bool):
output, _ = _cross_fused_attn_max_512_fwd(q,
kv,
mask,
rng_state,
attn_bias_type=attn_bias_type,
attn_mask_type=attn_mask_type,
scaling_factor=scaling_factor,
dropout_probability=dropout_probability,
is_training=is_training)
return output
def _cross_fused_attn_max_512_fwd(q, kv, mask, rng_state, attn_bias_type, attn_mask_type,
scaling_factor, dropout_probability, is_training):
q_seqlen = jnp.sum(mask[:, :, :, 0] == 0, axis=(-1, -2), dtype=jnp.int32)
q_cu_seqlen = jnp.cumsum(q_seqlen)
q_cu_seqlen = jnp.hstack((0, q_cu_seqlen))
kv_seqlen = jnp.sum(mask[:, :, 0, :] == 0, axis=(-1, -2), dtype=jnp.int32)
kv_cu_seqlen = jnp.cumsum(kv_seqlen)
kv_cu_seqlen = jnp.hstack((0, kv_cu_seqlen))
output, softmax_aux = cross_fused_attn_max_512_fwd(q,
kv,
q_cu_seqlen,
kv_cu_seqlen,
rng_state,
attn_bias_type=attn_bias_type.value,
attn_mask_type=attn_mask_type.value,
scaling_factor=scaling_factor,
dropout_probability=dropout_probability,
is_training=is_training)
return output, (softmax_aux, q, kv, q_cu_seqlen, kv_cu_seqlen)
def _cross_fused_attn_max_512_bwd(attn_bias_type, attn_mask_type, scaling_factor,
dropout_probability, is_training, ctx, grad):
softmax_aux, q, kv, q_cu_seqlen, kv_cu_seqlen = ctx
doutput = grad
grad_q, grad_kv = cross_fused_attn_max_512_bwd(q,
kv,
softmax_aux,
doutput,
q_cu_seqlen,
kv_cu_seqlen,
attn_bias_type=attn_bias_type.value,
attn_mask_type=attn_mask_type.value,
scaling_factor=scaling_factor,
dropout_probability=dropout_probability,
is_training=is_training)
return grad_q, grad_kv, None, None
_cross_fused_attn_max_512.defvjp(_cross_fused_attn_max_512_fwd, _cross_fused_attn_max_512_bwd)
...@@ -8,6 +8,7 @@ Sharding Meta for xmap with CustomCall ...@@ -8,6 +8,7 @@ Sharding Meta for xmap with CustomCall
from contextlib import contextmanager from contextlib import contextmanager
from dataclasses import dataclass from dataclasses import dataclass
from enum import Enum from enum import Enum
from itertools import repeat
from typing import Union, Tuple, Dict, Callable, Sequence from typing import Union, Tuple, Dict, Callable, Sequence
from jax.interpreters import pxla from jax.interpreters import pxla
import jax import jax
...@@ -315,6 +316,121 @@ class FP8MetaShardingMetaGenerator(ShardingMetaGenerator): ...@@ -315,6 +316,121 @@ class FP8MetaShardingMetaGenerator(ShardingMetaGenerator):
axis_resource, (), ()) axis_resource, (), ())
class FusedAttnShardingMetaGenerator(ShardingMetaGenerator):
"""
FusedAttnShardingMetaGenerator
"""
def get_dp_sharding_meta(
self,
input_shapes: Tuple[Tuple[int, ...]],
output_shapes: Tuple[Tuple[int, ...]],
dp_dims: Tuple[Tuple[int, ...]],
tp_dims: Tuple[Tuple[int, ...]], # pylint: disable=unused-argument
dp_axis_name: str = 'data',
tp_axis_name: str = 'model' # pylint: disable=unused-argument
) -> ShardingMeta:
"""get_dp_sharding_meta"""
dummy_tp_dims = [repeat(None), repeat(None)]
return FusedAttnShardingMetaGenerator._get_dptp_sharding_meta(input_shapes, output_shapes,
dp_dims, dummy_tp_dims,
dp_axis_name, None)
def get_tp_col_sharding_meta(self, *argv, **kwargs) -> ShardingMeta:
"""get_tp_col_sharding_meta"""
return FusedAttnShardingMetaGenerator._get_tp_sharding_meta(*argv, **kwargs)
def get_tp_row_sharding_meta(self, *argv, **kwargs) -> ShardingMeta:
"""get_tp_row_sharding_meta"""
return FusedAttnShardingMetaGenerator._get_tp_sharding_meta(*argv, **kwargs)
def get_dp_tp_col_sharding_meta(self, *argv, **kwargs) -> ShardingMeta:
"""get_dp_tp_col_sharding_meta"""
return FusedAttnShardingMetaGenerator._get_dptp_sharding_meta(*argv, **kwargs)
def get_dp_tp_row_sharding_meta(self, *argv, **kwargs) -> ShardingMeta:
"""get_dp_tp_row_sharding_meta"""
return FusedAttnShardingMetaGenerator._get_dptp_sharding_meta(*argv, **kwargs)
@staticmethod
def _get_tp_sharding_meta(
input_shapes: Tuple[Tuple[int, ...]],
output_shapes: Tuple[Tuple[int, ...]],
dp_dims: Tuple[Tuple[int, ...]], # pylint: disable=unused-argument
tp_dims: Tuple[Tuple[int, ...]],
dp_axis_name: str = 'data', # pylint: disable=unused-argument
tp_axis_name: str = 'model') -> ShardingMeta:
"""get_tp_sharding_meta"""
dummy_dp_dims = [repeat(None), repeat(None)]
return FusedAttnShardingMetaGenerator._get_dptp_sharding_meta(input_shapes, output_shapes,
dummy_dp_dims, tp_dims, None,
tp_axis_name)
@staticmethod
def _get_dptp_sharding_meta(input_shapes: Tuple[Tuple[int, ...]],
output_shapes: Tuple[Tuple[int, ...]],
dp_dims: Tuple[Tuple[int, ...]],
tp_dims: Tuple[Tuple[int, ...]],
dp_axis_name: str = 'data',
tp_axis_name: str = 'model') -> ShardingMeta:
"""get_dp_tp_sharding_meta"""
dp_size, dp_mesh_axis = _get_mesh_info(global_shard_resource().dp_resource)
tp_size, tp_mesh_axis = _get_mesh_info(global_shard_resource().tp_resource)
input_dp_dims, output_dp_dims = dp_dims
input_tp_dims, output_tp_dims = tp_dims
input_new_shapes = []
in_axes = []
for input_shape, dp_dim, tp_dim in zip(input_shapes, input_dp_dims, input_tp_dims):
in_axis = {}
if dp_dim is not None:
in_axis[dp_dim] = dp_axis_name
assert input_shape[dp_dim] % dp_size == 0, \
f"The dimension of batch in input_shape should be a multiple of " \
f"data parallelism size, but got {input_shape[dp_dim]=} and {dp_size=}."
input_shape = (*input_shape[:dp_dim], dp_size, input_shape[dp_dim] // dp_size,
*input_shape[dp_dim + 1:])
# the input shape has been expanded for dp_dim, tp_dim should +1 if tp_dim >= dp_dim
if tp_dim is not None and tp_dim >= dp_dim:
tp_dim = tp_dim + 1
if tp_dim is not None:
in_axis[tp_dim] = tp_axis_name
assert input_shape[tp_dim] % tp_size == 0, \
f"The dimension of tensor parallel in input_shape should be a multiple of " \
f"tensor parallelism size, but got {input_shape[tp_dim]=} and {tp_size=}."
input_shape = (*input_shape[:tp_dim], tp_size, input_shape[tp_dim] // tp_size,
*input_shape[tp_dim + 1:])
in_axes.append(in_axis)
input_new_shapes.append(input_shape)
output_new_shapes = output_shapes
out_axes = []
for dp_dim, tp_dim in zip(output_dp_dims, output_tp_dims):
out_axis = {}
if dp_dim is not None:
out_axis[dp_dim] = dp_axis_name
if tp_dim is not None and tp_dim >= dp_dim:
tp_dim = tp_dim + 1
if tp_dim is not None:
out_axis[tp_dim] = tp_axis_name
out_axes.append(out_axis)
axis_resources = {}
if dp_axis_name is not None:
axis_resources[dp_axis_name] = dp_mesh_axis
if tp_axis_name is not None:
axis_resources[tp_axis_name] = tp_mesh_axis
return ShardingMeta(tuple(in_axes), out_axes, axis_resources, input_new_shapes,
output_new_shapes)
class DotShardingMetaGenerator(ShardingMetaGenerator): class DotShardingMetaGenerator(ShardingMetaGenerator):
""" """
DotShardingMetaGenerator DotShardingMetaGenerator
...@@ -884,6 +1000,21 @@ def get_softmax_sharding_meta(stype: ShardingType, ...@@ -884,6 +1000,21 @@ def get_softmax_sharding_meta(stype: ShardingType,
dp_axis_name, tp_axis_name) dp_axis_name, tp_axis_name)
def get_fused_attn_sharding_meta(stype: ShardingType,
input_shapes: Tuple[Tuple[int, ...]],
output_shapes: Tuple[Tuple[int, ...]],
dp_dims: Tuple[Tuple[int, ...]],
tp_dims: Tuple[Tuple[int, ...]],
dp_axis_name: str = 'data',
tp_axis_name: str = 'model') -> ShardingMeta:
"""
get_self_fused_attn_sharding_meta
"""
return FusedAttnShardingMetaGenerator().get_sharding_meta(stype, input_shapes, output_shapes,
dp_dims, tp_dims, dp_axis_name,
tp_axis_name)
def xmap_runner(func: Callable, in_axes: Tuple[Dict, ...], def xmap_runner(func: Callable, in_axes: Tuple[Dict, ...],
out_axes: Union[Dict, Tuple[str, ...], Tuple[Union[Dict, Tuple], ...]], out_axes: Union[Dict, Tuple[str, ...], Tuple[Union[Dict, Tuple], ...]],
axis_resources: Dict, inputs: Tuple): axis_resources: Dict, inputs: Tuple):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment