Unverified Commit 687697a7 authored by Reese Wang's avatar Reese Wang Committed by GitHub
Browse files

[JAX] Add experimental internal used THD(packed) fused attn API (#964)



* Integrate experimental ragged offset
Signed-off-by: default avatarReese Wang <rewang@nvidia.com>

* Use per sequence based offsets
Signed-off-by: default avatarReese Wang <rewang@nvidia.com>

* Format
Signed-off-by: default avatarReese Wang <rewang@nvidia.com>

* Remove v/o_seq_offsets
Signed-off-by: default avatarReese Wang <rewang@nvidia.com>

* Add FP16 sanity tests and remove forward tests from the automatically run tests
Signed-off-by: default avatarReese Wang <rewang@nvidia.com>

* Enhance input checks
Signed-off-by: default avatarReese Wang <rewang@nvidia.com>

* Separate fused attn to 2 differnt APIs and add the docs
Signed-off-by: default avatarReese Wang <rewang@nvidia.com>

* Add experimental to the docs
Signed-off-by: default avatarReese Wang <rewang@nvidia.com>

* Fix lint
Signed-off-by: default avatarReese Wang <rewang@nvidia.com>

* Add runtime segments check
Signed-off-by: default avatarReese Wang <rewang@nvidia.com>

* Remove finished TODO
Signed-off-by: default avatarReese Wang <rewang@nvidia.com>

---------
Signed-off-by: default avatarReese Wang <rewang@nvidia.com>
parent 7669bf3d
...@@ -15,8 +15,7 @@ from utils import make_causal_mask, make_self_mask ...@@ -15,8 +15,7 @@ from utils import make_causal_mask, make_self_mask
from transformer_engine.jax import fp8_autocast from transformer_engine.jax import fp8_autocast
from transformer_engine.jax.attention import ( from transformer_engine.jax.attention import (
is_fused_attn_kernel_available, is_fused_attn_kernel_available,
fused_attn_qkvpacked, fused_attn,
fused_attn_kvpacked,
AttnBiasType, AttnBiasType,
AttnMaskType, AttnMaskType,
QKVLayout, QKVLayout,
...@@ -120,11 +119,15 @@ class TestDistributedSelfAttn: ...@@ -120,11 +119,15 @@ class TestDistributedSelfAttn:
def target_func(qkv, bias, mask): def target_func(qkv, bias, mask):
return jnp.mean( return jnp.mean(
fused_attn_qkvpacked( fused_attn(
qkv, (qkv,),
bias, bias,
mask, mask,
None, None,
None,
None,
None,
None,
attn_bias_type=attn_bias_type, attn_bias_type=attn_bias_type,
attn_mask_type=attn_mask_type, attn_mask_type=attn_mask_type,
scaling_factor=scaling_factor, scaling_factor=scaling_factor,
...@@ -252,12 +255,15 @@ class TestDistributedCrossAttn: ...@@ -252,12 +255,15 @@ class TestDistributedCrossAttn:
def target_func(q, kv, mask): def target_func(q, kv, mask):
return jnp.mean( return jnp.mean(
fused_attn_kvpacked( fused_attn(
q, (q, kv),
kv,
None, None,
mask, mask,
None, None,
None,
None,
None,
None,
attn_bias_type=attn_bias_type, attn_bias_type=attn_bias_type,
attn_mask_type=attn_mask_type, attn_mask_type=attn_mask_type,
scaling_factor=scaling_factor, scaling_factor=scaling_factor,
......
This diff is collapsed.
This diff is collapsed.
...@@ -137,6 +137,7 @@ struct CustomCallFusedAttnDescriptor { ...@@ -137,6 +137,7 @@ struct CustomCallFusedAttnDescriptor {
size_t num_gqa_groups; size_t num_gqa_groups;
size_t bias_heads; size_t bias_heads;
size_t head_dim; size_t head_dim;
size_t max_segments_per_seq;
size_t wkspace_size; size_t wkspace_size;
float scaling_factor; float scaling_factor;
float dropout_probability; float dropout_probability;
...@@ -151,9 +152,9 @@ struct CustomCallFusedAttnDescriptor { ...@@ -151,9 +152,9 @@ struct CustomCallFusedAttnDescriptor {
pybind11::bytes PackCustomCallFusedAttnDescriptor( pybind11::bytes PackCustomCallFusedAttnDescriptor(
size_t input_batch, size_t batch_size, size_t q_max_seqlen, size_t kv_max_seqlen, size_t input_batch, size_t batch_size, size_t q_max_seqlen, size_t kv_max_seqlen,
size_t attn_heads, size_t num_gqa_groups, size_t bias_heads, size_t head_dim, size_t attn_heads, size_t num_gqa_groups, size_t bias_heads, size_t head_dim,
size_t wkspace_size, float scaling_factor, float dropout_probability, NVTE_Bias_Type bias_type, size_t max_segments_per_seq, size_t wkspace_size, float scaling_factor,
NVTE_Mask_Type mask_type, NVTE_QKV_Layout qkv_layout, DType dtype, DType wkspace_dtype, float dropout_probability, NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type,
bool is_training); NVTE_QKV_Layout qkv_layout, DType dtype, DType wkspace_dtype, bool is_training);
// Transpose // Transpose
...@@ -249,7 +250,8 @@ pybind11::tuple GetFusedAttnForwardWorkspaceSizes( ...@@ -249,7 +250,8 @@ pybind11::tuple GetFusedAttnForwardWorkspaceSizes(
size_t input_batch, size_t bias_batch, size_t q_max_seqlen, size_t kv_max_seqlen, size_t input_batch, size_t bias_batch, size_t q_max_seqlen, size_t kv_max_seqlen,
size_t attn_heads, size_t num_gqa_groups, size_t bias_heads, size_t head_dim, size_t attn_heads, size_t num_gqa_groups, size_t bias_heads, size_t head_dim,
float scaling_factor, float dropout_probability, NVTE_Bias_Type bias_type, float scaling_factor, float dropout_probability, NVTE_Bias_Type bias_type,
NVTE_Mask_Type mask_type, NVTE_QKV_Layout qkv_layout, DType dtype, bool is_training); NVTE_Mask_Type mask_type, NVTE_QKV_Layout qkv_layout, DType dtype, bool is_training,
size_t max_segments_per_seq);
void FusedAttnForward(cudaStream_t stream, void **buffers, const char *opaque, size_t opaque_len); void FusedAttnForward(cudaStream_t stream, void **buffers, const char *opaque, size_t opaque_len);
...@@ -257,7 +259,8 @@ pybind11::tuple GetFusedAttnBackwardWorkspaceSizes( ...@@ -257,7 +259,8 @@ pybind11::tuple GetFusedAttnBackwardWorkspaceSizes(
size_t input_batch, size_t bias_batch, size_t q_max_seqlen, size_t kv_max_seqlen, size_t input_batch, size_t bias_batch, size_t q_max_seqlen, size_t kv_max_seqlen,
size_t attn_heads, size_t num_gqa_groups, size_t bias_heads, size_t head_dim, size_t attn_heads, size_t num_gqa_groups, size_t bias_heads, size_t head_dim,
float scaling_factor, float dropout_probability, NVTE_Bias_Type bias_type, float scaling_factor, float dropout_probability, NVTE_Bias_Type bias_type,
NVTE_Mask_Type mask_type, NVTE_QKV_Layout qkv_layout, DType dtype, bool is_training); NVTE_Mask_Type mask_type, NVTE_QKV_Layout qkv_layout, DType dtype, bool is_training,
size_t max_segments_per_seq);
void FusedAttnBackward(cudaStream_t stream, void **buffers, const char *opaque, size_t opaque_len); void FusedAttnBackward(cudaStream_t stream, void **buffers, const char *opaque, size_t opaque_len);
......
...@@ -66,13 +66,13 @@ pybind11::bytes PackCustomCallSoftmaxDescriptor(size_t batch_size, size_t paddin ...@@ -66,13 +66,13 @@ pybind11::bytes PackCustomCallSoftmaxDescriptor(size_t batch_size, size_t paddin
pybind11::bytes PackCustomCallFusedAttnDescriptor( pybind11::bytes PackCustomCallFusedAttnDescriptor(
size_t input_batch, size_t bias_batch, size_t q_max_seqlen, size_t kv_max_seqlen, size_t input_batch, size_t bias_batch, size_t q_max_seqlen, size_t kv_max_seqlen,
size_t attn_heads, size_t num_gqa_groups, size_t bias_heads, size_t head_dim, size_t attn_heads, size_t num_gqa_groups, size_t bias_heads, size_t head_dim,
size_t wkspace_size, float scaling_factor, float dropout_probability, NVTE_Bias_Type bias_type, size_t max_segments_per_seq, size_t wkspace_size, float scaling_factor,
NVTE_Mask_Type mask_type, NVTE_QKV_Layout qkv_layout, DType dtype, DType wkspace_dtype, float dropout_probability, NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type,
bool is_training) { NVTE_QKV_Layout qkv_layout, DType dtype, DType wkspace_dtype, bool is_training) {
return PackOpaque(CustomCallFusedAttnDescriptor{ return PackOpaque(CustomCallFusedAttnDescriptor{
input_batch, bias_batch, q_max_seqlen, kv_max_seqlen, attn_heads, num_gqa_groups, bias_heads, input_batch, bias_batch, q_max_seqlen, kv_max_seqlen, attn_heads, num_gqa_groups, bias_heads,
head_dim, wkspace_size, scaling_factor, dropout_probability, bias_type, mask_type, qkv_layout, head_dim, max_segments_per_seq, wkspace_size, scaling_factor, dropout_probability, bias_type,
dtype, wkspace_dtype, is_training}); mask_type, qkv_layout, dtype, wkspace_dtype, is_training});
} }
} // namespace jax } // namespace jax
......
...@@ -67,6 +67,7 @@ PYBIND11_MODULE(transformer_engine_jax, m) { ...@@ -67,6 +67,7 @@ PYBIND11_MODULE(transformer_engine_jax, m) {
m.def("get_layernorm_bwd_workspace_sizes", &GetLayerNormBackwardWorkspaceSizes); m.def("get_layernorm_bwd_workspace_sizes", &GetLayerNormBackwardWorkspaceSizes);
m.def("get_fused_attn_fwd_workspace_sizes", &GetFusedAttnForwardWorkspaceSizes); m.def("get_fused_attn_fwd_workspace_sizes", &GetFusedAttnForwardWorkspaceSizes);
m.def("get_fused_attn_bwd_workspace_sizes", &GetFusedAttnBackwardWorkspaceSizes); m.def("get_fused_attn_bwd_workspace_sizes", &GetFusedAttnBackwardWorkspaceSizes);
m.def("nvte_get_qkv_format", &nvte_get_qkv_format);
pybind11::enum_<DType>(m, "DType", pybind11::module_local()) pybind11::enum_<DType>(m, "DType", pybind11::module_local())
.value("kByte", DType::kByte) .value("kByte", DType::kByte)
...@@ -92,7 +93,15 @@ PYBIND11_MODULE(transformer_engine_jax, m) { ...@@ -92,7 +93,15 @@ PYBIND11_MODULE(transformer_engine_jax, m) {
pybind11::enum_<NVTE_QKV_Layout>(m, "NVTE_QKV_Layout", pybind11::module_local()) pybind11::enum_<NVTE_QKV_Layout>(m, "NVTE_QKV_Layout", pybind11::module_local())
.value("NVTE_BS3HD", NVTE_QKV_Layout::NVTE_BS3HD) .value("NVTE_BS3HD", NVTE_QKV_Layout::NVTE_BS3HD)
.value("NVTE_BSHD_BS2HD", NVTE_QKV_Layout::NVTE_BSHD_BS2HD) .value("NVTE_BSHD_BS2HD", NVTE_QKV_Layout::NVTE_BSHD_BS2HD)
.value("NVTE_BSHD_BSHD_BSHD", NVTE_QKV_Layout::NVTE_BSHD_BSHD_BSHD); .value("NVTE_BSHD_BSHD_BSHD", NVTE_QKV_Layout::NVTE_BSHD_BSHD_BSHD)
.value("NVTE_T3HD", NVTE_QKV_Layout::NVTE_T3HD)
.value("NVTE_THD_T2HD", NVTE_QKV_Layout::NVTE_THD_T2HD)
.value("NVTE_THD_THD_THD", NVTE_QKV_Layout::NVTE_THD_THD_THD);
pybind11::enum_<NVTE_QKV_Format>(m, "NVTE_QKV_Format", pybind11::module_local())
.value("NVTE_SBHD", NVTE_QKV_Format::NVTE_SBHD)
.value("NVTE_BSHD", NVTE_QKV_Format::NVTE_BSHD)
.value("NVTE_THD", NVTE_QKV_Format::NVTE_THD);
pybind11::enum_<NVTE_Activation_Type>(m, "NVTE_Activation_Type", pybind11::module_local()) pybind11::enum_<NVTE_Activation_Type>(m, "NVTE_Activation_Type", pybind11::module_local())
.value("GELU", NVTE_Activation_Type::GELU) .value("GELU", NVTE_Activation_Type::GELU)
......
...@@ -45,5 +45,29 @@ void PopulateRngStateAsync(void *rng_state_dst, const void *const seed, size_t q ...@@ -45,5 +45,29 @@ void PopulateRngStateAsync(void *rng_state_dst, const void *const seed, size_t q
NVTE_CHECK_CUDA(cudaGetLastError()); NVTE_CHECK_CUDA(cudaGetLastError());
} }
__global__ void get_runtime_num_segments_kernel(int32_t *cu_seqlen, size_t len, uint32_t *out) {
int tid = blockDim.x * blockIdx.x + threadIdx.x;
if (tid >= len) return;
if (cu_seqlen[tid] > 0) {
// atomicAdd only support 32 bits dtype
atomicAdd(out, 1);
}
}
uint32_t GetRuntimeNumSegments(void *cu_seqlen, void *workspace, size_t len, cudaStream_t stream) {
// workspace size requires 4 bytes
uint32_t *dout = static_cast<uint32_t *>(workspace);
uint32_t hout{};
cudaMemsetAsync(dout, 0, sizeof(uint32_t), stream);
constexpr int threads = 128;
const int blocks = (len - 1) / threads + 1;
get_runtime_num_segments_kernel<<<blocks, threads, 0, stream>>>(static_cast<int32_t *>(cu_seqlen),
len, dout);
cudaMemcpyAsync(&hout, dout, sizeof(uint32_t), cudaMemcpyDeviceToHost, stream);
cudaStreamSynchronize(stream);
return hout;
}
} // namespace jax } // namespace jax
} // namespace transformer_engine } // namespace transformer_engine
...@@ -28,6 +28,8 @@ void PopulateRngStateAsync(void *rng_state_dst, const void *const seed, size_t q ...@@ -28,6 +28,8 @@ void PopulateRngStateAsync(void *rng_state_dst, const void *const seed, size_t q
size_t kv_max_seqlen, NVTE_Fused_Attn_Backend backend, size_t kv_max_seqlen, NVTE_Fused_Attn_Backend backend,
cudaStream_t stream); cudaStream_t stream);
uint32_t GetRuntimeNumSegments(void *cu_seqlen, void *workspace, size_t len, cudaStream_t stream);
class cudaDevicePropertiesManager { class cudaDevicePropertiesManager {
public: public:
static cudaDevicePropertiesManager &Instance() { static cudaDevicePropertiesManager &Instance() {
......
...@@ -26,7 +26,7 @@ from .module import DenseGeneral, LayerNormDenseGeneral, LayerNormMLP ...@@ -26,7 +26,7 @@ from .module import DenseGeneral, LayerNormDenseGeneral, LayerNormMLP
from .module import LayerNorm, Softmax from .module import LayerNorm, Softmax
from ..attention import AttnBiasType, AttnMaskType, QKVLayout from ..attention import AttnBiasType, AttnMaskType, QKVLayout
from ..attention import is_fused_attn_kernel_available, canonicalize_attn_mask_type from ..attention import is_fused_attn_kernel_available, canonicalize_attn_mask_type
from ..attention import fused_attn_qkvpacked, fused_attn_kvpacked, fused_attn from ..attention import fused_attn
from ..softmax import SoftmaxType from ..softmax import SoftmaxType
from ..sharding import num_of_devices from ..sharding import num_of_devices
from ..sharding import get_sharding_map_logic_axis_to_mesh_axis from ..sharding import get_sharding_map_logic_axis_to_mesh_axis
...@@ -268,6 +268,7 @@ class _FusedDotProductAttention(nn.Module): # pylint: disable=too-few-public-me ...@@ -268,6 +268,7 @@ class _FusedDotProductAttention(nn.Module): # pylint: disable=too-few-public-me
scale_factor = self.scale_factor scale_factor = self.scale_factor
del self.scale_factor del self.scale_factor
# TODO(rewang): integrate THD format
if self.qkv_layout == QKVLayout.BS3HD: if self.qkv_layout == QKVLayout.BS3HD:
"""qkvpacked format, treat """qkvpacked format, treat
query: qkvpacked tensor, shape = [..., 3, h, d] query: qkvpacked tensor, shape = [..., 3, h, d]
...@@ -277,13 +278,14 @@ class _FusedDotProductAttention(nn.Module): # pylint: disable=too-few-public-me ...@@ -277,13 +278,14 @@ class _FusedDotProductAttention(nn.Module): # pylint: disable=too-few-public-me
qkv_packed = query qkv_packed = query
if self.transpose_batch_sequence: if self.transpose_batch_sequence:
qkv_packed = qkv_packed.transpose([1, 0, 2, 3, 4]) qkv_packed = qkv_packed.transpose([1, 0, 2, 3, 4])
x = fused_attn_qkvpacked( x = fused_attn(
qkv_packed, (qkv_packed,),
bias, bias,
mask, mask,
seed, seed,
attn_mask_type=self.attn_mask_type, attn_mask_type=self.attn_mask_type,
attn_bias_type=self.attn_bias_type, attn_bias_type=self.attn_bias_type,
qkv_layout=self.qkv_layout,
scaling_factor=scale_factor, scaling_factor=scale_factor,
dropout_probability=self.attention_dropout, dropout_probability=self.attention_dropout,
is_training=not deterministic, is_training=not deterministic,
...@@ -298,14 +300,14 @@ class _FusedDotProductAttention(nn.Module): # pylint: disable=too-few-public-me ...@@ -298,14 +300,14 @@ class _FusedDotProductAttention(nn.Module): # pylint: disable=too-few-public-me
if self.transpose_batch_sequence: if self.transpose_batch_sequence:
query = query.transpose([1, 0, 2, 3]) query = query.transpose([1, 0, 2, 3])
kv_packed = kv_packed.transpose([1, 0, 2, 3, 4]) kv_packed = kv_packed.transpose([1, 0, 2, 3, 4])
x = fused_attn_kvpacked( x = fused_attn(
query, (query, kv_packed),
kv_packed,
bias, bias,
mask, mask,
seed, seed,
attn_mask_type=self.attn_mask_type, attn_mask_type=self.attn_mask_type,
attn_bias_type=self.attn_bias_type, attn_bias_type=self.attn_bias_type,
qkv_layout=self.qkv_layout,
scaling_factor=scale_factor, scaling_factor=scale_factor,
dropout_probability=self.attention_dropout, dropout_probability=self.attention_dropout,
is_training=not deterministic, is_training=not deterministic,
...@@ -316,14 +318,13 @@ class _FusedDotProductAttention(nn.Module): # pylint: disable=too-few-public-me ...@@ -316,14 +318,13 @@ class _FusedDotProductAttention(nn.Module): # pylint: disable=too-few-public-me
key = key.transpose([1, 0, 2, 3]) key = key.transpose([1, 0, 2, 3])
value = value.transpose([1, 0, 2, 3]) value = value.transpose([1, 0, 2, 3])
x = fused_attn( x = fused_attn(
query, (query, key, value),
key,
value,
bias, bias,
mask, mask,
seed, seed,
attn_mask_type=self.attn_mask_type, attn_mask_type=self.attn_mask_type,
attn_bias_type=self.attn_bias_type, attn_bias_type=self.attn_bias_type,
qkv_layout=self.qkv_layout,
scaling_factor=scale_factor, scaling_factor=scale_factor,
dropout_probability=self.attention_dropout, dropout_probability=self.attention_dropout,
is_training=not deterministic, is_training=not deterministic,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment