Unverified Commit b459ccc9 authored by Oleg Goncharov's avatar Oleg Goncharov Committed by GitHub
Browse files

[PyTorch] Adjusted the logic of MHA and DPA to enable speculative decoding (#668)



* Modified MHA and DPA logic to use causal softmax and FA for inference
Signed-off-by: default avatarOleg Goncharov <ogoncharov@nvidia.com>

* Adjusted unfused attention and softmax logic for inference
Signed-off-by: default avatarOleg Goncharov <ogoncharov@nvidia.com>

* Cleaned up the code per pylint
Signed-off-by: default avatarOleg Goncharov <ogoncharov@nvidia.com>

* Added test cases to evaluate numerics of incremental decoding
Signed-off-by: default avatarOleg Goncharov <ogoncharov@nvidia.com>

* Apply suggestions from code review
Co-authored-by: default avatarTim Moon <4406448+timmoon10@users.noreply.github.com>
Signed-off-by: default avatarOleg Goncharov <64355998+Oleg-Goncharov@users.noreply.github.com>

* Apply suggestions from code review [sequence start-end]
Co-authored-by: default avatarTim Moon <4406448+timmoon10@users.noreply.github.com>
Signed-off-by: default avatarOleg Goncharov <64355998+Oleg-Goncharov@users.noreply.github.com>

* Apply suggestions from code review [inference_params offset update]]
Co-authored-by: default avatarTim Moon <4406448+timmoon10@users.noreply.github.com>
Signed-off-by: default avatarOleg Goncharov <64355998+Oleg-Goncharov@users.noreply.github.com>

* Fixed bug in KV-cache indices and updated test suite
Signed-off-by: default avatarOleg Goncharov <ogoncharov@nvidia.com>

* Added inference_params description and applied suggestions from the code review
Signed-off-by: default avatarOleg Goncharov <ogoncharov@nvidia.com>

* Adjusted absolute tolerances in numerics tests
Signed-off-by: default avatarOleg Goncharov <ogoncharov@nvidia.com>

* Cleaned up the files per pylint
Signed-off-by: default avatarOleg Goncharov <ogoncharov@nvidia.com>

---------
Signed-off-by: default avatarOleg Goncharov <ogoncharov@nvidia.com>
Signed-off-by: default avatarOleg Goncharov <64355998+Oleg-Goncharov@users.noreply.github.com>
Co-authored-by: default avatarPrzemyslaw Tredak <ptredak@nvidia.com>
Co-authored-by: default avatarTim Moon <4406448+timmoon10@users.noreply.github.com>
Co-authored-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>
parent 728e335f
......@@ -22,7 +22,7 @@ from transformer_engine.pytorch.utils import (
)
from transformer_engine.pytorch import (
DotProductAttention, LayerNormLinear, LayerNormMLP, Linear,
MultiheadAttention, RMSNorm, TransformerLayer, LayerNorm
MultiheadAttention, RMSNorm, TransformerLayer, LayerNorm, InferenceParams
)
from transformer_engine.pytorch.distributed import checkpoint as te_checkpoint
from transformer_engine.pytorch.distributed import _set_cuda_rng_state, CudaRNGStatesTracker
......@@ -1397,3 +1397,118 @@ def test_transformer_layer_hidden_states_format(dtype, bs, model):
y_bshd = block_bshd(x_bshd)
assert_all_equal([y_bshd], [y_sbhd.transpose(0,1).contiguous()])
model_configs_inference = {
# hidden_size, eps, num_attention_heads, embed, num_layers, seq_len
"126m": ModelConfig(768, 1e-5, 12, 64, 12, 16),
}
backends_inference = ["FlashAttention", "UnfusedAttention"]
module_inference = ["TransformerLayer", "MultiheadAttention"]
input_formats_inference = ["sbhd", "bshd"]
@pytest.mark.parametrize("dtype", param_types)
@pytest.mark.parametrize("bs", batch_sizes)
@pytest.mark.parametrize("model_key", model_configs_inference.keys())
@pytest.mark.parametrize("use_RoPE", all_boolean)
@pytest.mark.parametrize("input_format", input_formats_inference)
@pytest.mark.parametrize("module", module_inference)
@pytest.mark.parametrize("backend", backends_inference)
def test_kv_cache_accuracy(dtype, bs, model_key, use_RoPE, input_format, module, backend):
os.environ["NVTE_FLASH_ATTN"] = "0"
os.environ["NVTE_FUSED_ATTN"] = "0"
if backend == "FlashAttention":
os.environ["NVTE_FLASH_ATTN"] = "1"
elif backend == "FusedAttention":
os.environ["NVTE_FUSED_ATTN"] = "1"
config = model_configs_inference[model_key]
S = config.seq_len
B = bs
H = config.num_attention_heads
D = config.hidden_size
head_size = config.embed
layer_number = 1
# Limits the max size of KV-cache
B_max = B
S_max = S + 2
if module == "TransformerLayer":
model = (
TransformerLayer(
hidden_size=D,
ffn_hidden_size= 4 * D,
num_attention_heads=H,
attn_input_format=input_format,
layer_number=layer_number,
attention_dropout = 0.0
)
.to(dtype=dtype)
.cuda()
.eval()
)
else:
model = (
MultiheadAttention(
hidden_size=D,
num_attention_heads=H,
qkv_format=input_format,
layer_number=layer_number,
attention_dropout = 0.0
)
.to(dtype=dtype)
.cuda()
.eval()
)
inference_params = InferenceParams(max_batch_size=B_max, max_sequence_length=S_max)
rotary_freqs = torch.randn((S_max, 1, 1, head_size), dtype=torch.float, device="cuda")
input = torch.randn((S, B, D), dtype=dtype, device="cuda")
if input_format == "bshd":
input = input.transpose(0, 1).contiguous()
incremental_output = torch.zeros_like(input)
# Generate output for the entire sequence
full_output = model(
hidden_states=input,
rotary_pos_emb=rotary_freqs if use_RoPE else None)
# Incrementaly generate outputs using KV-cache
for i in range(S):
if input_format == "sbhd":
incremental_input = input[i].view(1,B,D)
else:
incremental_input = input[:, i, :].view(B,1,D)
line_output = model(
hidden_states=incremental_input,
inference_params=inference_params,
rotary_pos_emb=rotary_freqs if use_RoPE else None)
inference_params.sequence_len_offset += 1
if input_format == "sbhd":
incremental_output[i] = line_output.view(B,D)
else:
incremental_output[:, i, :] = line_output.view(B,D)
if module == "TransformerLayer":
atol = {
torch.float32 : 5e-3,
torch.half : 5e-3,
torch.bfloat16: 5e-2,
}
else:
atol = {
torch.float32 : 1e-3,
torch.half : 1e-3,
torch.bfloat16: 1e-2,
}
# Check if the fully generated output matches the one generated incrementally
assert_allclose(full_output, incremental_output, atol[dtype])
......@@ -84,7 +84,6 @@ _alibi_cache = {
__all__ = ["DotProductAttention", "InferenceParams", "MultiheadAttention"]
class InferenceParams: # pylint: disable=too-few-public-methods
"""
Inference parameters that are passed to the main model in order
......@@ -1180,7 +1179,7 @@ def apply_rotary_pos_emb(
Parameters
----------
t: torch.Tensor
Input tensor of shape `[s, b, h, d]`, `[s, b, h, d]` or `[t, h, d]`, on which
Input tensor of shape `[s, b, h, d]`, `[b, s, h, d]` or `[t, h, d]`, on which
rotary positional embedding will be applied.
freqs: torch.Tensor
Rotary positional embedding tensor of shape `[s2, 1, 1, d2]` and dtype 'float',
......@@ -2523,6 +2522,7 @@ class DotProductAttention(torch.nn.Module):
core_attention_bias: Optional[torch.Tensor] = None,
alibi_slopes: Optional[torch.Tensor] = None,
fast_zero_fill: bool = True,
inference_params: Optional[InferenceParams] = None,
) -> torch.Tensor:
"""
Dot Product Attention Layer.
......@@ -2616,6 +2616,16 @@ class DotProductAttention(torch.nn.Module):
to the attention score of query i and key j.
fast_zero_fill: bool, default = `True`
Whether to use the fast path to set output tensors to 0 or not.
inference_params: Optional[InferenceParams], default = `None`
Optimizes execution performance during inference by caching Keys and Values of the
current decoding iteration. These cached values are appended to the K and V values
computed in previous iterations, eliminating the need to recalculate them for the
entire sequence.
Initialization of `inference_params` is required prior to use to ensure sufficient
memory allocation.
Adjustments of the sequence_len_offset should be done after a complete forward pass.
If rotary positional embeddings (RoPE) are utilized, they must be prepared beforehand.
Supports "sbhd" and "bshd" layouts, with the "sbhd" layout being more efficient.
"""
assert (
......@@ -2643,6 +2653,39 @@ class DotProductAttention(torch.nn.Module):
if qkv_format is None:
qkv_format = self.qkv_format
if inference_params is not None:
assert self.layer_number is not None, "Layer number must be set!"
if qkv_format == "bshd":
key_layer = key_layer.transpose(0, 1)
value_layer = value_layer.transpose(0, 1)
(inference_key_memory, inference_value_memory,
) = inference_params.key_value_memory_dict[self.layer_number]
batch_start = inference_params.batch_size_offset
batch_end = batch_start + key_layer.size(1)
assert batch_end <= inference_key_memory.size(1)
sequence_start = inference_params.sequence_len_offset
sequence_end = sequence_start + key_layer.size(0)
assert sequence_end <= inference_key_memory.size(0)
# Copy keys and values into KV-cache
inference_key_memory[
sequence_start:sequence_end, batch_start:batch_end, ...] = key_layer
inference_value_memory[
sequence_start:sequence_end, batch_start:batch_end, ...] = value_layer
key_layer = inference_key_memory[:sequence_end, batch_start:batch_end, ...]
value_layer = inference_value_memory[:sequence_end, batch_start:batch_end, ...]
if qkv_format == "bshd":
key_layer = key_layer.transpose(0, 1)
value_layer = value_layer.transpose(0, 1)
key_layer = key_layer.contiguous()
value_layer = value_layer.contiguous()
assert (key_layer.shape[-2] == self.num_gqa_groups_per_partition
and value_layer.shape[-2] == self.num_gqa_groups_per_partition
), f"Keys and values must have num_gqa_group = {self.num_gqa_groups} heads!"
......@@ -2721,12 +2764,15 @@ class DotProductAttention(torch.nn.Module):
use_flash_attention = False
# Filter: cross attention + causal mask.
if (_flash_attn_2_1_plus
# (in training mode)
if (inference_params is None
and _flash_attn_2_1_plus
and "causal" in attn_mask_type
and max_seqlen_q != max_seqlen_kv):
and max_seqlen_q != max_seqlen_kv
):
warnings.warn(
"Disabling the use of FlashAttention since version 2.1+ has changed its behavior "
"for causal mask in cross attention. See "
"In training mode, disable the use of FlashAttention since version 2.1+ has "
"changed its behavior for causal mask in cross attention. See "
"https://github.com/Dao-AILab/flash-attention#21-change-behavior-of-causal-flag"
)
use_flash_attention = False
......@@ -2753,7 +2799,11 @@ class DotProductAttention(torch.nn.Module):
if attn_mask_type == "arbitrary":
use_flash_attention = False
use_fused_attention = False
if "causal" in attn_mask_type and max_seqlen_q != max_seqlen_kv:
if (inference_params is None
and "causal" in attn_mask_type
and max_seqlen_q != max_seqlen_kv
):
use_unfused_attention = False
# Filter: bias.
......@@ -3446,12 +3496,12 @@ class MultiheadAttention(torch.nn.Module):
), f"core_attention_bias_type {core_attention_bias_type} is not supported!"
# =================================================
# Pre-allocate memory for key-values for inference.
# Pre-allocate memory for key-values for inference
# =================================================
if inference_params and self.layer_number is not None:
if self.layer_number not in inference_params.key_value_memory_dict:
inf_max_seq_len = inference_params.max_sequence_len
inf_max_seq_len = inference_params.max_sequence_length
inf_max_batch_size = inference_params.max_batch_size
inference_key_memory = self._allocate_memory(
inf_max_seq_len, inf_max_batch_size, hidden_states.dtype
......@@ -3469,9 +3519,9 @@ class MultiheadAttention(torch.nn.Module):
inference_value_memory,
) = inference_params.key_value_memory_dict[self.layer_number]
# =====================
# ======================
# Query, Key, and Value
# =====================
# ======================
if self.attention_type == "self":
# Attention heads [sq, b, h] --> [sq, b, ng * (np/ng + 2) * hn]
......@@ -3593,51 +3643,37 @@ class MultiheadAttention(torch.nn.Module):
)
query_layer = query_layer.view(*new_tensor_shape)
# ==================================
# Adjust key and value for inference
# ==================================
# ======================================================
# Apply relative positional encoding (rotary embedding)
# ======================================================
# duplicate the pos_emb for self attention
if rotary_pos_emb is not None:
# duplicate the pos_emb for self attention
if not isinstance(rotary_pos_emb, tuple):
rotary_pos_emb = ((rotary_pos_emb,) * 2)
if inference_params and self.layer_number is not None:
batch_start = inference_params.batch_size_offset
batch_end = batch_start + key_layer.size(1)
assert batch_end <= inference_key_memory.size(1)
sequence_start = inference_params.sequence_len_offset
sequence_end = sequence_start + key_layer.size(0)
assert sequence_end <= inference_key_memory.size(0)
# Copy key and values.
inference_key_memory[
sequence_start:sequence_end, batch_start:batch_end, ...
] = key_layer
inference_value_memory[
sequence_start:sequence_end, batch_start:batch_end, ...
] = value_layer
key_layer = inference_key_memory[:sequence_end, batch_start:batch_end, ...]
value_layer = inference_value_memory[
:sequence_end, batch_start:batch_end, ...
]
# adjust the key rotary positional embedding
if rotary_pos_emb is not None:
q_pos_emb, k_pos_emb = rotary_pos_emb
q_pos_emb = q_pos_emb[sequence_start:sequence_end, :, :, :]
k_pos_emb = k_pos_emb[:sequence_end, :, :, :]
rotary_pos_emb = (q_pos_emb, k_pos_emb)
# ==================================
# core attention computation
# ==================================
# apply relative positional encoding (rotary embedding)
if rotary_pos_emb is not None:
q_pos_emb, k_pos_emb = rotary_pos_emb
# adjust key and value for inference
if inference_params is not None:
if self.qkv_format == "sbhd":
sequence_length = key_layer.size(0)
elif self.qkv_format == "bshd":
sequence_length = key_layer.size(1)
sequence_start = inference_params.sequence_len_offset
sequence_end = sequence_start + sequence_length
q_pos_emb = q_pos_emb[sequence_start:sequence_end, ...]
k_pos_emb = k_pos_emb[sequence_start:sequence_end, ...]
query_layer = apply_rotary_pos_emb(query_layer, q_pos_emb, self.qkv_format, fused=True)
key_layer = apply_rotary_pos_emb(key_layer, k_pos_emb, self.qkv_format, fused=True)
# ===========================
# Core attention computation
# ===========================
context_layer = self.core_attention(
query_layer,
key_layer,
......@@ -3653,11 +3689,12 @@ class MultiheadAttention(torch.nn.Module):
core_attention_bias=core_attention_bias,
alibi_slopes=alibi_slopes,
fast_zero_fill=fast_zero_fill,
inference_params=inference_params,
)
# =================
# ===================
# Output. [sq, b, h]
# =================
# ===================
projection_output = self.proj(
context_layer, is_first_microbatch=is_first_microbatch
......
......@@ -20,11 +20,18 @@ THREADS_PER_BLOCK = 128
_default_causal_mask = {}
def _get_default_causal_mask(sq: int) -> torch.Tensor:
def _get_default_causal_mask(sq: int, sk: int) -> torch.Tensor:
"""Return the causal upper triangular mask for softmax input"""
if sq not in _default_causal_mask:
_default_causal_mask[sq] = torch.triu(torch.ones(sq, sq, device="cuda"), diagonal=1).bool()
return _default_causal_mask[sq]
if sq == 1:
return torch.zeros((1, sk), dtype=torch.bool, device="cuda")
matrix_shape = (sq, sk)
if matrix_shape not in _default_causal_mask:
diagonal_offset = sk - sq + 1
_default_causal_mask[matrix_shape] = torch.triu(
torch.ones(sq, sk, dtype=torch.bool, device="cuda"),
diagonal=diagonal_offset)
return _default_causal_mask[matrix_shape]
def _get_onnx_export_causal_mask(
......@@ -334,47 +341,46 @@ class FusedScaleMaskSoftmax(nn.Module):
attn_batches = b * np
if ( # pylint: disable=too-many-boolean-expressions
self.scaled_masked_softmax_fusion # user wants to fuse
and self.input_in_float16 # input must be fp16
and 16 <= sk <= 16384 # sk must be 16 ~ 16384
and sk % 8 == 0 # sk must be divisor of 8
and sq % 4 == 0 # sq must be divisor of 4
and attn_batches % 4 == 0 # np * b must be divisor of 4
not self.scaled_masked_softmax_fusion # user doesn't want to fuse
or not self.input_in_float16 # input must be fp16
or sk < 16
or sk > 16384 # sk must be 16 ~ 16384
or sk % 8 != 0 # sk must be divisor of 8
or self.attn_mask_type == "arbitrary" # Custom masks not supported
):
return False
if self.attn_mask_type == "causal": # unfused causal softmax kernel
return True
if (sq % 4 == 0 # sq must be divisor of 4
and attn_batches % 4 == 0 # np * b must be divisor of 4
and self.attn_mask_type != "arbitrary" # Custom masks not supported
):
if 0 <= sk <= 16384:
batch_per_block = self.get_batch_per_block(int(sk))
if self.attn_mask_type == "causal":
if attn_batches % batch_per_block == 0:
return True
elif self.attn_mask_type == "padding":
if (
mask is not None
and sq % batch_per_block == 0
and mask.shape[-2] == sq
and mask.shape[-1] == sk
):
return True
else:
if sq % batch_per_block == 0:
return True
batch_per_block = self.get_batch_per_block(int(sk))
if self.attn_mask_type == "padding":
if (
mask is not None
and sq % batch_per_block == 0
and mask.shape[-2] == sq
and mask.shape[-1] == sk
):
return True
else:
if sq % batch_per_block == 0:
return True
return False
def forward_fused_softmax(
self, inp: torch.Tensor, mask: torch.Tensor, scale: Optional[float] = None
) -> torch.Tensor:
"""Fused masked softmax kernel"""
b, np, sq, sk = inp.size()
scale = 1.0 if scale is None else scale
if self.attn_mask_type == "causal":
assert sq == sk, "causal mask is only for self attention"
return ScaledAlignedCausalMaskedSoftmax.apply(inp, scale)
# input is 3D tensor (attn_batches, sq, sk)
inp = inp.view(-1, sq, sk)
probs = ScaledUpperTriangMaskedSoftmax.apply(inp, scale)
return probs.view(b, np, sq, sk)
# input is 4D tensor (b, np, sq, sk)
if mask is not None and self.attn_mask_type != "no_mask":
return ScaledMaskedSoftmax.apply(inp, mask, scale)
......@@ -391,12 +397,12 @@ class FusedScaleMaskSoftmax(nn.Module):
inp = inp * scale
if self.attn_mask_type == "causal":
seq_len_q, seq_len_k = inp.size(2), inp.size(3)
if is_in_onnx_export_mode() and self.kvcache_max_seq > 0:
seq_len_q, seq_len_k = inp.size(2), inp.size(3)
assert self.kvcache_max_seq >= seq_len_k
mask = _get_onnx_export_causal_mask(seq_len_q, seq_len_k, self.onnx_causal_mask)
else:
mask = _get_default_causal_mask(inp.size(2))
mask = _get_default_causal_mask(seq_len_q, seq_len_k)
mask_output = inp
if mask is not None and self.attn_mask_type != "no_mask":
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment