Unverified Commit 66ff2e36 authored by zlsh80826's avatar zlsh80826 Committed by GitHub
Browse files

[JAX] flash attention integration (#345)



* Fix flash attention dropout probability with inference
Signed-off-by: default avatarReese Wang <rewang@nvidia.com>

* Add output as the fused attention ctx tensor
Signed-off-by: default avatarReese Wang <rewang@nvidia.com>

* Add rng_state as the fused attention ctx tensors
Signed-off-by: default avatarReese Wang <rewang@nvidia.com>

* Add flash attention supported lengths to the fused attention
Signed-off-by: default avatarReese Wang <rewang@nvidia.com>

* Refactor attention primitive to reuse abstract shaped array
Signed-off-by: default avatarReese Wang <rewang@nvidia.com>

* Detect backend type to allocate appropriate ctx size
Signed-off-by: default avatarReese Wang <rewang@nvidia.com>

* Skip dropout correctness instead of return success
Signed-off-by: default avatarReese Wang <rewang@nvidia.com>

* Use cudaMemsetAsync and enhance the error handling
Signed-off-by: default avatarReese Wang <rewang@nvidia.com>

* Add flash attention kernel elts_per_thread update
Signed-off-by: default avatarReese Wang <rewang@nvidia.com>

* Remove redundant max 512 suffix
Signed-off-by: default avatarReese Wang <rewang@nvidia.com>

* Keep only DType and remove NVTEDType from python
Signed-off-by: default avatarReese Wang <rewang@nvidia.com>

* Fix a float32_attention_logits bugs
Signed-off-by: default avatarReese Wang <rewang@nvidia.com>

* Re-calculate workspace size for self attention
Signed-off-by: default avatarReese Wang <rewang@nvidia.com>

* Enhance bias/dbias shape guard
Signed-off-by: default avatarReese Wang <rewang@nvidia.com>

* Enhance the seed/rng_state checker
Signed-off-by: default avatarReese Wang <rewang@nvidia.com>

* Use jax.core.ShapedArray as jax.abstract_arrays is deprecated
Signed-off-by: default avatarReese Wang <rewang@nvidia.com>

* Enhance the unittest docs
Signed-off-by: default avatarReese Wang <rewang@nvidia.com>

---------
Signed-off-by: default avatarReese Wang <rewang@nvidia.com>
parent 403ade2f
...@@ -4,7 +4,7 @@ ...@@ -4,7 +4,7 @@
import pytest import pytest
import jax.numpy as jnp import jax.numpy as jnp
from jax.abstract_arrays import ShapedArray from jax.core import ShapedArray
from transformer_engine_jax import DType from transformer_engine_jax import DType
from transformer_engine.jax.cpp_extensions import te_dtype_to_jax_dtype from transformer_engine.jax.cpp_extensions import te_dtype_to_jax_dtype
......
This diff is collapsed.
...@@ -547,7 +547,7 @@ void fused_attn_arbitrary_seqlen_fwd_impl( ...@@ -547,7 +547,7 @@ void fused_attn_arbitrary_seqlen_fwd_impl(
NVTE_CHECK_CUDNN(cudnnSetStream(handle, stream)); NVTE_CHECK_CUDNN(cudnnSetStream(handle, stream));
if (!is_training) { if (!is_training) {
dropout_probability == 0.0f; dropout_probability = 0.0f;
} }
FADescriptor descriptor{b, h, FADescriptor descriptor{b, h,
...@@ -1144,7 +1144,7 @@ void fused_attn_arbitrary_seqlen_bwd_impl( ...@@ -1144,7 +1144,7 @@ void fused_attn_arbitrary_seqlen_bwd_impl(
void *devPtrSoftmaxSum = static_cast<int8_t *>(workspace) + plan_workspace_size; void *devPtrSoftmaxSum = static_cast<int8_t *>(workspace) + plan_workspace_size;
void *devPtrdQAccumulator = static_cast<int8_t *>(devPtrSoftmaxSum) void *devPtrdQAccumulator = static_cast<int8_t *>(devPtrSoftmaxSum)
+ softmaxSum_workspace_size; + softmaxSum_workspace_size;
NVTE_CHECK_CUDA(cudaMemset(devPtrdQAccumulator, 0, dqAccum_workspace_size)); NVTE_CHECK_CUDA(cudaMemsetAsync(devPtrdQAccumulator, 0, dqAccum_workspace_size, stream));
std::set<std::pair<uint64_t, void *>> data_ptrs; std::set<std::pair<uint64_t, void *>> data_ptrs;
// add all the data pointers to be used in the variant pack // add all the data pointers to be used in the variant pack
...@@ -1224,6 +1224,8 @@ void fused_attn_arbitrary_seqlen_fwd_qkvpacked( ...@@ -1224,6 +1224,8 @@ void fused_attn_arbitrary_seqlen_fwd_qkvpacked(
devPtrS = output_S->data.dptr; devPtrS = output_S->data.dptr;
Tensor *output_rng_state = reinterpret_cast<Tensor *>(Aux_CTX_Tensors->tensors[1]); Tensor *output_rng_state = reinterpret_cast<Tensor *>(Aux_CTX_Tensors->tensors[1]);
output_rng_state->data.dptr = rng_state->data.dptr; output_rng_state->data.dptr = rng_state->data.dptr;
} else {
NVTE_ERROR("Unexpected Aux_CTX_Tensors->size.");
} }
void* devPtrDropoutSeed = rng_state->data.dptr; void* devPtrDropoutSeed = rng_state->data.dptr;
...@@ -1250,6 +1252,8 @@ void fused_attn_arbitrary_seqlen_fwd_qkvpacked( ...@@ -1250,6 +1252,8 @@ void fused_attn_arbitrary_seqlen_fwd_qkvpacked(
workspace->data.shape = {1}; workspace->data.shape = {1};
workspace->data.dtype = DType::kByte; workspace->data.dtype = DType::kByte;
return; return;
} else {
NVTE_ERROR("Unexpected workspace_size.");
} }
} }
...@@ -1312,6 +1316,8 @@ void fused_attn_arbitrary_seqlen_bwd_qkvpacked(size_t batch, size_t max_seqlen, ...@@ -1312,6 +1316,8 @@ void fused_attn_arbitrary_seqlen_bwd_qkvpacked(size_t batch, size_t max_seqlen,
workspace->data.shape = {1}; workspace->data.shape = {1};
workspace->data.dtype = DType::kByte; workspace->data.dtype = DType::kByte;
return; return;
} else {
NVTE_ERROR("Unexpected workspace_size.");
} }
} }
} // namespace transformer_engine } // namespace transformer_engine
......
...@@ -1275,6 +1275,8 @@ void fused_attn_max_512_fwd_qkvpacked( ...@@ -1275,6 +1275,8 @@ void fused_attn_max_512_fwd_qkvpacked(
} else if (Aux_CTX_Tensors->size == 1) { } else if (Aux_CTX_Tensors->size == 1) {
Tensor *output_S = reinterpret_cast<Tensor *>(Aux_CTX_Tensors->tensors[0]); Tensor *output_S = reinterpret_cast<Tensor *>(Aux_CTX_Tensors->tensors[0]);
devPtrS = output_S->data.dptr; devPtrS = output_S->data.dptr;
} else {
NVTE_ERROR("Unexpected Aux_CTX_Tensors->size.");
} }
void *devPtrCuSeqlen = cu_seqlens->data.dptr; void *devPtrCuSeqlen = cu_seqlens->data.dptr;
...@@ -1304,6 +1306,8 @@ void fused_attn_max_512_fwd_qkvpacked( ...@@ -1304,6 +1306,8 @@ void fused_attn_max_512_fwd_qkvpacked(
workspace->data.shape = {1}; workspace->data.shape = {1};
workspace->data.dtype = DType::kByte; workspace->data.dtype = DType::kByte;
return; return;
} else {
NVTE_ERROR("Unexpected workspace_size.");
} }
} }
...@@ -1351,6 +1355,8 @@ void fused_attn_max_512_fwd_kvpacked(size_t batch, size_t q_max_seqlen, size_t k ...@@ -1351,6 +1355,8 @@ void fused_attn_max_512_fwd_kvpacked(size_t batch, size_t q_max_seqlen, size_t k
} else if (Aux_CTX_Tensors->size == 1) { } else if (Aux_CTX_Tensors->size == 1) {
Tensor *output_S = reinterpret_cast<Tensor *>(Aux_CTX_Tensors->tensors[0]); Tensor *output_S = reinterpret_cast<Tensor *>(Aux_CTX_Tensors->tensors[0]);
devPtrS = output_S->data.dptr; devPtrS = output_S->data.dptr;
} else {
NVTE_ERROR("Unexpected Aux_CTX_Tensors->size.");
} }
void *devQCuSeqlen = q_cu_seqlens->data.dptr; void *devQCuSeqlen = q_cu_seqlens->data.dptr;
...@@ -1380,6 +1386,8 @@ void fused_attn_max_512_fwd_kvpacked(size_t batch, size_t q_max_seqlen, size_t k ...@@ -1380,6 +1386,8 @@ void fused_attn_max_512_fwd_kvpacked(size_t batch, size_t q_max_seqlen, size_t k
workspace->data.shape = {1}; workspace->data.shape = {1};
workspace->data.dtype = DType::kByte; workspace->data.dtype = DType::kByte;
return; return;
} else {
NVTE_ERROR("Unexpected workspace_size.");
} }
} }
...@@ -1440,6 +1448,8 @@ void fused_attn_max_512_bwd_qkvpacked(size_t batch, size_t max_seqlen, size_t nu ...@@ -1440,6 +1448,8 @@ void fused_attn_max_512_bwd_qkvpacked(size_t batch, size_t max_seqlen, size_t nu
workspace->data.shape = {1}; workspace->data.shape = {1};
workspace->data.dtype = DType::kByte; workspace->data.dtype = DType::kByte;
return; return;
} else {
NVTE_ERROR("Unexpected workspace_size.");
} }
} }
...@@ -1503,6 +1513,8 @@ void fused_attn_max_512_bwd_kvpacked(size_t batch, size_t q_max_seqlen, size_t k ...@@ -1503,6 +1513,8 @@ void fused_attn_max_512_bwd_kvpacked(size_t batch, size_t q_max_seqlen, size_t k
workspace->data.shape = {1}; workspace->data.shape = {1};
workspace->data.dtype = DType::kByte; workspace->data.dtype = DType::kByte;
return; return;
} else {
NVTE_ERROR("Unexpected workspace_size.");
} }
} }
} // namespace transformer_engine } // namespace transformer_engine
......
This diff is collapsed.
...@@ -46,11 +46,10 @@ pybind11::dict Registrations() { ...@@ -46,11 +46,10 @@ pybind11::dict Registrations() {
EncapsulateFunction(ScaledUpperTriangMaskedSoftmaxForward); EncapsulateFunction(ScaledUpperTriangMaskedSoftmaxForward);
dict["te_scaled_upper_triang_masked_softmax_backward"] = dict["te_scaled_upper_triang_masked_softmax_backward"] =
EncapsulateFunction(ScaledUpperTriangMaskedSoftmaxBackward); EncapsulateFunction(ScaledUpperTriangMaskedSoftmaxBackward);
dict["te_self_fused_attn_max_512_forward"] = EncapsulateFunction(SelfFusedAttnMax512Forward); dict["te_self_fused_attn_forward"] = EncapsulateFunction(SelfFusedAttnForward);
dict["te_self_fused_attn_max_512_backward"] = EncapsulateFunction(SelfFusedAttnMax512Backward); dict["te_self_fused_attn_backward"] = EncapsulateFunction(SelfFusedAttnBackward);
dict["te_cross_fused_attn_max_512_forward"] = EncapsulateFunction(CrossFusedAttnMax512Forward); dict["te_cross_fused_attn_forward"] = EncapsulateFunction(CrossFusedAttnForward);
dict["te_cross_fused_attn_max_512_backward"] = dict["te_cross_fused_attn_backward"] = EncapsulateFunction(CrossFusedAttnBackward);
EncapsulateFunction(CrossFusedAttnMax512Backward);
return dict; return dict;
} }
...@@ -65,6 +64,7 @@ PYBIND11_MODULE(transformer_engine_jax, m) { ...@@ -65,6 +64,7 @@ PYBIND11_MODULE(transformer_engine_jax, m) {
m.def("get_device_compute_capability", &GetDeviceComputeCapability); m.def("get_device_compute_capability", &GetDeviceComputeCapability);
m.def("pack_fused_attn_descriptor", &PackCustomCallFusedAttnDescriptor); m.def("pack_fused_attn_descriptor", &PackCustomCallFusedAttnDescriptor);
m.def("is_fused_attn_kernel_available", &IsFusedAttnKernelAvailable); m.def("is_fused_attn_kernel_available", &IsFusedAttnKernelAvailable);
m.def("get_fused_attn_backend", &GetFusedAttnBackend);
pybind11::enum_<DType>(m, "DType", pybind11::module_local()) pybind11::enum_<DType>(m, "DType", pybind11::module_local())
.value("kByte", DType::kByte) .value("kByte", DType::kByte)
...@@ -85,6 +85,17 @@ PYBIND11_MODULE(transformer_engine_jax, m) { ...@@ -85,6 +85,17 @@ PYBIND11_MODULE(transformer_engine_jax, m) {
.value("NVTE_NO_MASK", NVTE_Mask_Type::NVTE_NO_MASK) .value("NVTE_NO_MASK", NVTE_Mask_Type::NVTE_NO_MASK)
.value("NVTE_PADDING_MASK", NVTE_Mask_Type::NVTE_PADDING_MASK) .value("NVTE_PADDING_MASK", NVTE_Mask_Type::NVTE_PADDING_MASK)
.value("NVTE_CAUSAL_MASK", NVTE_Mask_Type::NVTE_CAUSAL_MASK); .value("NVTE_CAUSAL_MASK", NVTE_Mask_Type::NVTE_CAUSAL_MASK);
pybind11::enum_<NVTE_QKV_Layout>(m, "NVTE_QKV_Layout", pybind11::module_local())
.value("NVTE_NOT_INTERLEAVED", NVTE_QKV_Layout::NVTE_NOT_INTERLEAVED)
.value("NVTE_QKV_INTERLEAVED", NVTE_QKV_Layout::NVTE_QKV_INTERLEAVED)
.value("NVTE_KV_INTERLEAVED", NVTE_QKV_Layout::NVTE_KV_INTERLEAVED);
pybind11::enum_<NVTE_Fused_Attn_Backend>(m, "NVTE_Fused_Attn_Backend", pybind11::module_local())
.value("NVTE_No_Backend", NVTE_Fused_Attn_Backend::NVTE_No_Backend)
.value("NVTE_F16_max512_seqlen", NVTE_Fused_Attn_Backend::NVTE_F16_max512_seqlen)
.value("NVTE_F16_arbitrary_seqlen", NVTE_Fused_Attn_Backend::NVTE_F16_arbitrary_seqlen)
.value("NVTE_FP8", NVTE_Fused_Attn_Backend::NVTE_FP8);
} }
} // namespace jax } // namespace jax
......
...@@ -740,8 +740,19 @@ void ScaledUpperTriangMaskedSoftmaxBackward(cudaStream_t stream, void **buffers, ...@@ -740,8 +740,19 @@ void ScaledUpperTriangMaskedSoftmaxBackward(cudaStream_t stream, void **buffers,
desc.scale_factor, stream); desc.scale_factor, stream);
} }
void SelfFusedAttnMax512Forward(cudaStream_t stream, void **buffers, const char *opaque, NVTE_Fused_Attn_Backend GetFusedAttnBackend(DType q_dtype, DType kv_dtype,
size_t opaque_len) { NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type,
NVTE_Mask_Type mask_type, float dropout_probability,
size_t q_max_seqlen, size_t kv_max_seqlen,
size_t head_dim) {
auto backend = nvte_get_fused_attn_backend(
static_cast<NVTEDType>(q_dtype), static_cast<NVTEDType>(kv_dtype), qkv_layout, bias_type,
mask_type, dropout_probability, q_max_seqlen, kv_max_seqlen, head_dim);
return backend;
}
void SelfFusedAttnForward(cudaStream_t stream, void **buffers, const char *opaque,
size_t opaque_len) {
const CustomCallFusedAttnDescriptor &descriptor = const CustomCallFusedAttnDescriptor &descriptor =
*UnpackOpaque<CustomCallFusedAttnDescriptor>(opaque, opaque_len); *UnpackOpaque<CustomCallFusedAttnDescriptor>(opaque, opaque_len);
...@@ -754,12 +765,17 @@ void SelfFusedAttnMax512Forward(cudaStream_t stream, void **buffers, const char ...@@ -754,12 +765,17 @@ void SelfFusedAttnMax512Forward(cudaStream_t stream, void **buffers, const char
// output // output
void *output = buffers[4]; void *output = buffers[4];
void *softmax_aux = buffers[5]; void *softmax_aux = buffers[5];
void *rng_state = buffers[6];
auto batch = descriptor.batch; auto batch = descriptor.batch;
auto num_head = descriptor.num_head; auto num_head = descriptor.num_head;
auto q_max_seqlen = descriptor.q_max_seqlen; auto q_max_seqlen = descriptor.q_max_seqlen;
auto kv_max_seqlen = descriptor.kv_max_seqlen; auto kv_max_seqlen = descriptor.kv_max_seqlen;
auto head_dim = descriptor.head_dim; auto head_dim = descriptor.head_dim;
auto dropout_probability = descriptor.dropout_probability;
auto bias_type = descriptor.bias_type;
auto mask_type = descriptor.mask_type;
constexpr auto qkv_layout = NVTE_QKV_Layout::NVTE_QKV_INTERLEAVED;
NVTE_CHECK(q_max_seqlen == kv_max_seqlen, NVTE_CHECK(q_max_seqlen == kv_max_seqlen,
"q_max_seqlen should be equal to kv_max_seqlen in the self attention."); "q_max_seqlen should be equal to kv_max_seqlen in the self attention.");
...@@ -768,78 +784,81 @@ void SelfFusedAttnMax512Forward(cudaStream_t stream, void **buffers, const char ...@@ -768,78 +784,81 @@ void SelfFusedAttnMax512Forward(cudaStream_t stream, void **buffers, const char
auto qkv_shape = std::vector<size_t>{batch * q_max_seqlen, 3, num_head, head_dim}; auto qkv_shape = std::vector<size_t>{batch * q_max_seqlen, 3, num_head, head_dim};
auto bias_shape = std::vector<size_t>{1, num_head, q_max_seqlen, kv_max_seqlen}; auto bias_shape = std::vector<size_t>{1, num_head, q_max_seqlen, kv_max_seqlen};
// input tensors
auto qkv_tensor = TensorWrapper(qkv, qkv_shape, dtype); auto qkv_tensor = TensorWrapper(qkv, qkv_shape, dtype);
auto bias_tensor = TensorWrapper(bias, bias_shape, dtype); auto bias_tensor = TensorWrapper(bias, bias_shape, dtype);
auto cu_seqlens_tensor =
TensorWrapper(cu_seqlens, std::vector<size_t>{batch + 1}, DType::kInt32);
// FP16/BF16 doesn't use this tensor // output tensors
auto s_tensor = TensorWrapper(nullptr, std::vector<size_t>{1}, dtype);
auto o_tensor = auto o_tensor =
TensorWrapper(output, std::vector<size_t>{batch * q_max_seqlen, num_head, head_dim}, dtype); TensorWrapper(output, std::vector<size_t>{batch * q_max_seqlen, num_head, head_dim}, dtype);
auto cu_seqlens_tensor = // F16 doesn't use this tensor
TensorWrapper(cu_seqlens, std::vector<size_t>{batch + 1}, DType::kInt32); auto s_tensor = TensorWrapper(nullptr, std::vector<size_t>{1}, dtype);
auto dummy_rng_state_tensor = TensorWrapper(nullptr, std::vector<size_t>{2}, DType::kInt64); // aux tensors
auto rng_state_tensor = TensorWrapper(rng_state, std::vector<size_t>{2}, DType::kInt64);
auto backend = nvte_get_fused_attn_backend(
static_cast<NVTEDType>(dtype), static_cast<NVTEDType>(dtype), qkv_layout, bias_type,
mask_type, dropout_probability, q_max_seqlen, kv_max_seqlen, head_dim);
PopulateRngStateAsync(rng_state, seed, q_max_seqlen, kv_max_seqlen, backend, stream);
NVTETensorPack aux_output_tensors; NVTETensorPack aux_output_tensors;
nvte_tensor_pack_create(&aux_output_tensors); nvte_tensor_pack_create(&aux_output_tensors);
TensorWrapper query_workspace_tensor; TensorWrapper query_workspace_tensor;
nvte_fused_attn_fwd_qkvpacked( nvte_fused_attn_fwd_qkvpacked(qkv_tensor.data(), bias_tensor.data(), s_tensor.data(),
qkv_tensor.data(), bias_tensor.data(), s_tensor.data(), o_tensor.data(), o_tensor.data(), &aux_output_tensors, cu_seqlens_tensor.data(),
&aux_output_tensors, cu_seqlens_tensor.data(), dummy_rng_state_tensor.data(), q_max_seqlen, rng_state_tensor.data(), q_max_seqlen, descriptor.is_training,
descriptor.is_training, descriptor.scaling_factor, descriptor.dropout_probability, descriptor.scaling_factor, dropout_probability, qkv_layout,
NVTE_QKV_Layout::NVTE_QKV_INTERLEAVED, descriptor.bias_type, descriptor.mask_type, bias_type, mask_type, query_workspace_tensor.data(), stream);
query_workspace_tensor.data(), stream);
auto *output_s = reinterpret_cast<Tensor *>(aux_output_tensors.tensors[0]); auto *output_s = reinterpret_cast<Tensor *>(aux_output_tensors.tensors[0]);
output_s->data.dptr = softmax_aux; output_s->data.dptr = softmax_aux;
// fused attn workspace + workspace for rng_state auto workspace_size = query_workspace_tensor.shape().data[0];
auto plan_workspace_size = auto *workspace = cublasLtMetaManager::Instance().GetWorkspace(workspace_size);
query_workspace_tensor.shape().data[0] * typeToSize(query_workspace_tensor.dtype());
auto rng_workspace_size = 2 * sizeof(int64_t);
auto total_workspace_size = plan_workspace_size + rng_workspace_size;
auto *workspace = cublasLtMetaManager::Instance().GetWorkspace(total_workspace_size);
auto workspace_tensor = auto workspace_tensor =
TensorWrapper(workspace, query_workspace_tensor.shape(), query_workspace_tensor.dtype()); TensorWrapper(workspace, query_workspace_tensor.shape(), query_workspace_tensor.dtype());
auto rng_state = static_cast<uint8_t *>(workspace) + plan_workspace_size;
auto rng_state_tensor = TensorWrapper(rng_state, std::vector<size_t>{2}, DType::kInt64);
PopulateRngStateAsync(rng_state, seed, q_max_seqlen, kv_max_seqlen, stream);
nvte_fused_attn_fwd_qkvpacked(qkv_tensor.data(), bias_tensor.data(), s_tensor.data(), nvte_fused_attn_fwd_qkvpacked(qkv_tensor.data(), bias_tensor.data(), s_tensor.data(),
o_tensor.data(), &aux_output_tensors, cu_seqlens_tensor.data(), o_tensor.data(), &aux_output_tensors, cu_seqlens_tensor.data(),
rng_state_tensor.data(), q_max_seqlen, descriptor.is_training, rng_state_tensor.data(), q_max_seqlen, descriptor.is_training,
descriptor.scaling_factor, descriptor.dropout_probability, descriptor.scaling_factor, dropout_probability, qkv_layout,
NVTE_QKV_Layout::NVTE_QKV_INTERLEAVED, descriptor.bias_type, bias_type, mask_type, workspace_tensor.data(), stream);
descriptor.mask_type, workspace_tensor.data(), stream);
nvte_tensor_pack_destroy(&aux_output_tensors); nvte_tensor_pack_destroy(&aux_output_tensors);
} }
void SelfFusedAttnMax512Backward(cudaStream_t stream, void **buffers, const char *opaque, void SelfFusedAttnBackward(cudaStream_t stream, void **buffers, const char *opaque,
size_t opaque_len) { size_t opaque_len) {
const CustomCallFusedAttnDescriptor &descriptor = const CustomCallFusedAttnDescriptor &descriptor =
*UnpackOpaque<CustomCallFusedAttnDescriptor>(opaque, opaque_len); *UnpackOpaque<CustomCallFusedAttnDescriptor>(opaque, opaque_len);
// input // input
void *qkv = buffers[0]; void *qkv = buffers[0];
void *softmax_aux = buffers[1]; void *softmax_aux = buffers[1];
void *doutput = buffers[2]; void *rng_state = buffers[2];
void *cu_seqlens = buffers[3]; void *output = buffers[3];
void *doutput = buffers[4];
void *cu_seqlens = buffers[5];
// output // output
void *dqkv = buffers[4]; void *dqkv = buffers[6];
void *dp = softmax_aux; void *dbias = buffers[7];
void *dbias = buffers[5];
auto batch = descriptor.batch; auto batch = descriptor.batch;
auto num_head = descriptor.num_head; auto num_head = descriptor.num_head;
auto q_max_seqlen = descriptor.q_max_seqlen; auto q_max_seqlen = descriptor.q_max_seqlen;
auto kv_max_seqlen = descriptor.kv_max_seqlen; auto kv_max_seqlen = descriptor.kv_max_seqlen;
auto head_dim = descriptor.head_dim; auto head_dim = descriptor.head_dim;
auto dropout_probability = descriptor.dropout_probability;
auto bias_type = descriptor.bias_type;
auto mask_type = descriptor.mask_type;
constexpr auto qkv_layout = NVTE_QKV_Layout::NVTE_QKV_INTERLEAVED;
NVTE_CHECK(q_max_seqlen == kv_max_seqlen, NVTE_CHECK(q_max_seqlen == kv_max_seqlen,
"q_max_seqlen should be equal to kv_max_seqlen in the self attention."); "q_max_seqlen should be equal to kv_max_seqlen in the self attention.");
...@@ -850,11 +869,9 @@ void SelfFusedAttnMax512Backward(cudaStream_t stream, void **buffers, const char ...@@ -850,11 +869,9 @@ void SelfFusedAttnMax512Backward(cudaStream_t stream, void **buffers, const char
auto bias_shape = std::vector<size_t>{1, num_head, q_max_seqlen, kv_max_seqlen}; auto bias_shape = std::vector<size_t>{1, num_head, q_max_seqlen, kv_max_seqlen};
auto qkv_tensor = TensorWrapper(qkv, qkv_shape, dtype); auto qkv_tensor = TensorWrapper(qkv, qkv_shape, dtype);
auto output_tensor = TensorWrapper(output, output_shape, dtype);
auto doutput_tensor = TensorWrapper(doutput, output_shape, dtype); auto doutput_tensor = TensorWrapper(doutput, output_shape, dtype);
// It's a little trick that the flash attn needs fwd output // F16 doesn't use this tensor
// But when seqlen <= 512, it is not needed
auto output_tensor = TensorWrapper(nullptr, output_shape, dtype);
// FP16/BF16 doesn't use this tensor
auto s_tensor = TensorWrapper(nullptr, std::vector<size_t>{1}, dtype); auto s_tensor = TensorWrapper(nullptr, std::vector<size_t>{1}, dtype);
auto dqkv_tensor = TensorWrapper(dqkv, qkv_shape, dtype); auto dqkv_tensor = TensorWrapper(dqkv, qkv_shape, dtype);
...@@ -862,50 +879,47 @@ void SelfFusedAttnMax512Backward(cudaStream_t stream, void **buffers, const char ...@@ -862,50 +879,47 @@ void SelfFusedAttnMax512Backward(cudaStream_t stream, void **buffers, const char
auto cu_seqlens_tensor = auto cu_seqlens_tensor =
TensorWrapper(cu_seqlens, std::vector<size_t>{batch + 1}, DType::kInt32); TensorWrapper(cu_seqlens, std::vector<size_t>{batch + 1}, DType::kInt32);
// Currently, no rng_state required for bwd
auto rng_state = TensorWrapper(nullptr, std::vector<size_t>{1}, DType::kInt64);
// TODO: needs to think about how to pass aux_output_tensors // TODO: needs to think about how to pass aux_output_tensors
NVTETensorPack aux_output_tensors; NVTETensorPack aux_output_tensors;
nvte_tensor_pack_create(&aux_output_tensors); nvte_tensor_pack_create(&aux_output_tensors);
aux_output_tensors.size = 1; aux_output_tensors.size = 2;
auto *output_s = reinterpret_cast<Tensor *>(aux_output_tensors.tensors[0]); auto *output_s = reinterpret_cast<Tensor *>(aux_output_tensors.tensors[0]);
output_s->data.shape = std::vector<size_t>{batch, num_head, q_max_seqlen, kv_max_seqlen};
output_s->data.dptr = softmax_aux; output_s->data.dptr = softmax_aux;
auto *rng_state_tensor = reinterpret_cast<Tensor *>(aux_output_tensors.tensors[1]);
rng_state_tensor->data.shape = std::vector<size_t>{2};
rng_state_tensor->data.dtype = DType::kInt64;
rng_state_tensor->data.dptr = rng_state;
TensorWrapper query_workspace_tensor; TensorWrapper query_workspace_tensor;
nvte_fused_attn_bwd_qkvpacked(qkv_tensor.data(), output_tensor.data(), doutput_tensor.data(), nvte_fused_attn_bwd_qkvpacked(qkv_tensor.data(), output_tensor.data(), doutput_tensor.data(),
s_tensor.data(), // not used for FP16/BF16 s_tensor.data(), // not used for F16
s_tensor.data(), // not used for FP16/BF16 s_tensor.data(), // not used for F16
&aux_output_tensors, dqkv_tensor.data(), dbias_tensor.data(), &aux_output_tensors, dqkv_tensor.data(), dbias_tensor.data(),
cu_seqlens_tensor.data(), q_max_seqlen, descriptor.scaling_factor, cu_seqlens_tensor.data(), q_max_seqlen, descriptor.scaling_factor,
descriptor.dropout_probability, dropout_probability, qkv_layout, bias_type, mask_type,
NVTE_QKV_Layout::NVTE_QKV_INTERLEAVED, descriptor.bias_type, query_workspace_tensor.data(), stream);
descriptor.mask_type, query_workspace_tensor.data(), stream);
size_t workspace_size = size_t workspace_size = query_workspace_tensor.shape().data[0];
query_workspace_tensor.shape().data[0] * typeToSize(query_workspace_tensor.dtype());
auto *workspace = cublasLtMetaManager::Instance().GetWorkspace(workspace_size); auto *workspace = cublasLtMetaManager::Instance().GetWorkspace(workspace_size);
auto workspace_tensor = auto workspace_tensor =
TensorWrapper(workspace, query_workspace_tensor.shape(), query_workspace_tensor.dtype()); TensorWrapper(workspace, query_workspace_tensor.shape(), query_workspace_tensor.dtype());
nvte_fused_attn_bwd_qkvpacked(qkv_tensor.data(), output_tensor.data(), doutput_tensor.data(), nvte_fused_attn_bwd_qkvpacked(qkv_tensor.data(), output_tensor.data(), doutput_tensor.data(),
s_tensor.data(), // not used for FP16/BF16 s_tensor.data(), // not used for F16
s_tensor.data(), // not used for FP16/BF16 s_tensor.data(), // not used for F16
&aux_output_tensors, dqkv_tensor.data(), dbias_tensor.data(), &aux_output_tensors, dqkv_tensor.data(), dbias_tensor.data(),
cu_seqlens_tensor.data(), q_max_seqlen, descriptor.scaling_factor, cu_seqlens_tensor.data(), q_max_seqlen, descriptor.scaling_factor,
descriptor.dropout_probability, dropout_probability, qkv_layout, bias_type, mask_type,
NVTE_QKV_Layout::NVTE_QKV_INTERLEAVED, descriptor.bias_type, workspace_tensor.data(), stream);
descriptor.mask_type, workspace_tensor.data(), stream);
nvte_tensor_pack_destroy(&aux_output_tensors); nvte_tensor_pack_destroy(&aux_output_tensors);
} }
void CrossFusedAttnMax512Forward(cudaStream_t stream, void **buffers, const char *opaque, void CrossFusedAttnForward(cudaStream_t stream, void **buffers, const char *opaque,
size_t opaque_len) { size_t opaque_len) {
const CustomCallFusedAttnDescriptor &descriptor = const CustomCallFusedAttnDescriptor &descriptor =
*UnpackOpaque<CustomCallFusedAttnDescriptor>(opaque, opaque_len); *UnpackOpaque<CustomCallFusedAttnDescriptor>(opaque, opaque_len);
...@@ -925,6 +939,10 @@ void CrossFusedAttnMax512Forward(cudaStream_t stream, void **buffers, const char ...@@ -925,6 +939,10 @@ void CrossFusedAttnMax512Forward(cudaStream_t stream, void **buffers, const char
auto q_max_seqlen = descriptor.q_max_seqlen; auto q_max_seqlen = descriptor.q_max_seqlen;
auto kv_max_seqlen = descriptor.kv_max_seqlen; auto kv_max_seqlen = descriptor.kv_max_seqlen;
auto head_dim = descriptor.head_dim; auto head_dim = descriptor.head_dim;
auto dropout_probability = descriptor.dropout_probability;
auto bias_type = descriptor.bias_type;
auto mask_type = descriptor.mask_type;
constexpr auto qkv_layout = NVTE_QKV_Layout::NVTE_KV_INTERLEAVED;
auto dtype = descriptor.dtype; auto dtype = descriptor.dtype;
auto q_shape = std::vector<size_t>{batch * q_max_seqlen, num_head, head_dim}; auto q_shape = std::vector<size_t>{batch * q_max_seqlen, num_head, head_dim};
...@@ -958,8 +976,7 @@ void CrossFusedAttnMax512Forward(cudaStream_t stream, void **buffers, const char ...@@ -958,8 +976,7 @@ void CrossFusedAttnMax512Forward(cudaStream_t stream, void **buffers, const char
q_tensor.data(), kv_tensor.data(), bias_tensor.data(), s_tensor.data(), o_tensor.data(), q_tensor.data(), kv_tensor.data(), bias_tensor.data(), s_tensor.data(), o_tensor.data(),
&aux_output_tensors, q_cu_seqlens_tensor.data(), kv_cu_seqlens_tensor.data(), &aux_output_tensors, q_cu_seqlens_tensor.data(), kv_cu_seqlens_tensor.data(),
dummy_rng_state_tensor.data(), q_max_seqlen, kv_max_seqlen, descriptor.is_training, dummy_rng_state_tensor.data(), q_max_seqlen, kv_max_seqlen, descriptor.is_training,
descriptor.scaling_factor, descriptor.dropout_probability, descriptor.scaling_factor, dropout_probability, qkv_layout, bias_type, mask_type,
NVTE_QKV_Layout::NVTE_KV_INTERLEAVED, descriptor.bias_type, descriptor.mask_type,
query_workspace_tensor.data(), stream); query_workspace_tensor.data(), stream);
auto *output_s = reinterpret_cast<Tensor *>(aux_output_tensors.tensors[0]); auto *output_s = reinterpret_cast<Tensor *>(aux_output_tensors.tensors[0]);
...@@ -976,21 +993,24 @@ void CrossFusedAttnMax512Forward(cudaStream_t stream, void **buffers, const char ...@@ -976,21 +993,24 @@ void CrossFusedAttnMax512Forward(cudaStream_t stream, void **buffers, const char
auto rng_state = static_cast<uint8_t *>(workspace) + plan_workspace_size; auto rng_state = static_cast<uint8_t *>(workspace) + plan_workspace_size;
auto rng_state_tensor = TensorWrapper(rng_state, std::vector<size_t>{2}, DType::kInt64); auto rng_state_tensor = TensorWrapper(rng_state, std::vector<size_t>{2}, DType::kInt64);
PopulateRngStateAsync(rng_state, seed, q_max_seqlen, kv_max_seqlen, stream);
auto backend = nvte_get_fused_attn_backend(
static_cast<NVTEDType>(dtype), static_cast<NVTEDType>(dtype), qkv_layout, bias_type,
mask_type, dropout_probability, q_max_seqlen, kv_max_seqlen, head_dim);
PopulateRngStateAsync(rng_state, seed, q_max_seqlen, kv_max_seqlen, backend, stream);
nvte_fused_attn_fwd_kvpacked( nvte_fused_attn_fwd_kvpacked(
q_tensor.data(), kv_tensor.data(), bias_tensor.data(), s_tensor.data(), o_tensor.data(), q_tensor.data(), kv_tensor.data(), bias_tensor.data(), s_tensor.data(), o_tensor.data(),
&aux_output_tensors, q_cu_seqlens_tensor.data(), kv_cu_seqlens_tensor.data(), &aux_output_tensors, q_cu_seqlens_tensor.data(), kv_cu_seqlens_tensor.data(),
rng_state_tensor.data(), q_max_seqlen, kv_max_seqlen, descriptor.is_training, rng_state_tensor.data(), q_max_seqlen, kv_max_seqlen, descriptor.is_training,
descriptor.scaling_factor, descriptor.dropout_probability, descriptor.scaling_factor, dropout_probability, qkv_layout, bias_type, mask_type,
NVTE_QKV_Layout::NVTE_KV_INTERLEAVED, descriptor.bias_type, descriptor.mask_type,
workspace_tensor.data(), stream); workspace_tensor.data(), stream);
nvte_tensor_pack_destroy(&aux_output_tensors); nvte_tensor_pack_destroy(&aux_output_tensors);
} }
void CrossFusedAttnMax512Backward(cudaStream_t stream, void **buffers, const char *opaque, void CrossFusedAttnBackward(cudaStream_t stream, void **buffers, const char *opaque,
size_t opaque_len) { size_t opaque_len) {
const CustomCallFusedAttnDescriptor &descriptor = const CustomCallFusedAttnDescriptor &descriptor =
*UnpackOpaque<CustomCallFusedAttnDescriptor>(opaque, opaque_len); *UnpackOpaque<CustomCallFusedAttnDescriptor>(opaque, opaque_len);
......
...@@ -116,6 +116,12 @@ pybind11::bytes PackCustomCallFusedAttnDescriptor( ...@@ -116,6 +116,12 @@ pybind11::bytes PackCustomCallFusedAttnDescriptor(
bool IsFusedAttnKernelAvailable(); bool IsFusedAttnKernelAvailable();
NVTE_Fused_Attn_Backend GetFusedAttnBackend(DType q_dtype, DType kv_dtype,
NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type,
NVTE_Mask_Type mask_type, float dropout_probability,
size_t q_max_seqlen, size_t kv_max_seqlen,
size_t head_dim);
void Transpose(cudaStream_t stream, void **buffers, const char *opaque, size_t opaque_len); void Transpose(cudaStream_t stream, void **buffers, const char *opaque, size_t opaque_len);
void CastTranspose(cudaStream_t stream, void **buffers, const char *opaque, size_t opaque_len); void CastTranspose(cudaStream_t stream, void **buffers, const char *opaque, size_t opaque_len);
...@@ -166,17 +172,17 @@ void ScaledUpperTriangMaskedSoftmaxForward(cudaStream_t stream, void **buffers, ...@@ -166,17 +172,17 @@ void ScaledUpperTriangMaskedSoftmaxForward(cudaStream_t stream, void **buffers,
void ScaledUpperTriangMaskedSoftmaxBackward(cudaStream_t stream, void **buffers, const char *opaque, void ScaledUpperTriangMaskedSoftmaxBackward(cudaStream_t stream, void **buffers, const char *opaque,
std::size_t opaque_len); std::size_t opaque_len);
void SelfFusedAttnMax512Forward(cudaStream_t stream, void **buffers, const char *opaque, void SelfFusedAttnForward(cudaStream_t stream, void **buffers, const char *opaque,
size_t opaque_len); size_t opaque_len);
void SelfFusedAttnMax512Backward(cudaStream_t stream, void **buffers, const char *opaque, void SelfFusedAttnBackward(cudaStream_t stream, void **buffers, const char *opaque,
size_t opaque_len); size_t opaque_len);
void CrossFusedAttnMax512Forward(cudaStream_t stream, void **buffers, const char *opaque, void CrossFusedAttnForward(cudaStream_t stream, void **buffers, const char *opaque,
size_t opaque_len); size_t opaque_len);
void CrossFusedAttnMax512Backward(cudaStream_t stream, void **buffers, const char *opaque, void CrossFusedAttnBackward(cudaStream_t stream, void **buffers, const char *opaque,
size_t opaque_len); size_t opaque_len);
} // namespace jax } // namespace jax
} // namespace transformer_engine } // namespace transformer_engine
......
...@@ -41,9 +41,15 @@ __global__ void populate_rng_state_kernel(int64_t *rng_state_dst, const int64_t ...@@ -41,9 +41,15 @@ __global__ void populate_rng_state_kernel(int64_t *rng_state_dst, const int64_t
} }
void PopulateRngStateAsync(void *rng_state_dst, const void *const seed, size_t q_max_seqlen, void PopulateRngStateAsync(void *rng_state_dst, const void *const seed, size_t q_max_seqlen,
size_t kv_max_seqlen, cudaStream_t stream) { size_t kv_max_seqlen, NVTE_Fused_Attn_Backend backend,
constexpr int threads_per_cta = 128; cudaStream_t stream) {
const size_t increment = (q_max_seqlen * kv_max_seqlen + threads_per_cta - 1) / threads_per_cta; size_t increment = 0;
if (backend == NVTE_Fused_Attn_Backend::NVTE_F16_arbitrary_seqlen) {
increment = 16;
} else {
constexpr int threads_per_cta = 128;
increment = (q_max_seqlen * kv_max_seqlen + threads_per_cta - 1) / threads_per_cta;
}
auto offset = FusedAttnOffsetManager::Instance().GetAndUpdateOffset(increment); auto offset = FusedAttnOffsetManager::Instance().GetAndUpdateOffset(increment);
populate_rng_state_kernel<<<1, 1, 0, stream>>>(reinterpret_cast<int64_t *>(rng_state_dst), populate_rng_state_kernel<<<1, 1, 0, stream>>>(reinterpret_cast<int64_t *>(rng_state_dst),
reinterpret_cast<const int64_t *>(seed), offset); reinterpret_cast<const int64_t *>(seed), offset);
......
...@@ -13,6 +13,7 @@ ...@@ -13,6 +13,7 @@
#include <stdexcept> #include <stdexcept>
#include <string> #include <string>
#include <type_traits> #include <type_traits>
#include "transformer_engine/fused_attn.h"
#include "transformer_engine/logging.h" #include "transformer_engine/logging.h"
namespace transformer_engine { namespace transformer_engine {
...@@ -22,7 +23,8 @@ int GetCudaRuntimeVersion(); ...@@ -22,7 +23,8 @@ int GetCudaRuntimeVersion();
int GetDeviceComputeCapability(int gpu_id); int GetDeviceComputeCapability(int gpu_id);
void PopulateRngStateAsync(void *rng_state_dst, const void *const seed, size_t q_max_seqlen, void PopulateRngStateAsync(void *rng_state_dst, const void *const seed, size_t q_max_seqlen,
size_t kv_max_seqlen, cudaStream_t stream); size_t kv_max_seqlen, NVTE_Fused_Attn_Backend backend,
cudaStream_t stream);
class cublasLtMetaManager { class cublasLtMetaManager {
public: public:
......
...@@ -178,7 +178,8 @@ def core_attention(query: Array, ...@@ -178,7 +178,8 @@ def core_attention(query: Array,
attn_weights = Softmax(softmax_type=softmax_type, attn_weights = Softmax(softmax_type=softmax_type,
scale_factor=fused_scale_factor, scale_factor=fused_scale_factor,
sharding_type=softmax_sharding_type)(attn_weights, mask, bias) sharding_type=softmax_sharding_type)(attn_weights, mask,
bias).astype(dtype)
if not deterministic and dropout_rate > 0.: if not deterministic and dropout_rate > 0.:
keep_prob = 1.0 - dropout_rate keep_prob = 1.0 - dropout_rate
...@@ -369,12 +370,20 @@ class MultiHeadAttention(nn.Module): ...@@ -369,12 +370,20 @@ class MultiHeadAttention(nn.Module):
canonicalize_dtype = dtypes.canonicalize_dtype(self.dtype) canonicalize_dtype = dtypes.canonicalize_dtype(self.dtype)
q_seqlen = inputs_q.shape[0] if self.transpose_batch_sequence else inputs_q.shape[1] q_seqlen = inputs_q.shape[0] if self.transpose_batch_sequence else inputs_q.shape[1]
kv_seqlen = inputs_kv.shape[0] if self.transpose_batch_sequence else inputs_kv.shape[1] kv_seqlen = inputs_kv.shape[0] if self.transpose_batch_sequence else inputs_kv.shape[1]
fused_attn_supported_seqlen = [128, 256, 384, 512]
enable_fused_attn = int(os.getenv("NVTE_FUSED_ATTN", "0")) enable_fused_attn = int(os.getenv("NVTE_FUSED_ATTN", "0"))
def _check_seqlen(seqlen):
return seqlen % 64 == 0
def _check_head_dim(head_dim):
return head_dim in [64, 128]
use_fused_attn = not decode and not self.transpose_batch_sequence and self.fuse_qkv and \ use_fused_attn = not decode and not self.transpose_batch_sequence and self.fuse_qkv and \
canonicalize_dtype in [jnp.bfloat16, jnp.float16] and \ canonicalize_dtype in [jnp.bfloat16, jnp.float16] and \
q_seqlen in fused_attn_supported_seqlen and kv_seqlen in fused_attn_supported_seqlen \ _check_seqlen(q_seqlen) and _check_seqlen(kv_seqlen) and \
and is_fused_attn_kernel_available() and (self.head_dim == 64) and enable_fused_attn _check_head_dim(self.head_dim) and \
is_fused_attn_kernel_available() and \
enable_fused_attn
if enable_fused_attn and not use_fused_attn: if enable_fused_attn and not use_fused_attn:
reason = "" reason = ""
...@@ -388,16 +397,16 @@ class MultiHeadAttention(nn.Module): ...@@ -388,16 +397,16 @@ class MultiHeadAttention(nn.Module):
if canonicalize_dtype not in [jnp.bfloat16, jnp.float16]: if canonicalize_dtype not in [jnp.bfloat16, jnp.float16]:
reason += f"dtype in [BF16, FP16] is required " \ reason += f"dtype in [BF16, FP16] is required " \
f"but got dtype={canonicalize_dtype}, " f"but got dtype={canonicalize_dtype}, "
if q_seqlen not in fused_attn_supported_seqlen: if not _check_seqlen(q_seqlen):
reason += f"q_seqlen in {fused_attn_supported_seqlen} is required " \ reason += f"q_seqlen % 64 == 0 is required " \
f"but got {q_seqlen=}, " f"but got {q_seqlen=}, "
if kv_seqlen not in fused_attn_supported_seqlen: if not _check_seqlen(kv_seqlen):
reason += f"kv_seqlen in {fused_attn_supported_seqlen} is required " \ reason += f"kv_seqlen % 64 == 0 is required " \
f"but got {kv_seqlen=}, " f"but got {kv_seqlen=}, "
if not _check_head_dim(self.head_dim):
reason += f"head_dim should be 64 or 128 but got {self.head_dim}, "
if not is_fused_attn_kernel_available(): if not is_fused_attn_kernel_available():
reason += "GPU arch >= Ampere and cuDNN >= 8.9.1 are required, " reason += "GPU arch >= Ampere and cuDNN >= 8.9.1 are required, "
if self.head_dim != 64:
reason += f"head_dim should be 64 but got {self.head_dim}, "
warnings.warn( warnings.warn(
f"Fused attention is not enabled, " \ f"Fused attention is not enabled, " \
......
...@@ -12,8 +12,8 @@ import transformer_engine_jax ...@@ -12,8 +12,8 @@ import transformer_engine_jax
from transformer_engine_jax import NVTE_Bias_Type from transformer_engine_jax import NVTE_Bias_Type
from transformer_engine_jax import NVTE_Mask_Type from transformer_engine_jax import NVTE_Mask_Type
from .cpp_extensions import cross_fused_attn_max_512_fwd, cross_fused_attn_max_512_bwd from .cpp_extensions import cross_fused_attn_fwd, cross_fused_attn_bwd
from .cpp_extensions import self_fused_attn_max_512_fwd, self_fused_attn_max_512_bwd from .cpp_extensions import self_fused_attn_fwd, self_fused_attn_bwd
from .sharding import get_fused_attn_sharding_meta from .sharding import get_fused_attn_sharding_meta
from .sharding import ShardingType from .sharding import ShardingType
from .sharding import xmap_runner from .sharding import xmap_runner
...@@ -57,18 +57,18 @@ def self_fused_attn(qkv: jnp.ndarray, ...@@ -57,18 +57,18 @@ def self_fused_attn(qkv: jnp.ndarray,
Self fused attention wrapper Self fused attention wrapper
""" """
assert sharding_type not in (ShardingType.TP_ROW, ShardingType.DP_TP_ROW), \ assert sharding_type not in (ShardingType.TP_ROW, ShardingType.DP_TP_ROW), \
"Fused_attn_max_512 does not support row-split tensor parallelism currently." "self_fused_attn does not support row-split tensor parallelism currently."
if sharding_type is ShardingType.SINGLE: if sharding_type is ShardingType.SINGLE:
output = _self_fused_attn_max_512(qkv, output = _self_fused_attn(qkv,
bias, bias,
mask, mask,
seed, seed,
attn_bias_type=attn_bias_type, attn_bias_type=attn_bias_type,
attn_mask_type=attn_mask_type, attn_mask_type=attn_mask_type,
scaling_factor=scaling_factor, scaling_factor=scaling_factor,
dropout_probability=dropout_probability, dropout_probability=dropout_probability,
is_training=is_training) is_training=is_training)
else: else:
dp_axis_name = "batch" dp_axis_name = "batch"
tp_axis_name = "model" tp_axis_name = "model"
...@@ -87,14 +87,14 @@ def self_fused_attn(qkv: jnp.ndarray, ...@@ -87,14 +87,14 @@ def self_fused_attn(qkv: jnp.ndarray,
jnp.reshape(x, new_shape) if x is not None else None jnp.reshape(x, new_shape) if x is not None else None
for x, new_shape in zip(inputs, sharding_meta.input_shapes)) for x, new_shape in zip(inputs, sharding_meta.input_shapes))
partial_self_fused_attn_max_512 = partial(_self_fused_attn_max_512, partial_self_fused_attn = partial(_self_fused_attn,
attn_bias_type=attn_bias_type, attn_bias_type=attn_bias_type,
attn_mask_type=attn_mask_type, attn_mask_type=attn_mask_type,
scaling_factor=scaling_factor, scaling_factor=scaling_factor,
dropout_probability=dropout_probability, dropout_probability=dropout_probability,
is_training=is_training) is_training=is_training)
output_ = xmap_runner(partial_self_fused_attn_max_512, sharding_meta.in_axes, output_ = xmap_runner(partial_self_fused_attn, sharding_meta.in_axes,
sharding_meta.out_axes[0], sharding_meta.axis_resources, inputs_) sharding_meta.out_axes[0], sharding_meta.axis_resources, inputs_)
output = jnp.reshape(output_, sharding_meta.output_shapes[0]) output = jnp.reshape(output_, sharding_meta.output_shapes[0])
...@@ -103,61 +103,65 @@ def self_fused_attn(qkv: jnp.ndarray, ...@@ -103,61 +103,65 @@ def self_fused_attn(qkv: jnp.ndarray,
@partial(jax.custom_vjp, nondiff_argnums=(4, 5, 6, 7, 8)) @partial(jax.custom_vjp, nondiff_argnums=(4, 5, 6, 7, 8))
def _self_fused_attn_max_512(qkv: jnp.ndarray, bias: jnp.ndarray, mask: jnp.ndarray, def _self_fused_attn(qkv: jnp.ndarray, bias: jnp.ndarray, mask: jnp.ndarray, seed: jnp.ndarray,
seed: jnp.ndarray, attn_bias_type: AttnBiasType, attn_bias_type: AttnBiasType, attn_mask_type: AttnMaskType,
attn_mask_type: AttnMaskType, scaling_factor: float, scaling_factor: float, dropout_probability: float, is_training: bool):
dropout_probability: float, is_training: bool): output, _ = _self_fused_attn_fwd(qkv,
output, _ = _self_fused_attn_max_512_fwd(qkv, bias,
bias, mask,
mask, seed,
seed, attn_bias_type=attn_bias_type,
attn_bias_type=attn_bias_type, attn_mask_type=attn_mask_type,
attn_mask_type=attn_mask_type, scaling_factor=scaling_factor,
scaling_factor=scaling_factor, dropout_probability=dropout_probability,
dropout_probability=dropout_probability, is_training=is_training)
is_training=is_training)
return output return output
def _self_fused_attn_max_512_fwd(qkv, bias, mask, seed, attn_bias_type, attn_mask_type, def _self_fused_attn_fwd(qkv, bias, mask, seed, attn_bias_type, attn_mask_type, scaling_factor,
scaling_factor, dropout_probability, is_training): dropout_probability, is_training):
seqlen = jnp.sum(mask[:, :, :, 0] == 0, axis=(-1, -2), dtype=jnp.int32) seqlen = jnp.sum(mask[:, :, :, 0] == 0, axis=(-1, -2), dtype=jnp.int32)
cu_seqlen = jnp.cumsum(seqlen) cu_seqlen = jnp.cumsum(seqlen)
cu_seqlen = jnp.hstack((0, cu_seqlen)) cu_seqlen = jnp.hstack((0, cu_seqlen))
output, softmax_aux = self_fused_attn_max_512_fwd(qkv, output, softmax_aux, rng_state = self_fused_attn_fwd(qkv,
bias, bias,
cu_seqlen, cu_seqlen,
seed, seed,
attn_bias_type=attn_bias_type.value, attn_bias_type=attn_bias_type.value,
attn_mask_type=attn_mask_type.value, attn_mask_type=attn_mask_type.value,
scaling_factor=scaling_factor, scaling_factor=scaling_factor,
dropout_probability=dropout_probability, dropout_probability=dropout_probability,
is_training=is_training) is_training=is_training)
return output, (qkv, softmax_aux, cu_seqlen) return output, (qkv, softmax_aux, rng_state, output, cu_seqlen)
def _self_fused_attn_max_512_bwd(attn_bias_type, attn_mask_type, scaling_factor, def _self_fused_attn_bwd(attn_bias_type, attn_mask_type, scaling_factor, dropout_probability,
dropout_probability, is_training, ctx, grad): is_training, ctx, grad):
qkv, softmax_aux, cu_seqlen = ctx qkv, softmax_aux, rng_state, output, cu_seqlen = ctx
doutput = grad doutput = grad
grad_qkv, grad_bias = self_fused_attn_max_512_bwd(qkv, grad_qkv, grad_bias = self_fused_attn_bwd(qkv,
softmax_aux, softmax_aux,
doutput, rng_state,
cu_seqlen, output,
attn_bias_type=attn_bias_type.value, doutput,
attn_mask_type=attn_mask_type.value, cu_seqlen,
scaling_factor=scaling_factor, attn_bias_type=attn_bias_type.value,
dropout_probability=dropout_probability, attn_mask_type=attn_mask_type.value,
is_training=is_training) scaling_factor=scaling_factor,
dropout_probability=dropout_probability,
is_training=is_training)
if attn_bias_type == NVTE_Bias_Type.NVTE_NO_BIAS:
grad_bias = None
return grad_qkv, grad_bias, None, None return grad_qkv, grad_bias, None, None
_self_fused_attn_max_512.defvjp(_self_fused_attn_max_512_fwd, _self_fused_attn_max_512_bwd) _self_fused_attn.defvjp(_self_fused_attn_fwd, _self_fused_attn_bwd)
def cross_fused_attn(q: jnp.ndarray, def cross_fused_attn(q: jnp.ndarray,
...@@ -174,18 +178,18 @@ def cross_fused_attn(q: jnp.ndarray, ...@@ -174,18 +178,18 @@ def cross_fused_attn(q: jnp.ndarray,
Cross multi-head attention wrapper Cross multi-head attention wrapper
""" """
assert sharding_type not in (ShardingType.TP_ROW, ShardingType.DP_TP_ROW), \ assert sharding_type not in (ShardingType.TP_ROW, ShardingType.DP_TP_ROW), \
"Fused_attn_max_512 does not support row-split tensor parallelism currently." "cross_fused_attn does not support row-split tensor parallelism currently."
if sharding_type is ShardingType.SINGLE: if sharding_type is ShardingType.SINGLE:
output = _cross_fused_attn_max_512(q, output = _cross_fused_attn(q,
kv, kv,
mask, mask,
seed, seed,
attn_bias_type=attn_bias_type, attn_bias_type=attn_bias_type,
attn_mask_type=attn_mask_type, attn_mask_type=attn_mask_type,
scaling_factor=scaling_factor, scaling_factor=scaling_factor,
dropout_probability=dropout_probability, dropout_probability=dropout_probability,
is_training=is_training) is_training=is_training)
else: else:
dp_axis_name = "batch" dp_axis_name = "batch"
tp_axis_name = "model" tp_axis_name = "model"
...@@ -203,14 +207,14 @@ def cross_fused_attn(q: jnp.ndarray, ...@@ -203,14 +207,14 @@ def cross_fused_attn(q: jnp.ndarray,
jnp.reshape(x, new_shape) if x is not None else None jnp.reshape(x, new_shape) if x is not None else None
for x, new_shape in zip(inputs, sharding_meta.input_shapes)) for x, new_shape in zip(inputs, sharding_meta.input_shapes))
partial_cross_fused_attn_max_512 = partial(_cross_fused_attn_max_512, partial_cross_fused_attn = partial(_cross_fused_attn,
attn_bias_type=attn_bias_type, attn_bias_type=attn_bias_type,
attn_mask_type=attn_mask_type, attn_mask_type=attn_mask_type,
scaling_factor=scaling_factor, scaling_factor=scaling_factor,
dropout_probability=dropout_probability, dropout_probability=dropout_probability,
is_training=is_training) is_training=is_training)
output_ = xmap_runner(partial_cross_fused_attn_max_512, sharding_meta.in_axes, output_ = xmap_runner(partial_cross_fused_attn, sharding_meta.in_axes,
sharding_meta.out_axes[0], sharding_meta.axis_resources, inputs_) sharding_meta.out_axes[0], sharding_meta.axis_resources, inputs_)
output = jnp.reshape(output_, sharding_meta.output_shapes[0]) output = jnp.reshape(output_, sharding_meta.output_shapes[0])
...@@ -219,24 +223,24 @@ def cross_fused_attn(q: jnp.ndarray, ...@@ -219,24 +223,24 @@ def cross_fused_attn(q: jnp.ndarray,
@partial(jax.custom_vjp, nondiff_argnums=(4, 5, 6, 7, 8)) @partial(jax.custom_vjp, nondiff_argnums=(4, 5, 6, 7, 8))
def _cross_fused_attn_max_512(q: jnp.ndarray, kv: jnp.ndarray, mask: jnp.ndarray, seed: jnp.ndarray, def _cross_fused_attn(q: jnp.ndarray, kv: jnp.ndarray, mask: jnp.ndarray, seed: jnp.ndarray,
attn_bias_type: AttnBiasType, attn_mask_type: AttnMaskType, attn_bias_type: AttnBiasType, attn_mask_type: AttnMaskType,
scaling_factor: float, dropout_probability: float, is_training: bool): scaling_factor: float, dropout_probability: float, is_training: bool):
output, _ = _cross_fused_attn_max_512_fwd(q, output, _ = _cross_fused_attn_fwd(q,
kv, kv,
mask, mask,
seed, seed,
attn_bias_type=attn_bias_type, attn_bias_type=attn_bias_type,
attn_mask_type=attn_mask_type, attn_mask_type=attn_mask_type,
scaling_factor=scaling_factor, scaling_factor=scaling_factor,
dropout_probability=dropout_probability, dropout_probability=dropout_probability,
is_training=is_training) is_training=is_training)
return output return output
def _cross_fused_attn_max_512_fwd(q, kv, mask, seed, attn_bias_type, attn_mask_type, scaling_factor, def _cross_fused_attn_fwd(q, kv, mask, seed, attn_bias_type, attn_mask_type, scaling_factor,
dropout_probability, is_training): dropout_probability, is_training):
q_seqlen = jnp.sum(mask[:, :, :, 0] == 0, axis=(-1, -2), dtype=jnp.int32) q_seqlen = jnp.sum(mask[:, :, :, 0] == 0, axis=(-1, -2), dtype=jnp.int32)
q_cu_seqlen = jnp.cumsum(q_seqlen) q_cu_seqlen = jnp.cumsum(q_seqlen)
...@@ -246,38 +250,38 @@ def _cross_fused_attn_max_512_fwd(q, kv, mask, seed, attn_bias_type, attn_mask_t ...@@ -246,38 +250,38 @@ def _cross_fused_attn_max_512_fwd(q, kv, mask, seed, attn_bias_type, attn_mask_t
kv_cu_seqlen = jnp.cumsum(kv_seqlen) kv_cu_seqlen = jnp.cumsum(kv_seqlen)
kv_cu_seqlen = jnp.hstack((0, kv_cu_seqlen)) kv_cu_seqlen = jnp.hstack((0, kv_cu_seqlen))
output, softmax_aux = cross_fused_attn_max_512_fwd(q, output, softmax_aux = cross_fused_attn_fwd(q,
kv, kv,
q_cu_seqlen, q_cu_seqlen,
kv_cu_seqlen, kv_cu_seqlen,
seed, seed,
attn_bias_type=attn_bias_type.value, attn_bias_type=attn_bias_type.value,
attn_mask_type=attn_mask_type.value, attn_mask_type=attn_mask_type.value,
scaling_factor=scaling_factor, scaling_factor=scaling_factor,
dropout_probability=dropout_probability, dropout_probability=dropout_probability,
is_training=is_training) is_training=is_training)
return output, (softmax_aux, q, kv, q_cu_seqlen, kv_cu_seqlen) return output, (softmax_aux, q, kv, q_cu_seqlen, kv_cu_seqlen)
def _cross_fused_attn_max_512_bwd(attn_bias_type, attn_mask_type, scaling_factor, def _cross_fused_attn_bwd(attn_bias_type, attn_mask_type, scaling_factor, dropout_probability,
dropout_probability, is_training, ctx, grad): is_training, ctx, grad):
softmax_aux, q, kv, q_cu_seqlen, kv_cu_seqlen = ctx softmax_aux, q, kv, q_cu_seqlen, kv_cu_seqlen = ctx
doutput = grad doutput = grad
grad_q, grad_kv = cross_fused_attn_max_512_bwd(q, grad_q, grad_kv = cross_fused_attn_bwd(q,
kv, kv,
softmax_aux, softmax_aux,
doutput, doutput,
q_cu_seqlen, q_cu_seqlen,
kv_cu_seqlen, kv_cu_seqlen,
attn_bias_type=attn_bias_type.value, attn_bias_type=attn_bias_type.value,
attn_mask_type=attn_mask_type.value, attn_mask_type=attn_mask_type.value,
scaling_factor=scaling_factor, scaling_factor=scaling_factor,
dropout_probability=dropout_probability, dropout_probability=dropout_probability,
is_training=is_training) is_training=is_training)
return grad_q, grad_kv, None, None return grad_q, grad_kv, None, None
_cross_fused_attn_max_512.defvjp(_cross_fused_attn_max_512_fwd, _cross_fused_attn_max_512_bwd) _cross_fused_attn.defvjp(_cross_fused_attn_fwd, _cross_fused_attn_bwd)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment