Unverified Commit 37b3b7a7 authored by Fabian Joswig's avatar Fabian Joswig Committed by GitHub
Browse files

[PyTorch] Fix bug in micro batched inference with rotary embeddings (#536)



[fix] fixed micro batched inference with RoPE
Signed-off-by: default avatarFabian Joswig <fabian.joswig@deepl.com>
Co-authored-by: default avatarcyanguwa <8636796+cyanguwa@users.noreply.github.com>
parent e0be70d6
......@@ -2837,7 +2837,6 @@ class MultiheadAttention(torch.nn.Module):
# Pre-allocate memory for key-values for inference.
# =================================================
is_first_step = False
if inference_params and self.layer_number is not None:
if self.layer_number not in inference_params.key_value_memory_dict:
inf_max_seq_len = inference_params.max_sequence_len
......@@ -2852,7 +2851,6 @@ class MultiheadAttention(torch.nn.Module):
inference_key_memory,
inference_value_memory,
)
is_first_step = True
else:
(
inference_key_memory,
......@@ -3014,20 +3012,7 @@ class MultiheadAttention(torch.nn.Module):
# adjust the key rotary positional embedding
if rotary_pos_emb is not None:
q_pos_emb, k_pos_emb = rotary_pos_emb
# need to cross check this condition during inference
# if not set_inference_key_value_memory:
if not is_first_step:
# In inference, we compute one token at a time.
# Select the correct positional embedding
# (only the last token in the sequence)
q_pos_emb = q_pos_emb[sequence_end - 1 : sequence_end]
else:
# In the first forward pass of inference,
# we use the entire provided prefix.
# q_pos_emb here has the rope embeddings of the entire
# prefix + to-be-generated output so
# we slice to just the prefix.
q_pos_emb = q_pos_emb[:sequence_end, :, :, :]
q_pos_emb = q_pos_emb[sequence_start:sequence_end, :, :, :]
k_pos_emb = k_pos_emb[:sequence_end, :, :, :]
rotary_pos_emb = (q_pos_emb, k_pos_emb)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment