Unverified Commit 687697a7 authored by Reese Wang's avatar Reese Wang Committed by GitHub
Browse files

[JAX] Add experimental internal used THD(packed) fused attn API (#964)



* Integrate experimental ragged offset
Signed-off-by: default avatarReese Wang <rewang@nvidia.com>

* Use per sequence based offsets
Signed-off-by: default avatarReese Wang <rewang@nvidia.com>

* Format
Signed-off-by: default avatarReese Wang <rewang@nvidia.com>

* Remove v/o_seq_offsets
Signed-off-by: default avatarReese Wang <rewang@nvidia.com>

* Add FP16 sanity tests and remove forward tests from the automatically run tests
Signed-off-by: default avatarReese Wang <rewang@nvidia.com>

* Enhance input checks
Signed-off-by: default avatarReese Wang <rewang@nvidia.com>

* Separate fused attn to 2 differnt APIs and add the docs
Signed-off-by: default avatarReese Wang <rewang@nvidia.com>

* Add experimental to the docs
Signed-off-by: default avatarReese Wang <rewang@nvidia.com>

* Fix lint
Signed-off-by: default avatarReese Wang <rewang@nvidia.com>

* Add runtime segments check
Signed-off-by: default avatarReese Wang <rewang@nvidia.com>

* Remove finished TODO
Signed-off-by: default avatarReese Wang <rewang@nvidia.com>

---------
Signed-off-by: default avatarReese Wang <rewang@nvidia.com>
parent 7669bf3d
......@@ -15,8 +15,7 @@ from utils import make_causal_mask, make_self_mask
from transformer_engine.jax import fp8_autocast
from transformer_engine.jax.attention import (
is_fused_attn_kernel_available,
fused_attn_qkvpacked,
fused_attn_kvpacked,
fused_attn,
AttnBiasType,
AttnMaskType,
QKVLayout,
......@@ -120,11 +119,15 @@ class TestDistributedSelfAttn:
def target_func(qkv, bias, mask):
return jnp.mean(
fused_attn_qkvpacked(
qkv,
fused_attn(
(qkv,),
bias,
mask,
None,
None,
None,
None,
None,
attn_bias_type=attn_bias_type,
attn_mask_type=attn_mask_type,
scaling_factor=scaling_factor,
......@@ -252,12 +255,15 @@ class TestDistributedCrossAttn:
def target_func(q, kv, mask):
return jnp.mean(
fused_attn_kvpacked(
q,
kv,
fused_attn(
(q, kv),
None,
mask,
None,
None,
None,
None,
None,
attn_bias_type=attn_bias_type,
attn_mask_type=attn_mask_type,
scaling_factor=scaling_factor,
......
This diff is collapsed.
This diff is collapsed.
......@@ -137,6 +137,7 @@ struct CustomCallFusedAttnDescriptor {
size_t num_gqa_groups;
size_t bias_heads;
size_t head_dim;
size_t max_segments_per_seq;
size_t wkspace_size;
float scaling_factor;
float dropout_probability;
......@@ -151,9 +152,9 @@ struct CustomCallFusedAttnDescriptor {
pybind11::bytes PackCustomCallFusedAttnDescriptor(
size_t input_batch, size_t batch_size, size_t q_max_seqlen, size_t kv_max_seqlen,
size_t attn_heads, size_t num_gqa_groups, size_t bias_heads, size_t head_dim,
size_t wkspace_size, float scaling_factor, float dropout_probability, NVTE_Bias_Type bias_type,
NVTE_Mask_Type mask_type, NVTE_QKV_Layout qkv_layout, DType dtype, DType wkspace_dtype,
bool is_training);
size_t max_segments_per_seq, size_t wkspace_size, float scaling_factor,
float dropout_probability, NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type,
NVTE_QKV_Layout qkv_layout, DType dtype, DType wkspace_dtype, bool is_training);
// Transpose
......@@ -249,7 +250,8 @@ pybind11::tuple GetFusedAttnForwardWorkspaceSizes(
size_t input_batch, size_t bias_batch, size_t q_max_seqlen, size_t kv_max_seqlen,
size_t attn_heads, size_t num_gqa_groups, size_t bias_heads, size_t head_dim,
float scaling_factor, float dropout_probability, NVTE_Bias_Type bias_type,
NVTE_Mask_Type mask_type, NVTE_QKV_Layout qkv_layout, DType dtype, bool is_training);
NVTE_Mask_Type mask_type, NVTE_QKV_Layout qkv_layout, DType dtype, bool is_training,
size_t max_segments_per_seq);
void FusedAttnForward(cudaStream_t stream, void **buffers, const char *opaque, size_t opaque_len);
......@@ -257,7 +259,8 @@ pybind11::tuple GetFusedAttnBackwardWorkspaceSizes(
size_t input_batch, size_t bias_batch, size_t q_max_seqlen, size_t kv_max_seqlen,
size_t attn_heads, size_t num_gqa_groups, size_t bias_heads, size_t head_dim,
float scaling_factor, float dropout_probability, NVTE_Bias_Type bias_type,
NVTE_Mask_Type mask_type, NVTE_QKV_Layout qkv_layout, DType dtype, bool is_training);
NVTE_Mask_Type mask_type, NVTE_QKV_Layout qkv_layout, DType dtype, bool is_training,
size_t max_segments_per_seq);
void FusedAttnBackward(cudaStream_t stream, void **buffers, const char *opaque, size_t opaque_len);
......
......@@ -66,13 +66,13 @@ pybind11::bytes PackCustomCallSoftmaxDescriptor(size_t batch_size, size_t paddin
pybind11::bytes PackCustomCallFusedAttnDescriptor(
size_t input_batch, size_t bias_batch, size_t q_max_seqlen, size_t kv_max_seqlen,
size_t attn_heads, size_t num_gqa_groups, size_t bias_heads, size_t head_dim,
size_t wkspace_size, float scaling_factor, float dropout_probability, NVTE_Bias_Type bias_type,
NVTE_Mask_Type mask_type, NVTE_QKV_Layout qkv_layout, DType dtype, DType wkspace_dtype,
bool is_training) {
size_t max_segments_per_seq, size_t wkspace_size, float scaling_factor,
float dropout_probability, NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type,
NVTE_QKV_Layout qkv_layout, DType dtype, DType wkspace_dtype, bool is_training) {
return PackOpaque(CustomCallFusedAttnDescriptor{
input_batch, bias_batch, q_max_seqlen, kv_max_seqlen, attn_heads, num_gqa_groups, bias_heads,
head_dim, wkspace_size, scaling_factor, dropout_probability, bias_type, mask_type, qkv_layout,
dtype, wkspace_dtype, is_training});
head_dim, max_segments_per_seq, wkspace_size, scaling_factor, dropout_probability, bias_type,
mask_type, qkv_layout, dtype, wkspace_dtype, is_training});
}
} // namespace jax
......
......@@ -67,6 +67,7 @@ PYBIND11_MODULE(transformer_engine_jax, m) {
m.def("get_layernorm_bwd_workspace_sizes", &GetLayerNormBackwardWorkspaceSizes);
m.def("get_fused_attn_fwd_workspace_sizes", &GetFusedAttnForwardWorkspaceSizes);
m.def("get_fused_attn_bwd_workspace_sizes", &GetFusedAttnBackwardWorkspaceSizes);
m.def("nvte_get_qkv_format", &nvte_get_qkv_format);
pybind11::enum_<DType>(m, "DType", pybind11::module_local())
.value("kByte", DType::kByte)
......@@ -92,7 +93,15 @@ PYBIND11_MODULE(transformer_engine_jax, m) {
pybind11::enum_<NVTE_QKV_Layout>(m, "NVTE_QKV_Layout", pybind11::module_local())
.value("NVTE_BS3HD", NVTE_QKV_Layout::NVTE_BS3HD)
.value("NVTE_BSHD_BS2HD", NVTE_QKV_Layout::NVTE_BSHD_BS2HD)
.value("NVTE_BSHD_BSHD_BSHD", NVTE_QKV_Layout::NVTE_BSHD_BSHD_BSHD);
.value("NVTE_BSHD_BSHD_BSHD", NVTE_QKV_Layout::NVTE_BSHD_BSHD_BSHD)
.value("NVTE_T3HD", NVTE_QKV_Layout::NVTE_T3HD)
.value("NVTE_THD_T2HD", NVTE_QKV_Layout::NVTE_THD_T2HD)
.value("NVTE_THD_THD_THD", NVTE_QKV_Layout::NVTE_THD_THD_THD);
pybind11::enum_<NVTE_QKV_Format>(m, "NVTE_QKV_Format", pybind11::module_local())
.value("NVTE_SBHD", NVTE_QKV_Format::NVTE_SBHD)
.value("NVTE_BSHD", NVTE_QKV_Format::NVTE_BSHD)
.value("NVTE_THD", NVTE_QKV_Format::NVTE_THD);
pybind11::enum_<NVTE_Activation_Type>(m, "NVTE_Activation_Type", pybind11::module_local())
.value("GELU", NVTE_Activation_Type::GELU)
......
......@@ -45,5 +45,29 @@ void PopulateRngStateAsync(void *rng_state_dst, const void *const seed, size_t q
NVTE_CHECK_CUDA(cudaGetLastError());
}
__global__ void get_runtime_num_segments_kernel(int32_t *cu_seqlen, size_t len, uint32_t *out) {
int tid = blockDim.x * blockIdx.x + threadIdx.x;
if (tid >= len) return;
if (cu_seqlen[tid] > 0) {
// atomicAdd only support 32 bits dtype
atomicAdd(out, 1);
}
}
uint32_t GetRuntimeNumSegments(void *cu_seqlen, void *workspace, size_t len, cudaStream_t stream) {
// workspace size requires 4 bytes
uint32_t *dout = static_cast<uint32_t *>(workspace);
uint32_t hout{};
cudaMemsetAsync(dout, 0, sizeof(uint32_t), stream);
constexpr int threads = 128;
const int blocks = (len - 1) / threads + 1;
get_runtime_num_segments_kernel<<<blocks, threads, 0, stream>>>(static_cast<int32_t *>(cu_seqlen),
len, dout);
cudaMemcpyAsync(&hout, dout, sizeof(uint32_t), cudaMemcpyDeviceToHost, stream);
cudaStreamSynchronize(stream);
return hout;
}
} // namespace jax
} // namespace transformer_engine
......@@ -28,6 +28,8 @@ void PopulateRngStateAsync(void *rng_state_dst, const void *const seed, size_t q
size_t kv_max_seqlen, NVTE_Fused_Attn_Backend backend,
cudaStream_t stream);
uint32_t GetRuntimeNumSegments(void *cu_seqlen, void *workspace, size_t len, cudaStream_t stream);
class cudaDevicePropertiesManager {
public:
static cudaDevicePropertiesManager &Instance() {
......
......@@ -26,7 +26,7 @@ from .module import DenseGeneral, LayerNormDenseGeneral, LayerNormMLP
from .module import LayerNorm, Softmax
from ..attention import AttnBiasType, AttnMaskType, QKVLayout
from ..attention import is_fused_attn_kernel_available, canonicalize_attn_mask_type
from ..attention import fused_attn_qkvpacked, fused_attn_kvpacked, fused_attn
from ..attention import fused_attn
from ..softmax import SoftmaxType
from ..sharding import num_of_devices
from ..sharding import get_sharding_map_logic_axis_to_mesh_axis
......@@ -268,6 +268,7 @@ class _FusedDotProductAttention(nn.Module): # pylint: disable=too-few-public-me
scale_factor = self.scale_factor
del self.scale_factor
# TODO(rewang): integrate THD format
if self.qkv_layout == QKVLayout.BS3HD:
"""qkvpacked format, treat
query: qkvpacked tensor, shape = [..., 3, h, d]
......@@ -277,13 +278,14 @@ class _FusedDotProductAttention(nn.Module): # pylint: disable=too-few-public-me
qkv_packed = query
if self.transpose_batch_sequence:
qkv_packed = qkv_packed.transpose([1, 0, 2, 3, 4])
x = fused_attn_qkvpacked(
qkv_packed,
x = fused_attn(
(qkv_packed,),
bias,
mask,
seed,
attn_mask_type=self.attn_mask_type,
attn_bias_type=self.attn_bias_type,
qkv_layout=self.qkv_layout,
scaling_factor=scale_factor,
dropout_probability=self.attention_dropout,
is_training=not deterministic,
......@@ -298,14 +300,14 @@ class _FusedDotProductAttention(nn.Module): # pylint: disable=too-few-public-me
if self.transpose_batch_sequence:
query = query.transpose([1, 0, 2, 3])
kv_packed = kv_packed.transpose([1, 0, 2, 3, 4])
x = fused_attn_kvpacked(
query,
kv_packed,
x = fused_attn(
(query, kv_packed),
bias,
mask,
seed,
attn_mask_type=self.attn_mask_type,
attn_bias_type=self.attn_bias_type,
qkv_layout=self.qkv_layout,
scaling_factor=scale_factor,
dropout_probability=self.attention_dropout,
is_training=not deterministic,
......@@ -316,14 +318,13 @@ class _FusedDotProductAttention(nn.Module): # pylint: disable=too-few-public-me
key = key.transpose([1, 0, 2, 3])
value = value.transpose([1, 0, 2, 3])
x = fused_attn(
query,
key,
value,
(query, key, value),
bias,
mask,
seed,
attn_mask_type=self.attn_mask_type,
attn_bias_type=self.attn_bias_type,
qkv_layout=self.qkv_layout,
scaling_factor=scale_factor,
dropout_probability=self.attention_dropout,
is_training=not deterministic,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment