"vscode:/vscode.git/clone" did not exist on "033b77ebc475ea1064ae1e77f1fee326c0b1332d"
Commit de2949f3 authored by Tri Dao's avatar Tri Dao
Browse files

[Rotary] Pass max_seqlen from mha.py to rotary during inference

parent 942fcbf0
...@@ -402,10 +402,10 @@ class RotaryEmbedding(torch.nn.Module): ...@@ -402,10 +402,10 @@ class RotaryEmbedding(torch.nn.Module):
Apply rotary embedding *inplace* to qkv and / or kv. Apply rotary embedding *inplace* to qkv and / or kv.
""" """
seqlen = qkv.shape[1] seqlen = qkv.shape[1]
if isinstance(seqlen_offset, int): if max_seqlen is not None:
self._update_cos_sin_cache(seqlen + seqlen_offset, device=qkv.device, dtype=qkv.dtype)
elif max_seqlen is not None:
self._update_cos_sin_cache(max_seqlen, device=qkv.device, dtype=qkv.dtype) self._update_cos_sin_cache(max_seqlen, device=qkv.device, dtype=qkv.dtype)
elif isinstance(seqlen_offset, int):
self._update_cos_sin_cache(seqlen + seqlen_offset, device=qkv.device, dtype=qkv.dtype)
if kv is None: if kv is None:
if self.scale is None: if self.scale is None:
return apply_rotary_emb_qkv_( return apply_rotary_emb_qkv_(
......
...@@ -606,6 +606,9 @@ class MHA(nn.Module): ...@@ -606,6 +606,9 @@ class MHA(nn.Module):
else {"key_padding_mask": key_padding_mask, **kwargs} else {"key_padding_mask": key_padding_mask, **kwargs}
) )
seqlen_offset = 0 if inference_params is None else inference_params.sequence_len_offset seqlen_offset = 0 if inference_params is None else inference_params.sequence_len_offset
rotary_max_seqlen = (
inference_params.max_sequene_len if inference_params is not None else None
)
if not self.cross_attn and self.num_heads_kv == self.num_heads: if not self.cross_attn and self.num_heads_kv == self.num_heads:
assert x_kv is None and mixer_subset is None assert x_kv is None and mixer_subset is None
if not self.return_residual: if not self.return_residual:
...@@ -623,7 +626,9 @@ class MHA(nn.Module): ...@@ -623,7 +626,9 @@ class MHA(nn.Module):
or not inference_params.fused_ft_kernel or not inference_params.fused_ft_kernel
): ):
if self.rotary_emb_dim > 0: if self.rotary_emb_dim > 0:
qkv = self.rotary_emb(qkv, seqlen_offset=seqlen_offset) qkv = self.rotary_emb(
qkv, seqlen_offset=seqlen_offset, max_seqlen=rotary_max_seqlen
)
if inference_params is None: if inference_params is None:
if not self.checkpointing: if not self.checkpointing:
context = self.inner_attn(qkv, **kwargs) context = self.inner_attn(qkv, **kwargs)
...@@ -669,7 +674,9 @@ class MHA(nn.Module): ...@@ -669,7 +674,9 @@ class MHA(nn.Module):
or not inference_params.fused_ft_kernel or not inference_params.fused_ft_kernel
): ):
if self.rotary_emb_dim > 0: if self.rotary_emb_dim > 0:
q, kv = self.rotary_emb(q, kv, seqlen_offset=seqlen_offset) q, kv = self.rotary_emb(
q, kv, seqlen_offset=seqlen_offset, max_seqlen=rotary_max_seqlen
)
if inference_params is None: if inference_params is None:
if not self.checkpointing: if not self.checkpointing:
context = self.inner_cross_attn(q, kv, **kwargs) context = self.inner_cross_attn(q, kv, **kwargs)
...@@ -851,6 +858,9 @@ class ParallelMHA(nn.Module): ...@@ -851,6 +858,9 @@ class ParallelMHA(nn.Module):
if seqlen is not None: if seqlen is not None:
qkv = rearrange(qkv, "(b s) ... -> b s ...", s=seqlen) qkv = rearrange(qkv, "(b s) ... -> b s ...", s=seqlen)
seqlen_offset = 0 if inference_params is None else inference_params.sequence_len_offset seqlen_offset = 0 if inference_params is None else inference_params.sequence_len_offset
rotary_max_seqlen = (
inference_params.max_sequene_len if inference_params is not None else None
)
if self.num_heads_kv == self.num_heads: if self.num_heads_kv == self.num_heads:
qkv = rearrange(qkv, "b s (three h d) -> b s three h d", three=3, d=self.head_dim) qkv = rearrange(qkv, "b s (three h d) -> b s three h d", three=3, d=self.head_dim)
if ( if (
...@@ -859,7 +869,9 @@ class ParallelMHA(nn.Module): ...@@ -859,7 +869,9 @@ class ParallelMHA(nn.Module):
or not inference_params.fused_ft_kernel or not inference_params.fused_ft_kernel
): ):
if self.rotary_emb_dim > 0: if self.rotary_emb_dim > 0:
qkv = self.rotary_emb(qkv, seqlen_offset=seqlen_offset) qkv = self.rotary_emb(
qkv, seqlen_offset=seqlen_offset, max_seqlen=rotary_max_seqlen
)
if inference_params is None: if inference_params is None:
if not self.checkpointing: if not self.checkpointing:
context = self.inner_attn(qkv, **kwargs) context = self.inner_attn(qkv, **kwargs)
...@@ -889,7 +901,9 @@ class ParallelMHA(nn.Module): ...@@ -889,7 +901,9 @@ class ParallelMHA(nn.Module):
or not inference_params.fused_ft_kernel or not inference_params.fused_ft_kernel
): ):
if self.rotary_emb_dim > 0: if self.rotary_emb_dim > 0:
q, kv = self.rotary_emb(q, kv, seqlen_offset=seqlen_offset) q, kv = self.rotary_emb(
q, kv, seqlen_offset=seqlen_offset, max_seqlen=rotary_max_seqlen
)
if inference_params is None: if inference_params is None:
if not self.checkpointing: if not self.checkpointing:
context = self.inner_cross_attn(q, kv, **kwargs) context = self.inner_cross_attn(q, kv, **kwargs)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment