Unverified Commit d74e65f5 authored by Reese Wang's avatar Reese Wang Committed by GitHub
Browse files

[JAX] Reduce lowering time after cuDNN 90300 (#1032)



* Support actlen = 0 after cuDNN 9.3.0
Signed-off-by: default avatarReese Wang <rewang@nvidia.com>

* Add runtime_segment < max_segment tests
Signed-off-by: default avatarReese Wang <rewang@nvidia.com>

---------
Signed-off-by: default avatarReese Wang <rewang@nvidia.com>
parent 87939be1
......@@ -391,7 +391,7 @@ class FusedAttnRunner:
return segment_ids, segment_pad
if get_qkv_format(self.qkv_layout) == QKVFormat.THD:
self.num_segments_per_seq = 3
self.num_segments_per_seq = 2
self.token_q, self.segment_pad_q = generate_random_segment_ids(
self.batch_size, self.max_seqlen_q, self.num_segments_per_seq, seed=42
)
......@@ -461,7 +461,8 @@ class FusedAttnRunner:
"dropout_probability": self.dropout_prob,
"is_training": self.is_training,
"qkv_layout": self.qkv_layout,
"max_segments_per_seq": self.num_segments_per_seq,
# +1 for testing runtime_segments < max_segments
"max_segments_per_seq": self.num_segments_per_seq + 1,
}
# Convert the outputs to float32 for the elementwise comparison
......@@ -518,7 +519,7 @@ class FusedAttnRunner:
"dropout_probability": self.dropout_prob,
"is_training": self.is_training,
"qkv_layout": self.qkv_layout,
"max_segments_per_seq": self.num_segments_per_seq,
"max_segments_per_seq": self.num_segments_per_seq + 1,
}
# We can compute dBias only for the [1, h, s, s] layout
......
......@@ -30,6 +30,7 @@ from .misc import (
jax_dtype_to_te_dtype,
te_dtype_to_jax_dtype,
get_padded_spec,
get_cudnn_version,
)
from ..sharding import (
all_reduce_sum_along_dp_fsdp,
......@@ -393,12 +394,12 @@ class FusedAttnFwdPrimitive(BasePrimitive):
if nvte_get_qkv_format(qkv_layout) == NVTE_QKV_Format.NVTE_THD:
def _fix_len_take(x, condition):
def _fix_len_take(x, condition, fill_value=-1):
x_shape = x.shape
x = x.flatten()
size = x.size
indices = jnp.nonzero(condition.flatten(), size=size, fill_value=size)[0]
y = jnp.take(x, indices, fill_value=-1)
y = jnp.take(x, indices, fill_value=fill_value)
return jnp.reshape(y, x_shape)
def convert_to_2d(offsets, batch, max_seqlen):
......@@ -425,9 +426,16 @@ class FusedAttnFwdPrimitive(BasePrimitive):
kv_batch = reduce(operator.mul, k.shape[:-3])
# Gather valid q_seqlen, which is greater than 0
# cuDNN version < 9.3.0:
# [[3, 5, 7, -1, -1], [2, 4, 6, -1, -1]] -> [[3, 5, 7, 2, 4], [6, -1, -1, -1, -1]]
q_seqlen = _fix_len_take(q_seqlen, q_seqlen > 0)
kv_seqlen = _fix_len_take(kv_seqlen, kv_seqlen > 0)
# cuDNN version >= 9.3.0, which supports act_seqlen = 0
# [[3, 5, 7, -1, -1], [2, 4, 6, -1, -1]] -> [[3, 5, 7, 2, 4], [6, 0, 0, 0, 0]]
if get_cudnn_version() >= (9, 3, 0):
fill_value = 0
else:
fill_value = -1
q_seqlen = _fix_len_take(q_seqlen, q_seqlen > 0, fill_value=fill_value)
kv_seqlen = _fix_len_take(kv_seqlen, kv_seqlen > 0, fill_value=fill_value)
# Flatten the offset calculation
# max_seqlen = 8, [[0, 3, 5, -1], [0, 2, 4, -1]] -> [[0, 3, 5, -1], [8, 11, 13, -1]]
......@@ -788,13 +796,13 @@ class FusedAttnBwdPrimitive(BasePrimitive):
if nvte_get_qkv_format(qkv_layout) == NVTE_QKV_Format.NVTE_THD:
def _fix_len_take(x, condition):
def _fix_len_take(x, condition, fill_value=-1):
x_shape = x.shape
x = x.flatten()
size = x.size
indices = jnp.nonzero(condition.flatten(), size=size, fill_value=size)[0]
# TODO(rewang): try indices_are_sorted
y = jnp.take(x, indices, fill_value=-1)
y = jnp.take(x, indices, fill_value=fill_value)
return jnp.reshape(y, x_shape)
def convert_to_2d(offsets, batch, max_seqlen):
......@@ -821,9 +829,16 @@ class FusedAttnBwdPrimitive(BasePrimitive):
kv_batch = reduce(operator.mul, k.shape[:-3])
# Gather valid q_seqlen, which is greater than 0
# cuDNN version < 9.3.0:
# [[3, 5, 7, -1, -1], [2, 4, 6, -1, -1]] -> [[3, 5, 7, 2, 4], [6, -1, -1, -1, -1]]
q_seqlen = _fix_len_take(q_seqlen, q_seqlen > 0)
kv_seqlen = _fix_len_take(kv_seqlen, kv_seqlen > 0)
# cuDNN version >= 9.3.0, which supports act_seqlen = 0
# [[3, 5, 7, -1, -1], [2, 4, 6, -1, -1]] -> [[3, 5, 7, 2, 4], [6, 0, 0, 0, 0]]
if get_cudnn_version() >= (9, 3, 0):
fill_value = 0
else:
fill_value = -1
q_seqlen = _fix_len_take(q_seqlen, q_seqlen > 0, fill_value=fill_value)
kv_seqlen = _fix_len_take(kv_seqlen, kv_seqlen > 0, fill_value=fill_value)
# Flatten the offset calculation
# max_seqlen = 8, [[0, 3, 5, -1], [0, 2, 4, -1]] -> [[0, 3, 5, -1], [8, 11, 13, -1]]
......
......@@ -3,12 +3,16 @@
# See LICENSE for license information.
"""JAX/TE miscellaneous for custom ops"""
import functools
from typing import Tuple
import numpy as np
import jax.numpy as jnp
from jax import dtypes
from jax.interpreters.mlir import dtype_to_ir_type
from transformer_engine.transformer_engine_jax import DType as TEDType
from transformer_engine import transformer_engine_jax
from ..sharding import get_padded_spec as te_get_padded_spec
......@@ -128,3 +132,13 @@ def multidim_transpose(shape, static_axis_boundary, transpose_axis_boundary):
*shape[transpose_axis_boundary:],
*shape[transpose_start_idx:transpose_axis_boundary],
)
@functools.lru_cache(maxsize=None)
def get_cudnn_version() -> Tuple[int, int, int]:
"""Runtime cuDNN version (major, minor, patch)"""
encoded_version = transformer_engine_jax.get_cudnn_version()
major_version_magnitude = 1000 if encoded_version < 90000 else 10000
major, encoded_version = divmod(encoded_version, major_version_magnitude)
minor, patch = divmod(encoded_version, 100)
return (major, minor, patch)
......@@ -139,7 +139,13 @@ pybind11::tuple GetFusedAttnForwardWorkspaceSizes(
auto is_ragged = nvte_get_qkv_format(qkv_layout) == NVTE_QKV_Format::NVTE_THD;
// It is a WAR to pre-create all possible cuDNN graph at the JIT compile time
size_t max_num_segments = is_ragged ? input_batch * max_segments_per_seq : input_batch;
for (auto num_segments = input_batch; num_segments <= max_num_segments; ++num_segments) {
size_t min_num_segments = input_batch;
auto cudnn_runtime_version = cudnnGetVersion();
if (is_ragged && cudnn_runtime_version >= 90300) {
// For cuDNN < 9.3.0, it requires to run all possible seqlens to address act_seqlen = 0
min_num_segments = input_batch * max_segments_per_seq;
}
for (auto num_segments = min_num_segments; num_segments <= max_num_segments; ++num_segments) {
// the last one is the largest which will be the returned workspace size
auto q_cu_seqlens_tensor =
TensorWrapper(nullptr, std::vector<size_t>{num_segments + 1}, DType::kInt32);
......@@ -227,6 +233,10 @@ void FusedAttnForward(cudaStream_t stream, void **buffers, const char *opaque, s
size_t num_segments = input_batch; // Non-THD format, input_batch = num_segments
if (is_ragged) {
auto cudnn_runtime_version = cudnnGetVersion();
if (cudnn_runtime_version >= 90300) {
num_segments = input_batch * max_segments_per_seq;
} else {
// workspace can be reused here as it is not used with cuDNN graph at the same time
size_t runtime_num_segments_q =
GetRuntimeNumSegments(q_cu_seqlens, workspace, input_batch * q_max_seqlen, stream);
......@@ -235,6 +245,7 @@ void FusedAttnForward(cudaStream_t stream, void **buffers, const char *opaque, s
NVTE_CHECK(runtime_num_segments_q == runtime_num_segments_kv);
NVTE_CHECK(runtime_num_segments_q <= input_batch * max_segments_per_seq);
num_segments = runtime_num_segments_q;
}
cudaMemsetAsync(output, 0,
input_batch * q_max_seqlen * attn_heads * head_dim * typeToSize(dtype), stream);
}
......@@ -366,7 +377,13 @@ pybind11::tuple GetFusedAttnBackwardWorkspaceSizes(
auto is_ragged = nvte_get_qkv_format(qkv_layout) == NVTE_QKV_Format::NVTE_THD;
// It is a WAR to pre-create all possible cuDNN graph at the JIT compile time
size_t max_num_segments = is_ragged ? input_batch * max_segments_per_seq : input_batch;
for (auto num_segments = input_batch; num_segments <= max_num_segments; ++num_segments) {
size_t min_num_segments = input_batch;
auto cudnn_runtime_version = cudnnGetVersion();
if (is_ragged && cudnn_runtime_version >= 90300) {
// For cuDNN < 9.3.0, it requires to run all possible seqlens to address act_seqlen = 0
min_num_segments = input_batch * max_segments_per_seq;
}
for (auto num_segments = min_num_segments; num_segments <= max_num_segments; ++num_segments) {
// the last one is the largest which will be the returned workspace size
auto q_cu_seqlens_tensor =
TensorWrapper(nullptr, std::vector<size_t>{num_segments + 1}, DType::kInt32);
......@@ -460,6 +477,10 @@ void FusedAttnBackward(cudaStream_t stream, void **buffers, const char *opaque,
size_t num_segments = input_batch; // Non-THD format, input_batch = num_segments
if (is_ragged) {
auto cudnn_runtime_version = cudnnGetVersion();
if (cudnn_runtime_version >= 90300) {
num_segments = input_batch * max_segments_per_seq;
} else {
// workspace can be reused here as it is not used with cuDNN graph at the same time
size_t runtime_num_segments_q =
GetRuntimeNumSegments(q_cu_seqlens, workspace, input_batch * q_max_seqlen, stream);
......@@ -469,6 +490,7 @@ void FusedAttnBackward(cudaStream_t stream, void **buffers, const char *opaque,
NVTE_CHECK(runtime_num_segments_q <= input_batch * max_segments_per_seq);
num_segments = runtime_num_segments_q;
}
}
auto q_cu_seqlens_tensor =
TensorWrapper(q_cu_seqlens, std::vector<size_t>{num_segments + 1}, DType::kInt32);
......
......@@ -59,6 +59,7 @@ PYBIND11_MODULE(transformer_engine_jax, m) {
m.def("pack_fused_attn_descriptor", &PackCustomCallFusedAttnDescriptor);
m.def("get_fused_attn_backend", &GetFusedAttnBackend);
m.def("get_cuda_version", &GetCudaRuntimeVersion);
m.def("get_cudnn_version", &GetCudnnRuntimeVersion);
m.def("get_device_compute_capability", &GetDeviceComputeCapability);
m.def("get_cublasLt_version", &cublasLtGetVersion);
m.def("get_dact_dbias_ct_workspace_sizes", &GetDActDBiasCastTransposeWorkspaceSizes);
......
......@@ -19,6 +19,8 @@ int GetCudaRuntimeVersion() {
return ver;
}
size_t GetCudnnRuntimeVersion() { return cudnnGetVersion(); }
int GetDeviceComputeCapability(int gpu_id) { return transformer_engine::cuda::sm_arch(gpu_id); }
__global__ void populate_rng_state_kernel(int64_t *rng_state_dst, const int64_t *const seed,
......
......@@ -22,6 +22,7 @@ namespace transformer_engine {
namespace jax {
int GetCudaRuntimeVersion();
size_t GetCudnnRuntimeVersion();
int GetDeviceComputeCapability(int gpu_id);
void PopulateRngStateAsync(void *rng_state_dst, const void *const seed, size_t q_max_seqlen,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment