Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
7d312ad2
Unverified
Commit
7d312ad2
authored
Feb 20, 2024
by
Joao Gante
Committed by
GitHub
Feb 20, 2024
Browse files
Llama: fix batched generation (#29109)
parent
ff76e7c2
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
35 additions
and
8 deletions
+35
-8
src/transformers/models/llama/modeling_llama.py
src/transformers/models/llama/modeling_llama.py
+30
-3
tests/test_cache_utils.py
tests/test_cache_utils.py
+5
-5
No files found.
src/transformers/models/llama/modeling_llama.py
View file @
7d312ad2
...
...
@@ -101,11 +101,34 @@ class LlamaRotaryEmbedding(nn.Module):
inv_freq
=
1.0
/
(
self
.
base
**
(
torch
.
arange
(
0
,
self
.
dim
,
2
,
dtype
=
torch
.
int64
).
float
().
to
(
device
)
/
self
.
dim
))
self
.
register_buffer
(
"inv_freq"
,
inv_freq
,
persistent
=
False
)
@
property
def
sin_cached
(
self
):
logger
.
warning_once
(
"The sin_cached attribute will be removed in 4.40. Bear in mind that its contents changed in v4.38. Use "
"the forward method of RoPE from now on instead."
)
return
self
.
_sin_cached
@
property
def
cos_cached
(
self
):
logger
.
warning_once
(
"The cos_cached attribute will be removed in 4.40. Bear in mind that its contents changed in v4.38. Use "
"the forward method of RoPE from now on instead."
)
return
self
.
_cos_cached
def
forward
(
self
,
x
,
position_ids
,
seq_len
=
None
):
# x: [bs, num_attention_heads, seq_len, head_size]
freqs
=
(
self
.
inv_freq
[:,
None
].
float
().
expand
(
-
1
,
position_ids
.
shape
[
0
])
@
(
position_ids
.
float
())).
t
()
inv_freq_expanded
=
self
.
inv_freq
[
None
,
:,
None
].
float
().
expand
(
position_ids
.
shape
[
0
],
-
1
,
1
)
position_ids_expanded
=
position_ids
[:,
None
,
:].
float
()
freqs
=
(
inv_freq_expanded
@
position_ids_expanded
).
transpose
(
1
,
2
)
emb
=
torch
.
cat
((
freqs
,
freqs
),
dim
=-
1
)
return
emb
.
cos
().
to
(
dtype
=
x
.
dtype
),
emb
.
sin
().
to
(
dtype
=
x
.
dtype
)
cos
=
emb
.
cos
().
to
(
dtype
=
x
.
dtype
)
sin
=
emb
.
sin
().
to
(
dtype
=
x
.
dtype
)
# backwards compatibility
self
.
_cos_cached
=
cos
self
.
_sin_cached
=
sin
return
cos
,
sin
class
LlamaLinearScalingRotaryEmbedding
(
LlamaRotaryEmbedding
):
...
...
@@ -181,6 +204,8 @@ def apply_rotary_pos_emb(q, k, cos, sin, position_ids, unsqueeze_dim=1):
Returns:
`tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding.
"""
cos
=
cos
.
unsqueeze
(
unsqueeze_dim
)
sin
=
sin
.
unsqueeze
(
unsqueeze_dim
)
q_embed
=
(
q
*
cos
)
+
(
rotate_half
(
q
)
*
sin
)
k_embed
=
(
k
*
cos
)
+
(
rotate_half
(
k
)
*
sin
)
return
q_embed
,
k_embed
...
...
@@ -1033,6 +1058,7 @@ class LlamaModel(LlamaPreTrainedModel):
batch_size
,
seq_length
=
input_tensor
.
shape
[:
2
]
dtype
=
input_tensor
.
dtype
device
=
input_tensor
.
device
# support going beyond cached `max_position_embedding`
if
seq_length
>
self
.
causal_mask
.
shape
[
-
1
]:
...
...
@@ -1048,8 +1074,9 @@ class LlamaModel(LlamaPreTrainedModel):
(
self
.
config
.
max_position_embeddings
,
self
.
config
.
max_position_embeddings
),
fill_value
=
torch
.
finfo
(
dtype
).
min
,
)
causal_mask
=
torch
.
triu
(
mask
,
diagonal
=
1
)
.
to
(
dtype
)
causal_mask
=
torch
.
triu
(
mask
,
diagonal
=
1
)
causal_mask
=
causal_mask
.
to
(
dtype
=
dtype
,
device
=
device
)
if
attention_mask
is
not
None
and
attention_mask
.
dim
()
==
2
:
mask_length
=
attention_mask
.
shape
[
-
1
]
padding_mask
=
causal_mask
[...,
:
mask_length
].
eq
(
0.0
)
*
attention_mask
[:,
None
,
None
,
:].
eq
(
0.0
)
...
...
tests/test_cache_utils.py
View file @
7d312ad2
...
...
@@ -293,7 +293,7 @@ class CacheIntegrationTest(unittest.TestCase):
@
parameterized
.
expand
([
"eager"
,
"sdpa"
,
"flash_attention_2"
])
def
test_static_cache_greedy_sampling_pad_left
(
self
,
attn_implementation
):
EXPECTED_GENERATION
=
[
"The best color is the one that complements the s
ubject you are photograph
"
,
"The best color is the one that complements the s
kin tone of the
"
,
"We should not undermind the issues at hand.
\n
We should not undermind the issues"
,
]
...
...
@@ -333,18 +333,18 @@ class CacheIntegrationTest(unittest.TestCase):
@
parameterized
.
expand
([
"eager"
,
"sdpa"
,
"flash_attention_2"
])
def
test_static_cache_greedy_sampling_pad_right
(
self
,
attn_implementation
):
EXPECTED_GENERATION
=
[
"The best color is
\n\n\n\n\n\n\n\n\n\n
"
,
"We should not undermind the issues at hand
, but address them head on.
\n
I think
"
,
"The best color is
Ћ the one that complements the skin tone of
"
,
"We should not undermind the issues at hand
.
\n
We should not undermind the issues
"
,
]
tokenizer
=
AutoTokenizer
.
from_pretrained
(
"NousResearch/Llama-2-7b-chat-hf"
,
padding_side
=
"
lef
t"
,
pad_token
=
"<s>"
"NousResearch/Llama-2-7b-chat-hf"
,
padding_side
=
"
righ
t"
,
pad_token
=
"<s>"
)
model
=
AutoModelForCausalLM
.
from_pretrained
(
"NousResearch/Llama-2-7b-chat-hf"
,
torch_dtype
=
torch
.
bfloat16
,
attn_implementation
=
attn_implementation
,
).
to
(
"cuda:1"
)
).
to
(
torch_device
)
inputs
=
tokenizer
(
[
"The best color is"
,
"We should not undermind the issues at hand"
],
padding
=
True
,
return_tensors
=
"pt"
).
to
(
model
.
device
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment