Unverified Commit d7511ec4 authored by Kirthi Shankar Sivamani's avatar Kirthi Shankar Sivamani Committed by GitHub
Browse files

[PyTorch] Inference params (KV cache) support (#466)



Inference params support
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>
parent daa5e184
......@@ -28,6 +28,9 @@ pyTorch
.. autoapiclass:: transformer_engine.pytorch.TransformerLayer(hidden_size, ffn_hidden_size, num_attention_heads, **kwargs)
:members: forward
.. autoapiclass:: transformer_engine.pytorch.InferenceParams(max_batch_size, max_sequence_length)
:members: swap_key_value_dict
.. autoapifunction:: transformer_engine.pytorch.fp8_autocast
.. autoapifunction:: transformer_engine.pytorch.checkpoint
......
......@@ -9,6 +9,7 @@ from .module import LayerNormMLP
from .module import LayerNorm
from .module import RMSNorm
from .attention import DotProductAttention
from .attention import InferenceParams
from .attention import MultiheadAttention
from .transformer import TransformerLayer
from .fp8 import fp8_autocast
......
......@@ -70,7 +70,53 @@ else:
_cu_seqlens_q, _cu_seqlens_kv, _indices_q, _indices_kv = None, None, None, None
__all__ = ["DotProductAttention", "MultiheadAttention"]
__all__ = ["DotProductAttention", "InferenceParams", "MultiheadAttention"]
class InferenceParams: # pylint: disable=too-few-public-methods
"""
Inference parameters that are passed to the main model in order
to efficienly calculate and store the context during inference.
Parameters
----------
max_batch_size : int
maximum batch size during inference.
max_sequence_length : int
maximum sequence length during inference.
"""
def __init__(self, max_batch_size, max_sequence_length):
self.max_sequence_length = max_sequence_length
self.max_batch_size = max_batch_size
self.sequence_len_offset = 0
self.batch_size_offset = 0
self.key_value_memory_dict = {}
def swap_key_value_dict(self, batch_indices):
"""
Reorders the KV cache using the specified batch indices.
Parameters
----------
batch_indices : List[int]
Sequence of indices to reorder along the batch dimensions of
the KV cache. Must have a length equal to the batch size.
"""
if len(self.key_value_memory_dict) == 0:
raise ValueError("should not swap when dict in empty")
for layer_number, inference_memory in self.key_value_memory_dict.items():
inference_key_memory, inference_value_memory = inference_memory
assert (
len(batch_indices) == inference_key_memory.shape[1]
) # make sure batch size is the same
new_inference_key_memory = inference_key_memory[:, batch_indices]
new_inference_value_memory = inference_value_memory[:, batch_indices]
self.key_value_memory_dict[layer_number] = (
new_inference_key_memory,
new_inference_value_memory,
)
def get_cu_seqlens_and_indices(mask: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
......@@ -2522,7 +2568,7 @@ class MultiheadAttention(torch.nn.Module):
attn_mask_type: Optional[str] = None,
is_first_microbatch: Optional[bool] = None,
checkpoint_core_attention: bool = False,
inference_params: Optional[Any] = None,
inference_params: Optional[InferenceParams] = None,
rotary_pos_emb: Optional[Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]] = None,
core_attention_bias_type: str = "no_bias",
core_attention_bias: Optional[torch.Tensor] = None,
......
......@@ -6,13 +6,13 @@
import os
import warnings
from contextlib import nullcontext
from typing import Any, Callable, List, Optional, Tuple, Union
from typing import Callable, List, Optional, Tuple, Union
import torch
import transformer_engine_extensions as tex
from transformer_engine.pytorch.module import LayerNormMLP, LayerNorm, RMSNorm
from transformer_engine.pytorch.attention import MultiheadAttention
from transformer_engine.pytorch.attention import InferenceParams, MultiheadAttention
from transformer_engine.pytorch.jit import (
set_jit_fusion_options,
warmup_jit_bias_dropout_add_all_dtypes,
......@@ -456,7 +456,7 @@ class TransformerLayer(torch.nn.Module):
enc_dec_attn_mask: Optional[torch.Tensor] = None,
is_first_microbatch: Optional[bool] = None,
checkpoint_core_attention: bool = False,
inference_params: Optional[Any] = None,
inference_params: Optional[InferenceParams] = None,
rotary_pos_emb: Optional[Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]] = None,
core_attention_bias_type: str = "no_bias",
core_attention_bias: Optional[torch.Tensor] = None,
......@@ -512,6 +512,9 @@ class TransformerLayer(torch.nn.Module):
Bias tensor for Q * K.T
fast_zero_fill: bool, default = `True`
Whether to set output tensors to 0 or not before use.
inference_params: InferenceParams, default = None
Inference parameters that are passed to the main model in order
to efficienly calculate and store the context during inference.
"""
if self_attn_mask_type is None:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment