"...git@developer.sourcefind.cn:OpenDAS/TransformerEngine.git" did not exist on "82bc797f17eee5f830edcd058e79390a0c5acff6"
Unverified Commit d74e65f5 authored by Reese Wang's avatar Reese Wang Committed by GitHub
Browse files

[JAX] Reduce lowering time after cuDNN 90300 (#1032)



* Support actlen = 0 after cuDNN 9.3.0
Signed-off-by: default avatarReese Wang <rewang@nvidia.com>

* Add runtime_segment < max_segment tests
Signed-off-by: default avatarReese Wang <rewang@nvidia.com>

---------
Signed-off-by: default avatarReese Wang <rewang@nvidia.com>
parent 87939be1
...@@ -391,7 +391,7 @@ class FusedAttnRunner: ...@@ -391,7 +391,7 @@ class FusedAttnRunner:
return segment_ids, segment_pad return segment_ids, segment_pad
if get_qkv_format(self.qkv_layout) == QKVFormat.THD: if get_qkv_format(self.qkv_layout) == QKVFormat.THD:
self.num_segments_per_seq = 3 self.num_segments_per_seq = 2
self.token_q, self.segment_pad_q = generate_random_segment_ids( self.token_q, self.segment_pad_q = generate_random_segment_ids(
self.batch_size, self.max_seqlen_q, self.num_segments_per_seq, seed=42 self.batch_size, self.max_seqlen_q, self.num_segments_per_seq, seed=42
) )
...@@ -461,7 +461,8 @@ class FusedAttnRunner: ...@@ -461,7 +461,8 @@ class FusedAttnRunner:
"dropout_probability": self.dropout_prob, "dropout_probability": self.dropout_prob,
"is_training": self.is_training, "is_training": self.is_training,
"qkv_layout": self.qkv_layout, "qkv_layout": self.qkv_layout,
"max_segments_per_seq": self.num_segments_per_seq, # +1 for testing runtime_segments < max_segments
"max_segments_per_seq": self.num_segments_per_seq + 1,
} }
# Convert the outputs to float32 for the elementwise comparison # Convert the outputs to float32 for the elementwise comparison
...@@ -518,7 +519,7 @@ class FusedAttnRunner: ...@@ -518,7 +519,7 @@ class FusedAttnRunner:
"dropout_probability": self.dropout_prob, "dropout_probability": self.dropout_prob,
"is_training": self.is_training, "is_training": self.is_training,
"qkv_layout": self.qkv_layout, "qkv_layout": self.qkv_layout,
"max_segments_per_seq": self.num_segments_per_seq, "max_segments_per_seq": self.num_segments_per_seq + 1,
} }
# We can compute dBias only for the [1, h, s, s] layout # We can compute dBias only for the [1, h, s, s] layout
......
...@@ -30,6 +30,7 @@ from .misc import ( ...@@ -30,6 +30,7 @@ from .misc import (
jax_dtype_to_te_dtype, jax_dtype_to_te_dtype,
te_dtype_to_jax_dtype, te_dtype_to_jax_dtype,
get_padded_spec, get_padded_spec,
get_cudnn_version,
) )
from ..sharding import ( from ..sharding import (
all_reduce_sum_along_dp_fsdp, all_reduce_sum_along_dp_fsdp,
...@@ -393,12 +394,12 @@ class FusedAttnFwdPrimitive(BasePrimitive): ...@@ -393,12 +394,12 @@ class FusedAttnFwdPrimitive(BasePrimitive):
if nvte_get_qkv_format(qkv_layout) == NVTE_QKV_Format.NVTE_THD: if nvte_get_qkv_format(qkv_layout) == NVTE_QKV_Format.NVTE_THD:
def _fix_len_take(x, condition): def _fix_len_take(x, condition, fill_value=-1):
x_shape = x.shape x_shape = x.shape
x = x.flatten() x = x.flatten()
size = x.size size = x.size
indices = jnp.nonzero(condition.flatten(), size=size, fill_value=size)[0] indices = jnp.nonzero(condition.flatten(), size=size, fill_value=size)[0]
y = jnp.take(x, indices, fill_value=-1) y = jnp.take(x, indices, fill_value=fill_value)
return jnp.reshape(y, x_shape) return jnp.reshape(y, x_shape)
def convert_to_2d(offsets, batch, max_seqlen): def convert_to_2d(offsets, batch, max_seqlen):
...@@ -425,9 +426,16 @@ class FusedAttnFwdPrimitive(BasePrimitive): ...@@ -425,9 +426,16 @@ class FusedAttnFwdPrimitive(BasePrimitive):
kv_batch = reduce(operator.mul, k.shape[:-3]) kv_batch = reduce(operator.mul, k.shape[:-3])
# Gather valid q_seqlen, which is greater than 0 # Gather valid q_seqlen, which is greater than 0
# cuDNN version < 9.3.0:
# [[3, 5, 7, -1, -1], [2, 4, 6, -1, -1]] -> [[3, 5, 7, 2, 4], [6, -1, -1, -1, -1]] # [[3, 5, 7, -1, -1], [2, 4, 6, -1, -1]] -> [[3, 5, 7, 2, 4], [6, -1, -1, -1, -1]]
q_seqlen = _fix_len_take(q_seqlen, q_seqlen > 0) # cuDNN version >= 9.3.0, which supports act_seqlen = 0
kv_seqlen = _fix_len_take(kv_seqlen, kv_seqlen > 0) # [[3, 5, 7, -1, -1], [2, 4, 6, -1, -1]] -> [[3, 5, 7, 2, 4], [6, 0, 0, 0, 0]]
if get_cudnn_version() >= (9, 3, 0):
fill_value = 0
else:
fill_value = -1
q_seqlen = _fix_len_take(q_seqlen, q_seqlen > 0, fill_value=fill_value)
kv_seqlen = _fix_len_take(kv_seqlen, kv_seqlen > 0, fill_value=fill_value)
# Flatten the offset calculation # Flatten the offset calculation
# max_seqlen = 8, [[0, 3, 5, -1], [0, 2, 4, -1]] -> [[0, 3, 5, -1], [8, 11, 13, -1]] # max_seqlen = 8, [[0, 3, 5, -1], [0, 2, 4, -1]] -> [[0, 3, 5, -1], [8, 11, 13, -1]]
...@@ -788,13 +796,13 @@ class FusedAttnBwdPrimitive(BasePrimitive): ...@@ -788,13 +796,13 @@ class FusedAttnBwdPrimitive(BasePrimitive):
if nvte_get_qkv_format(qkv_layout) == NVTE_QKV_Format.NVTE_THD: if nvte_get_qkv_format(qkv_layout) == NVTE_QKV_Format.NVTE_THD:
def _fix_len_take(x, condition): def _fix_len_take(x, condition, fill_value=-1):
x_shape = x.shape x_shape = x.shape
x = x.flatten() x = x.flatten()
size = x.size size = x.size
indices = jnp.nonzero(condition.flatten(), size=size, fill_value=size)[0] indices = jnp.nonzero(condition.flatten(), size=size, fill_value=size)[0]
# TODO(rewang): try indices_are_sorted # TODO(rewang): try indices_are_sorted
y = jnp.take(x, indices, fill_value=-1) y = jnp.take(x, indices, fill_value=fill_value)
return jnp.reshape(y, x_shape) return jnp.reshape(y, x_shape)
def convert_to_2d(offsets, batch, max_seqlen): def convert_to_2d(offsets, batch, max_seqlen):
...@@ -821,9 +829,16 @@ class FusedAttnBwdPrimitive(BasePrimitive): ...@@ -821,9 +829,16 @@ class FusedAttnBwdPrimitive(BasePrimitive):
kv_batch = reduce(operator.mul, k.shape[:-3]) kv_batch = reduce(operator.mul, k.shape[:-3])
# Gather valid q_seqlen, which is greater than 0 # Gather valid q_seqlen, which is greater than 0
# cuDNN version < 9.3.0:
# [[3, 5, 7, -1, -1], [2, 4, 6, -1, -1]] -> [[3, 5, 7, 2, 4], [6, -1, -1, -1, -1]] # [[3, 5, 7, -1, -1], [2, 4, 6, -1, -1]] -> [[3, 5, 7, 2, 4], [6, -1, -1, -1, -1]]
q_seqlen = _fix_len_take(q_seqlen, q_seqlen > 0) # cuDNN version >= 9.3.0, which supports act_seqlen = 0
kv_seqlen = _fix_len_take(kv_seqlen, kv_seqlen > 0) # [[3, 5, 7, -1, -1], [2, 4, 6, -1, -1]] -> [[3, 5, 7, 2, 4], [6, 0, 0, 0, 0]]
if get_cudnn_version() >= (9, 3, 0):
fill_value = 0
else:
fill_value = -1
q_seqlen = _fix_len_take(q_seqlen, q_seqlen > 0, fill_value=fill_value)
kv_seqlen = _fix_len_take(kv_seqlen, kv_seqlen > 0, fill_value=fill_value)
# Flatten the offset calculation # Flatten the offset calculation
# max_seqlen = 8, [[0, 3, 5, -1], [0, 2, 4, -1]] -> [[0, 3, 5, -1], [8, 11, 13, -1]] # max_seqlen = 8, [[0, 3, 5, -1], [0, 2, 4, -1]] -> [[0, 3, 5, -1], [8, 11, 13, -1]]
......
...@@ -3,12 +3,16 @@ ...@@ -3,12 +3,16 @@
# See LICENSE for license information. # See LICENSE for license information.
"""JAX/TE miscellaneous for custom ops""" """JAX/TE miscellaneous for custom ops"""
import functools
from typing import Tuple
import numpy as np import numpy as np
import jax.numpy as jnp import jax.numpy as jnp
from jax import dtypes from jax import dtypes
from jax.interpreters.mlir import dtype_to_ir_type from jax.interpreters.mlir import dtype_to_ir_type
from transformer_engine.transformer_engine_jax import DType as TEDType from transformer_engine.transformer_engine_jax import DType as TEDType
from transformer_engine import transformer_engine_jax
from ..sharding import get_padded_spec as te_get_padded_spec from ..sharding import get_padded_spec as te_get_padded_spec
...@@ -128,3 +132,13 @@ def multidim_transpose(shape, static_axis_boundary, transpose_axis_boundary): ...@@ -128,3 +132,13 @@ def multidim_transpose(shape, static_axis_boundary, transpose_axis_boundary):
*shape[transpose_axis_boundary:], *shape[transpose_axis_boundary:],
*shape[transpose_start_idx:transpose_axis_boundary], *shape[transpose_start_idx:transpose_axis_boundary],
) )
@functools.lru_cache(maxsize=None)
def get_cudnn_version() -> Tuple[int, int, int]:
"""Runtime cuDNN version (major, minor, patch)"""
encoded_version = transformer_engine_jax.get_cudnn_version()
major_version_magnitude = 1000 if encoded_version < 90000 else 10000
major, encoded_version = divmod(encoded_version, major_version_magnitude)
minor, patch = divmod(encoded_version, 100)
return (major, minor, patch)
...@@ -139,7 +139,13 @@ pybind11::tuple GetFusedAttnForwardWorkspaceSizes( ...@@ -139,7 +139,13 @@ pybind11::tuple GetFusedAttnForwardWorkspaceSizes(
auto is_ragged = nvte_get_qkv_format(qkv_layout) == NVTE_QKV_Format::NVTE_THD; auto is_ragged = nvte_get_qkv_format(qkv_layout) == NVTE_QKV_Format::NVTE_THD;
// It is a WAR to pre-create all possible cuDNN graph at the JIT compile time // It is a WAR to pre-create all possible cuDNN graph at the JIT compile time
size_t max_num_segments = is_ragged ? input_batch * max_segments_per_seq : input_batch; size_t max_num_segments = is_ragged ? input_batch * max_segments_per_seq : input_batch;
for (auto num_segments = input_batch; num_segments <= max_num_segments; ++num_segments) { size_t min_num_segments = input_batch;
auto cudnn_runtime_version = cudnnGetVersion();
if (is_ragged && cudnn_runtime_version >= 90300) {
// For cuDNN < 9.3.0, it requires to run all possible seqlens to address act_seqlen = 0
min_num_segments = input_batch * max_segments_per_seq;
}
for (auto num_segments = min_num_segments; num_segments <= max_num_segments; ++num_segments) {
// the last one is the largest which will be the returned workspace size // the last one is the largest which will be the returned workspace size
auto q_cu_seqlens_tensor = auto q_cu_seqlens_tensor =
TensorWrapper(nullptr, std::vector<size_t>{num_segments + 1}, DType::kInt32); TensorWrapper(nullptr, std::vector<size_t>{num_segments + 1}, DType::kInt32);
...@@ -227,14 +233,19 @@ void FusedAttnForward(cudaStream_t stream, void **buffers, const char *opaque, s ...@@ -227,14 +233,19 @@ void FusedAttnForward(cudaStream_t stream, void **buffers, const char *opaque, s
size_t num_segments = input_batch; // Non-THD format, input_batch = num_segments size_t num_segments = input_batch; // Non-THD format, input_batch = num_segments
if (is_ragged) { if (is_ragged) {
// workspace can be reused here as it is not used with cuDNN graph at the same time auto cudnn_runtime_version = cudnnGetVersion();
size_t runtime_num_segments_q = if (cudnn_runtime_version >= 90300) {
GetRuntimeNumSegments(q_cu_seqlens, workspace, input_batch * q_max_seqlen, stream); num_segments = input_batch * max_segments_per_seq;
size_t runtime_num_segments_kv = } else {
GetRuntimeNumSegments(kv_cu_seqlens, workspace, input_batch * kv_max_seqlen, stream); // workspace can be reused here as it is not used with cuDNN graph at the same time
NVTE_CHECK(runtime_num_segments_q == runtime_num_segments_kv); size_t runtime_num_segments_q =
NVTE_CHECK(runtime_num_segments_q <= input_batch * max_segments_per_seq); GetRuntimeNumSegments(q_cu_seqlens, workspace, input_batch * q_max_seqlen, stream);
num_segments = runtime_num_segments_q; size_t runtime_num_segments_kv =
GetRuntimeNumSegments(kv_cu_seqlens, workspace, input_batch * kv_max_seqlen, stream);
NVTE_CHECK(runtime_num_segments_q == runtime_num_segments_kv);
NVTE_CHECK(runtime_num_segments_q <= input_batch * max_segments_per_seq);
num_segments = runtime_num_segments_q;
}
cudaMemsetAsync(output, 0, cudaMemsetAsync(output, 0,
input_batch * q_max_seqlen * attn_heads * head_dim * typeToSize(dtype), stream); input_batch * q_max_seqlen * attn_heads * head_dim * typeToSize(dtype), stream);
} }
...@@ -366,7 +377,13 @@ pybind11::tuple GetFusedAttnBackwardWorkspaceSizes( ...@@ -366,7 +377,13 @@ pybind11::tuple GetFusedAttnBackwardWorkspaceSizes(
auto is_ragged = nvte_get_qkv_format(qkv_layout) == NVTE_QKV_Format::NVTE_THD; auto is_ragged = nvte_get_qkv_format(qkv_layout) == NVTE_QKV_Format::NVTE_THD;
// It is a WAR to pre-create all possible cuDNN graph at the JIT compile time // It is a WAR to pre-create all possible cuDNN graph at the JIT compile time
size_t max_num_segments = is_ragged ? input_batch * max_segments_per_seq : input_batch; size_t max_num_segments = is_ragged ? input_batch * max_segments_per_seq : input_batch;
for (auto num_segments = input_batch; num_segments <= max_num_segments; ++num_segments) { size_t min_num_segments = input_batch;
auto cudnn_runtime_version = cudnnGetVersion();
if (is_ragged && cudnn_runtime_version >= 90300) {
// For cuDNN < 9.3.0, it requires to run all possible seqlens to address act_seqlen = 0
min_num_segments = input_batch * max_segments_per_seq;
}
for (auto num_segments = min_num_segments; num_segments <= max_num_segments; ++num_segments) {
// the last one is the largest which will be the returned workspace size // the last one is the largest which will be the returned workspace size
auto q_cu_seqlens_tensor = auto q_cu_seqlens_tensor =
TensorWrapper(nullptr, std::vector<size_t>{num_segments + 1}, DType::kInt32); TensorWrapper(nullptr, std::vector<size_t>{num_segments + 1}, DType::kInt32);
...@@ -460,14 +477,19 @@ void FusedAttnBackward(cudaStream_t stream, void **buffers, const char *opaque, ...@@ -460,14 +477,19 @@ void FusedAttnBackward(cudaStream_t stream, void **buffers, const char *opaque,
size_t num_segments = input_batch; // Non-THD format, input_batch = num_segments size_t num_segments = input_batch; // Non-THD format, input_batch = num_segments
if (is_ragged) { if (is_ragged) {
// workspace can be reused here as it is not used with cuDNN graph at the same time auto cudnn_runtime_version = cudnnGetVersion();
size_t runtime_num_segments_q = if (cudnn_runtime_version >= 90300) {
GetRuntimeNumSegments(q_cu_seqlens, workspace, input_batch * q_max_seqlen, stream); num_segments = input_batch * max_segments_per_seq;
size_t runtime_num_segments_kv = } else {
GetRuntimeNumSegments(kv_cu_seqlens, workspace, input_batch * kv_max_seqlen, stream); // workspace can be reused here as it is not used with cuDNN graph at the same time
NVTE_CHECK(runtime_num_segments_q == runtime_num_segments_kv); size_t runtime_num_segments_q =
NVTE_CHECK(runtime_num_segments_q <= input_batch * max_segments_per_seq); GetRuntimeNumSegments(q_cu_seqlens, workspace, input_batch * q_max_seqlen, stream);
num_segments = runtime_num_segments_q; size_t runtime_num_segments_kv =
GetRuntimeNumSegments(kv_cu_seqlens, workspace, input_batch * kv_max_seqlen, stream);
NVTE_CHECK(runtime_num_segments_q == runtime_num_segments_kv);
NVTE_CHECK(runtime_num_segments_q <= input_batch * max_segments_per_seq);
num_segments = runtime_num_segments_q;
}
} }
auto q_cu_seqlens_tensor = auto q_cu_seqlens_tensor =
......
...@@ -59,6 +59,7 @@ PYBIND11_MODULE(transformer_engine_jax, m) { ...@@ -59,6 +59,7 @@ PYBIND11_MODULE(transformer_engine_jax, m) {
m.def("pack_fused_attn_descriptor", &PackCustomCallFusedAttnDescriptor); m.def("pack_fused_attn_descriptor", &PackCustomCallFusedAttnDescriptor);
m.def("get_fused_attn_backend", &GetFusedAttnBackend); m.def("get_fused_attn_backend", &GetFusedAttnBackend);
m.def("get_cuda_version", &GetCudaRuntimeVersion); m.def("get_cuda_version", &GetCudaRuntimeVersion);
m.def("get_cudnn_version", &GetCudnnRuntimeVersion);
m.def("get_device_compute_capability", &GetDeviceComputeCapability); m.def("get_device_compute_capability", &GetDeviceComputeCapability);
m.def("get_cublasLt_version", &cublasLtGetVersion); m.def("get_cublasLt_version", &cublasLtGetVersion);
m.def("get_dact_dbias_ct_workspace_sizes", &GetDActDBiasCastTransposeWorkspaceSizes); m.def("get_dact_dbias_ct_workspace_sizes", &GetDActDBiasCastTransposeWorkspaceSizes);
......
...@@ -19,6 +19,8 @@ int GetCudaRuntimeVersion() { ...@@ -19,6 +19,8 @@ int GetCudaRuntimeVersion() {
return ver; return ver;
} }
size_t GetCudnnRuntimeVersion() { return cudnnGetVersion(); }
int GetDeviceComputeCapability(int gpu_id) { return transformer_engine::cuda::sm_arch(gpu_id); } int GetDeviceComputeCapability(int gpu_id) { return transformer_engine::cuda::sm_arch(gpu_id); }
__global__ void populate_rng_state_kernel(int64_t *rng_state_dst, const int64_t *const seed, __global__ void populate_rng_state_kernel(int64_t *rng_state_dst, const int64_t *const seed,
......
...@@ -22,6 +22,7 @@ namespace transformer_engine { ...@@ -22,6 +22,7 @@ namespace transformer_engine {
namespace jax { namespace jax {
int GetCudaRuntimeVersion(); int GetCudaRuntimeVersion();
size_t GetCudnnRuntimeVersion();
int GetDeviceComputeCapability(int gpu_id); int GetDeviceComputeCapability(int gpu_id);
void PopulateRngStateAsync(void *rng_state_dst, const void *const seed, size_t q_max_seqlen, void PopulateRngStateAsync(void *rng_state_dst, const void *const seed, size_t q_max_seqlen,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment