Unverified Commit d7511ec4 authored by Kirthi Shankar Sivamani's avatar Kirthi Shankar Sivamani Committed by GitHub
Browse files

[PyTorch] Inference params (KV cache) support (#466)



Inference params support
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>
parent daa5e184
...@@ -28,6 +28,9 @@ pyTorch ...@@ -28,6 +28,9 @@ pyTorch
.. autoapiclass:: transformer_engine.pytorch.TransformerLayer(hidden_size, ffn_hidden_size, num_attention_heads, **kwargs) .. autoapiclass:: transformer_engine.pytorch.TransformerLayer(hidden_size, ffn_hidden_size, num_attention_heads, **kwargs)
:members: forward :members: forward
.. autoapiclass:: transformer_engine.pytorch.InferenceParams(max_batch_size, max_sequence_length)
:members: swap_key_value_dict
.. autoapifunction:: transformer_engine.pytorch.fp8_autocast .. autoapifunction:: transformer_engine.pytorch.fp8_autocast
.. autoapifunction:: transformer_engine.pytorch.checkpoint .. autoapifunction:: transformer_engine.pytorch.checkpoint
......
...@@ -9,6 +9,7 @@ from .module import LayerNormMLP ...@@ -9,6 +9,7 @@ from .module import LayerNormMLP
from .module import LayerNorm from .module import LayerNorm
from .module import RMSNorm from .module import RMSNorm
from .attention import DotProductAttention from .attention import DotProductAttention
from .attention import InferenceParams
from .attention import MultiheadAttention from .attention import MultiheadAttention
from .transformer import TransformerLayer from .transformer import TransformerLayer
from .fp8 import fp8_autocast from .fp8 import fp8_autocast
......
...@@ -70,7 +70,53 @@ else: ...@@ -70,7 +70,53 @@ else:
_cu_seqlens_q, _cu_seqlens_kv, _indices_q, _indices_kv = None, None, None, None _cu_seqlens_q, _cu_seqlens_kv, _indices_q, _indices_kv = None, None, None, None
__all__ = ["DotProductAttention", "MultiheadAttention"] __all__ = ["DotProductAttention", "InferenceParams", "MultiheadAttention"]
class InferenceParams: # pylint: disable=too-few-public-methods
"""
Inference parameters that are passed to the main model in order
to efficienly calculate and store the context during inference.
Parameters
----------
max_batch_size : int
maximum batch size during inference.
max_sequence_length : int
maximum sequence length during inference.
"""
def __init__(self, max_batch_size, max_sequence_length):
self.max_sequence_length = max_sequence_length
self.max_batch_size = max_batch_size
self.sequence_len_offset = 0
self.batch_size_offset = 0
self.key_value_memory_dict = {}
def swap_key_value_dict(self, batch_indices):
"""
Reorders the KV cache using the specified batch indices.
Parameters
----------
batch_indices : List[int]
Sequence of indices to reorder along the batch dimensions of
the KV cache. Must have a length equal to the batch size.
"""
if len(self.key_value_memory_dict) == 0:
raise ValueError("should not swap when dict in empty")
for layer_number, inference_memory in self.key_value_memory_dict.items():
inference_key_memory, inference_value_memory = inference_memory
assert (
len(batch_indices) == inference_key_memory.shape[1]
) # make sure batch size is the same
new_inference_key_memory = inference_key_memory[:, batch_indices]
new_inference_value_memory = inference_value_memory[:, batch_indices]
self.key_value_memory_dict[layer_number] = (
new_inference_key_memory,
new_inference_value_memory,
)
def get_cu_seqlens_and_indices(mask: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: def get_cu_seqlens_and_indices(mask: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
...@@ -2522,7 +2568,7 @@ class MultiheadAttention(torch.nn.Module): ...@@ -2522,7 +2568,7 @@ class MultiheadAttention(torch.nn.Module):
attn_mask_type: Optional[str] = None, attn_mask_type: Optional[str] = None,
is_first_microbatch: Optional[bool] = None, is_first_microbatch: Optional[bool] = None,
checkpoint_core_attention: bool = False, checkpoint_core_attention: bool = False,
inference_params: Optional[Any] = None, inference_params: Optional[InferenceParams] = None,
rotary_pos_emb: Optional[Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]] = None, rotary_pos_emb: Optional[Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]] = None,
core_attention_bias_type: str = "no_bias", core_attention_bias_type: str = "no_bias",
core_attention_bias: Optional[torch.Tensor] = None, core_attention_bias: Optional[torch.Tensor] = None,
......
...@@ -6,13 +6,13 @@ ...@@ -6,13 +6,13 @@
import os import os
import warnings import warnings
from contextlib import nullcontext from contextlib import nullcontext
from typing import Any, Callable, List, Optional, Tuple, Union from typing import Callable, List, Optional, Tuple, Union
import torch import torch
import transformer_engine_extensions as tex import transformer_engine_extensions as tex
from transformer_engine.pytorch.module import LayerNormMLP, LayerNorm, RMSNorm from transformer_engine.pytorch.module import LayerNormMLP, LayerNorm, RMSNorm
from transformer_engine.pytorch.attention import MultiheadAttention from transformer_engine.pytorch.attention import InferenceParams, MultiheadAttention
from transformer_engine.pytorch.jit import ( from transformer_engine.pytorch.jit import (
set_jit_fusion_options, set_jit_fusion_options,
warmup_jit_bias_dropout_add_all_dtypes, warmup_jit_bias_dropout_add_all_dtypes,
...@@ -456,7 +456,7 @@ class TransformerLayer(torch.nn.Module): ...@@ -456,7 +456,7 @@ class TransformerLayer(torch.nn.Module):
enc_dec_attn_mask: Optional[torch.Tensor] = None, enc_dec_attn_mask: Optional[torch.Tensor] = None,
is_first_microbatch: Optional[bool] = None, is_first_microbatch: Optional[bool] = None,
checkpoint_core_attention: bool = False, checkpoint_core_attention: bool = False,
inference_params: Optional[Any] = None, inference_params: Optional[InferenceParams] = None,
rotary_pos_emb: Optional[Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]] = None, rotary_pos_emb: Optional[Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]] = None,
core_attention_bias_type: str = "no_bias", core_attention_bias_type: str = "no_bias",
core_attention_bias: Optional[torch.Tensor] = None, core_attention_bias: Optional[torch.Tensor] = None,
...@@ -512,6 +512,9 @@ class TransformerLayer(torch.nn.Module): ...@@ -512,6 +512,9 @@ class TransformerLayer(torch.nn.Module):
Bias tensor for Q * K.T Bias tensor for Q * K.T
fast_zero_fill: bool, default = `True` fast_zero_fill: bool, default = `True`
Whether to set output tensors to 0 or not before use. Whether to set output tensors to 0 or not before use.
inference_params: InferenceParams, default = None
Inference parameters that are passed to the main model in order
to efficienly calculate and store the context during inference.
""" """
if self_attn_mask_type is None: if self_attn_mask_type is None:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment