Unverified Commit 65c2798a authored by Charlene Yang's avatar Charlene Yang Committed by GitHub
Browse files

[PyTorch] Minor fixes for TE 2.2 (#1589)



* skip cuDNN 9.8 for KV caching
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* revert from max_seqlen_kv to max_sequence_length for InferenceParams
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* rename test_paged_attn to test_kv_cache
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* remove redundant None returns in bwd
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* add debug flags when no backend is found
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* skip kv_cache_accuracy tests for cuDNN 9.8
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* truncate length of cu_seqlens for consistency with q/k/v shape
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* add back padding_brcm for fused attn tests
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* re-enable kv_cache_accuracy test for 9.8
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* fix cuDNN search dir
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* fixes based on review
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* remove extra empty line
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

---------
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>
Co-authored-by: default avatarpre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
parent abbdd769
...@@ -32,6 +32,7 @@ pyTorch ...@@ -32,6 +32,7 @@ pyTorch
:members: forward, set_context_parallel_group, set_tensor_parallel_group :members: forward, set_context_parallel_group, set_tensor_parallel_group
.. autoapiclass:: transformer_engine.pytorch.dot_product_attention.inference.InferenceParams(max_batch_size, max_sequence_length) .. autoapiclass:: transformer_engine.pytorch.dot_product_attention.inference.InferenceParams(max_batch_size, max_sequence_length)
:members: reset, allocate_memory, pre_step, get_seqlens_pre_step, convert_paged_to_nonpaged, step
.. autoapiclass:: transformer_engine.pytorch.CudaRNGStatesTracker() .. autoapiclass:: transformer_engine.pytorch.CudaRNGStatesTracker()
:members: reset, get_states, set_states, add, fork :members: reset, get_states, set_states, add, fork
......
...@@ -38,7 +38,7 @@ python3 -m pytest -v -s $TE_PATH/tests/pytorch/test_permutation.py || test_fail ...@@ -38,7 +38,7 @@ python3 -m pytest -v -s $TE_PATH/tests/pytorch/test_permutation.py || test_fail
python3 -m pytest -v -s $TE_PATH/tests/pytorch/test_parallel_cross_entropy.py || test_fail "test_parallel_cross_entropy.py" python3 -m pytest -v -s $TE_PATH/tests/pytorch/test_parallel_cross_entropy.py || test_fail "test_parallel_cross_entropy.py"
python3 -m pytest -v -s $TE_PATH/tests/pytorch/test_cpu_offloading.py || test_fail "test_cpu_offloading.py" python3 -m pytest -v -s $TE_PATH/tests/pytorch/test_cpu_offloading.py || test_fail "test_cpu_offloading.py"
NVTE_DEBUG=1 NVTE_DEBUG_LEVEL=1 python3 -m pytest -o log_cli=true --log-cli-level=INFO -v -s $TE_PATH/tests/pytorch/fused_attn/test_fused_attn.py || test_fail "test_fused_attn.py" NVTE_DEBUG=1 NVTE_DEBUG_LEVEL=1 python3 -m pytest -o log_cli=true --log-cli-level=INFO -v -s $TE_PATH/tests/pytorch/fused_attn/test_fused_attn.py || test_fail "test_fused_attn.py"
NVTE_DEBUG=1 NVTE_DEBUG_LEVEL=1 python3 -m pytest -o log_cli=true --log-cli-level=INFO -v -s $TE_PATH/tests/pytorch/fused_attn/test_paged_attn.py || test_fail "test_paged_attn.py" NVTE_DEBUG=1 NVTE_DEBUG_LEVEL=1 python3 -m pytest -o log_cli=true --log-cli-level=INFO -v -s $TE_PATH/tests/pytorch/fused_attn/test_kv_cache.py || test_fail "test_kv_cache.py"
if [ "$RET" -ne 0 ]; then if [ "$RET" -ne 0 ]; then
echo "Error in the following test cases:$FAILED_CASES" echo "Error in the following test cases:$FAILED_CASES"
......
...@@ -229,7 +229,7 @@ def get_model( ...@@ -229,7 +229,7 @@ def get_model(
attn_mask_type = "causal" attn_mask_type = "causal"
qkv_format = "bshd" qkv_format = "bshd"
if mode == "inference": if mode == "inference":
attn_mask_type = "padding_causal" if backend != "FusedAttention" else "padding" attn_mask_type = "padding_causal"
fp8_recipe = recipe.DelayedScaling( fp8_recipe = recipe.DelayedScaling(
margin=0, margin=0,
...@@ -392,9 +392,9 @@ def get_tols(module, backend, dtype): ...@@ -392,9 +392,9 @@ def get_tols(module, backend, dtype):
@pytest.mark.parametrize("module", ["TransformerLayer", "DotProductAttention"]) @pytest.mark.parametrize("module", ["TransformerLayer", "DotProductAttention"])
@pytest.mark.parametrize("is_cuda_graph", [False, True]) @pytest.mark.parametrize("is_cuda_graph", [False, True])
@pytest.mark.parametrize("is_fp8", [False, True]) @pytest.mark.parametrize("is_fp8", [False, True])
def test_paged_attn(dtype, model, qkv_format, is_paged, backend, module, is_cuda_graph, is_fp8): def test_kv_cache(dtype, model, qkv_format, is_paged, backend, module, is_cuda_graph, is_fp8):
reset_rng_states() reset_rng_states()
logger = logging.getLogger("test_paged_attn") logger = logging.getLogger("test_kv_cache")
fp8_recipe = recipe.DelayedScaling( fp8_recipe = recipe.DelayedScaling(
margin=0, margin=0,
fp8_format=recipe.Format.HYBRID, fp8_format=recipe.Format.HYBRID,
...@@ -407,7 +407,7 @@ def test_paged_attn(dtype, model, qkv_format, is_paged, backend, module, is_cuda ...@@ -407,7 +407,7 @@ def test_paged_attn(dtype, model, qkv_format, is_paged, backend, module, is_cuda
fp8_meta["recipe"] = fp8_recipe fp8_meta["recipe"] = fp8_recipe
config = model_configs_infer[model] config = model_configs_infer[model]
num_layers = 2 if module == "TransformerLayer" and backend != "FusedAttention" else 1 num_layers = 2 if module == "TransformerLayer" else 1
# flash-attn v2 requires page_size >= 256 # flash-attn v2 requires page_size >= 256
if backend == "FlashAttention" and not fa_utils.v3_is_installed: if backend == "FlashAttention" and not fa_utils.v3_is_installed:
config_max_seqlen_q = config.max_seqlen_q config_max_seqlen_q = config.max_seqlen_q
...@@ -437,7 +437,7 @@ def test_paged_attn(dtype, model, qkv_format, is_paged, backend, module, is_cuda ...@@ -437,7 +437,7 @@ def test_paged_attn(dtype, model, qkv_format, is_paged, backend, module, is_cuda
# initialize inference_params # initialize inference_params
inference_params = InferenceParams( inference_params = InferenceParams(
max_batch_size=max_batch_size, max_batch_size=max_batch_size,
max_seqlen_kv=config.max_seqlen_kv, max_sequence_length=config.max_seqlen_kv,
num_heads_kv=config.num_gqa_groups, num_heads_kv=config.num_gqa_groups,
head_dim_k=config.head_dim_qk, head_dim_k=config.head_dim_qk,
head_dim_v=config.head_dim_v, head_dim_v=config.head_dim_v,
......
...@@ -2143,7 +2143,7 @@ def test_kv_cache_accuracy(dtype, bs, model_key, use_RoPE, input_format, module, ...@@ -2143,7 +2143,7 @@ def test_kv_cache_accuracy(dtype, bs, model_key, use_RoPE, input_format, module,
inference_params = InferenceParams( inference_params = InferenceParams(
max_batch_size=B_max, max_batch_size=B_max,
max_seqlen_kv=S_max, max_sequence_length=S_max,
num_heads_kv=H, num_heads_kv=H,
head_dim_k=head_size, head_dim_k=head_size,
dtype=dtype, dtype=dtype,
......
...@@ -99,7 +99,8 @@ target_include_directories(transformer_engine PUBLIC ...@@ -99,7 +99,8 @@ target_include_directories(transformer_engine PUBLIC
# Configure dependencies # Configure dependencies
target_link_libraries(transformer_engine PUBLIC target_link_libraries(transformer_engine PUBLIC
CUDA::cublas CUDA::cublas
CUDA::cudart) CUDA::cudart
CUDNN::cudnn_all)
target_include_directories(transformer_engine PRIVATE target_include_directories(transformer_engine PRIVATE
${CMAKE_CUDA_TOOLKIT_INCLUDE_DIRECTORIES}) ${CMAKE_CUDA_TOOLKIT_INCLUDE_DIRECTORIES})
target_include_directories(transformer_engine PRIVATE "${CUDNN_FRONTEND_INCLUDE_DIR}") target_include_directories(transformer_engine PRIVATE "${CUDNN_FRONTEND_INCLUDE_DIR}")
......
...@@ -5131,6 +5131,16 @@ class FusedAttention(torch.nn.Module): ...@@ -5131,6 +5131,16 @@ class FusedAttention(torch.nn.Module):
# get q_format and kv_format for training and inference # get q_format and kv_format for training and inference
qkv_format, q_format, kv_format = dpa_utils.get_qkv_format(qkv_layout, inference_params) qkv_format, q_format, kv_format = dpa_utils.get_qkv_format(qkv_layout, inference_params)
# cuDNN can work with 0-length sequences in the batch for both bshd/sbhd and thd formats
# however, for bshd/sbhd, q/k/v tensors need to have the same batch size as indicated by
# cu_seqlens, whereas thd does not have this requirement
# e.g. if q_format = bshd, and q.shape = [3, 1, 16, 64], we should have k.shape[0] =
# v.shape[0] = q.shape[0], and cu_seqlens_q.shape = cu_seqlens_kv.shape = [4]
if q_format in ["bshd", "sbhd"] or kv_format in ["bshd", "sbhd"]:
batch_size = query_layer.shape[0] if q_format == "bshd" else query_layer.shape[1]
cu_seqlens_q = cu_seqlens_q[: batch_size + 1]
cu_seqlens_kv = cu_seqlens_kv[: batch_size + 1]
page_table = None page_table = None
if inference_params is None: if inference_params is None:
if qkv_format in ["sbhd", "bshd"]: if qkv_format in ["sbhd", "bshd"]:
...@@ -6214,7 +6224,11 @@ class DotProductAttention(TransformerEngineBaseModule): ...@@ -6214,7 +6224,11 @@ class DotProductAttention(TransformerEngineBaseModule):
# raise exception if no backend is available # raise exception if no backend is available
if sum([use_flash_attention, use_fused_attention, use_unfused_attention]) == 0: if sum([use_flash_attention, use_fused_attention, use_unfused_attention]) == 0:
raise ValueError("No dot product attention support for the provided inputs!") raise ValueError(
"No dot product attention backend is available for the provided inputs. Please"
" run with NVTE_DEBUG=1 NVTE_DEBUG_LEVEL=2 to find out the reasons for"
" disabling all backends."
)
# run attention # run attention
if use_flash_attention: if use_flash_attention:
......
...@@ -100,7 +100,7 @@ class InferenceParams: ...@@ -100,7 +100,7 @@ class InferenceParams:
---------- ----------
max_batch_size: int max_batch_size: int
Maximum batch size in inference Maximum batch size in inference
max_seqlen_kv: int max_sequence_length: int
Maximum sequence length in inference Maximum sequence length in inference
num_heads_kv: int num_heads_kv: int
Number of attention heads in keys and values Number of attention heads in keys and values
...@@ -117,7 +117,7 @@ class InferenceParams: ...@@ -117,7 +117,7 @@ class InferenceParams:
page_size: int, default = None page_size: int, default = None
Page size of the KV cache. Required for is_paged = True. Page size of the KV cache. Required for is_paged = True.
max_ctx_len: int, default = None max_ctx_len: int, default = None
Maximum context length in inference. 1 <= max_ctx_len <= max_seqlen_kv. Maximum context length in inference. 1 <= max_ctx_len <= max_sequence_length.
qkv_format: str, default = "bshd" qkv_format: str, default = "bshd"
Format of the incoming query/key/value tensors in current iteration Format of the incoming query/key/value tensors in current iteration
custom_cache_manager: KVCacheManager, default = None custom_cache_manager: KVCacheManager, default = None
...@@ -127,7 +127,7 @@ class InferenceParams: ...@@ -127,7 +127,7 @@ class InferenceParams:
def __init__( def __init__(
self, self,
max_batch_size: int, max_batch_size: int,
max_seqlen_kv: int, max_sequence_length: int,
num_heads_kv: int = 16, num_heads_kv: int = 16,
head_dim_k: int = 64, head_dim_k: int = 64,
dtype: torch.dtype = torch.bfloat16, dtype: torch.dtype = torch.bfloat16,
...@@ -140,7 +140,7 @@ class InferenceParams: ...@@ -140,7 +140,7 @@ class InferenceParams:
custom_cache_manager: KVCacheManager = None, custom_cache_manager: KVCacheManager = None,
): ):
self.max_batch_size = max_batch_size self.max_batch_size = max_batch_size
self.max_seqlen_kv = max_seqlen_kv self.max_sequence_length = max_sequence_length
self.num_heads_kv = num_heads_kv self.num_heads_kv = num_heads_kv
self.head_dim_k = head_dim_k self.head_dim_k = head_dim_k
self.dtype = dtype self.dtype = dtype
...@@ -153,7 +153,7 @@ class InferenceParams: ...@@ -153,7 +153,7 @@ class InferenceParams:
) )
self.cache_manager = cache_manager( self.cache_manager = cache_manager(
max_batch_size=self.max_batch_size, max_batch_size=self.max_batch_size,
max_seqlen=self.max_seqlen_kv, max_seqlen=self.max_sequence_length,
num_heads=self.num_heads_kv, num_heads=self.num_heads_kv,
head_dim_k=self.head_dim_k, head_dim_k=self.head_dim_k,
dtype=self.dtype, dtype=self.dtype,
...@@ -163,9 +163,9 @@ class InferenceParams: ...@@ -163,9 +163,9 @@ class InferenceParams:
assert page_size is not None, "Paged KV cache requires page_size is not None." assert page_size is not None, "Paged KV cache requires page_size is not None."
self.page_size = page_size self.page_size = page_size
assert ( assert (
max_seqlen_kv % page_size == 0 max_sequence_length % page_size == 0
), "Paged KV cache requires max_seqlen_kv % page_size = 0." ), "Paged KV cache requires max_sequence_length % page_size = 0."
max_pages_per_seq = max_seqlen_kv // page_size max_pages_per_seq = max_sequence_length // page_size
assert ( assert (
total_num_pages == self.max_batch_size * max_pages_per_seq total_num_pages == self.max_batch_size * max_pages_per_seq
), "Paged KV cache requires total_num_pages = max_batch_size * max_pages_per_seq." ), "Paged KV cache requires total_num_pages = max_batch_size * max_pages_per_seq."
...@@ -181,7 +181,7 @@ class InferenceParams: ...@@ -181,7 +181,7 @@ class InferenceParams:
head_dim_k=self.head_dim_k, head_dim_k=self.head_dim_k,
dtype=self.dtype, dtype=self.dtype,
max_batch_size=self.max_batch_size, max_batch_size=self.max_batch_size,
max_seqlen=self.max_seqlen_kv, max_seqlen=self.max_sequence_length,
head_dim_v=self.head_dim_v, head_dim_v=self.head_dim_v,
) )
...@@ -231,7 +231,7 @@ class InferenceParams: ...@@ -231,7 +231,7 @@ class InferenceParams:
f"dtype={self.dtype}, " f"dtype={self.dtype}, "
f"is_paged={self.is_paged}, " f"is_paged={self.is_paged}, "
f"max_batch_size={self.max_batch_size}, " f"max_batch_size={self.max_batch_size}, "
f"max_seqlen={self.max_seqlen_kv}, " f"max_seqlen={self.max_sequence_length}, "
f"num_heads={self.num_heads_kv}, " f"num_heads={self.num_heads_kv}, "
f"head_dim_k={self.head_dim_k}, " f"head_dim_k={self.head_dim_k}, "
f"head_dim_v={self.head_dim_v}" f"head_dim_v={self.head_dim_v}"
...@@ -241,8 +241,8 @@ class InferenceParams: ...@@ -241,8 +241,8 @@ class InferenceParams:
""" """
Allocate memory for the cache. For layer layer_number, Allocate memory for the cache. For layer layer_number,
- NonPagedKVCacheManager: - NonPagedKVCacheManager:
- K cache: [max_batch_size, max_seqlen_kv, num_heads_kv, head_dim_k] - K cache: [max_batch_size, max_sequence_length, num_heads_kv, head_dim_k]
- V cache: [max_batch_size, max_seqlen_kv, num_heads_kv, head_dim_v] - V cache: [max_batch_size, max_sequence_length, num_heads_kv, head_dim_v]
- PagedKVCacheManager: - PagedKVCacheManager:
- K cache: [total_num_pages, page_size, num_heads_kv, head_dim_k] - K cache: [total_num_pages, page_size, num_heads_kv, head_dim_k]
- V cache: [total_num_pages, page_size, num_heads_kv, head_dim_v] - V cache: [total_num_pages, page_size, num_heads_kv, head_dim_v]
...@@ -348,7 +348,7 @@ class InferenceParams: ...@@ -348,7 +348,7 @@ class InferenceParams:
Updated cumulative sequence lengths for key and value, [batch_size + 1] Updated cumulative sequence lengths for key and value, [batch_size + 1]
max_seqlen_q: int max_seqlen_q: int
Update maximum sequence length for query Update maximum sequence length for query
max_seqlen_kv: int max_sequence_length: int
Update maximum sequence length for key and value Update maximum sequence length for key and value
qkv_format: str qkv_format: str
Updated qkv_format, e.g. 'thd' format becomes 'thd_2bshd' after step() Updated qkv_format, e.g. 'thd' format becomes 'thd_2bshd' after step()
...@@ -373,7 +373,7 @@ class InferenceParams: ...@@ -373,7 +373,7 @@ class InferenceParams:
v_cache, v_cache,
self.cu_seqlens_q, self.cu_seqlens_q,
self.cu_seqlens_kv, self.cu_seqlens_kv,
self.max_seqlen_kv, self.max_sequence_length,
self.output_qkv_format, self.output_qkv_format,
) )
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment