Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
ae49b218
Unverified
Commit
ae49b218
authored
Feb 21, 2024
by
Younes Belkada
Committed by
GitHub
Feb 21, 2024
Browse files
FIX [`Gemma`] Fix bad rebase with transformers main (#29170)
fix bad rebase
parent
594c1277
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
21 additions
and
23 deletions
+21
-23
src/transformers/models/gemma/modeling_gemma.py
src/transformers/models/gemma/modeling_gemma.py
+21
-23
No files found.
src/transformers/models/gemma/modeling_gemma.py
View file @
ae49b218
...
...
@@ -124,7 +124,7 @@ def rotate_half(x):
# Copied from transformers.models.llama.modeling_llama.apply_rotary_pos_emb
def
apply_rotary_pos_emb
(
q
,
k
,
cos
,
sin
,
position_ids
,
unsqueeze_dim
=
1
):
def
apply_rotary_pos_emb
(
q
,
k
,
cos
,
sin
,
position_ids
=
None
,
unsqueeze_dim
=
1
):
"""Applies Rotary Position Embedding to the query and key tensors.
Args:
...
...
@@ -132,9 +132,8 @@ def apply_rotary_pos_emb(q, k, cos, sin, position_ids, unsqueeze_dim=1):
k (`torch.Tensor`): The key tensor.
cos (`torch.Tensor`): The cosine part of the rotary embedding.
sin (`torch.Tensor`): The sine part of the rotary embedding.
position_ids (`torch.Tensor`):
The position indices of the tokens corresponding to the query and key tensors. For example, this can be
used to pass offsetted position ids when working with a KV-cache.
position_ids (`torch.Tensor`, *optional*):
Deprecated and unused.
unsqueeze_dim (`int`, *optional*, defaults to 1):
The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and
sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note
...
...
@@ -940,6 +939,10 @@ class GemmaModel(GemmaPreTrainedModel):
attentions
=
all_self_attns
,
)
# TODO: As of torch==2.2.0, the `attention_mask` passed to the model in `generate` is 2D and of dynamic length even when the static
# KV cache is used. This is an issue for torch.compile which then recaptures cudagraphs at each decode steps due to the dynamic shapes.
# (`recording cudagraph tree for symint key 13`, etc.), which is VERY slow. A workaround is `@torch.compiler.disable`, but this prevents using
# `fullgraph=True`. See more context in https://github.com/huggingface/transformers/pull/29114
def
_update_causal_mask
(
self
,
attention_mask
,
input_tensor
):
if
self
.
config
.
_attn_implementation
==
"flash_attention_2"
:
if
attention_mask
is
not
None
and
0.0
in
attention_mask
:
...
...
@@ -955,16 +958,8 @@ class GemmaModel(GemmaPreTrainedModel):
causal_mask
=
torch
.
full
((
2
*
self
.
causal_mask
.
shape
[
-
1
],
2
*
self
.
causal_mask
.
shape
[
-
1
]),
fill_value
=
1
)
self
.
register_buffer
(
"causal_mask"
,
torch
.
triu
(
causal_mask
,
diagonal
=
1
),
persistent
=
False
)
if
hasattr
(
self
,
"causal_mask"
):
# we use the current dtype to avoid any overflows
causal_mask
=
(
self
.
causal_mask
[
None
,
None
,
:,
:].
repeat
(
batch_size
,
1
,
1
,
1
).
to
(
dtype
)
*
torch
.
finfo
(
dtype
).
min
)
else
:
mask
=
torch
.
full
(
(
self
.
config
.
max_position_embeddings
,
self
.
config
.
max_position_embeddings
),
fill_value
=
torch
.
finfo
(
dtype
).
min
,
)
causal_mask
=
torch
.
triu
(
mask
,
diagonal
=
1
)
# We use the current dtype to avoid any overflows
causal_mask
=
self
.
causal_mask
[
None
,
None
,
:,
:].
repeat
(
batch_size
,
1
,
1
,
1
).
to
(
dtype
)
*
torch
.
finfo
(
dtype
).
min
causal_mask
=
causal_mask
.
to
(
dtype
=
dtype
,
device
=
device
)
if
attention_mask
is
not
None
and
attention_mask
.
dim
()
==
2
:
...
...
@@ -1146,29 +1141,32 @@ class GemmaForCausalLM(GemmaPreTrainedModel):
if
past_key_values
:
position_ids
=
position_ids
[:,
-
input_ids
.
shape
[
1
]
:]
if
past_key_value
:
=
getattr
(
self
.
model
.
layers
[
0
].
self_attn
,
"past_key_value"
,
None
):
if
getattr
(
self
.
model
.
layers
[
0
].
self_attn
,
"past_key_value"
,
None
)
is
not
None
:
# generation with static cache
past_length
=
past_key_value
.
get_seq_length
()
cache_position
=
kwargs
.
get
(
"cache_position"
,
None
)
if
cache_position
is
None
:
past_length
=
0
else
:
past_length
=
cache_position
[
-
1
]
+
1
input_ids
=
input_ids
[:,
past_length
:]
position_ids
=
position_ids
[:,
past_length
:]
# TODO @gante we should only keep a `cache_position` in generate, and do +=1.
# same goes for position ids. Could also help with continued generation.
cache_position
=
kwargs
.
get
(
"cache_position"
,
None
)
if
cache_position
is
None
:
cache_position
=
torch
.
arange
(
past_length
,
past_length
+
position_ids
.
shape
[
-
1
],
device
=
position_ids
.
device
)
cache_position
=
torch
.
arange
(
past_length
,
past_length
+
position_ids
.
shape
[
-
1
],
device
=
position_ids
.
device
)
# if `inputs_embeds` are passed, we only want to use them in the 1st generation step
if
inputs_embeds
is
not
None
and
past_key_values
is
None
:
model_inputs
=
{
"inputs_embeds"
:
inputs_embeds
}
else
:
model_inputs
=
{
"input_ids"
:
input_ids
}
# The `contiguous()` here is necessary to have a static stride during decoding. torchdynamo otherwise
# recompiles graphs as the stride of the inputs is a guard. Ref: https://github.com/huggingface/transformers/pull/29114
# TODO: use `next_tokens` directly instead.
model_inputs
=
{
"input_ids"
:
input_ids
.
contiguous
()}
model_inputs
.
update
(
{
"position_ids"
:
position_ids
,
"position_ids"
:
position_ids
.
contiguous
()
,
"cache_position"
:
cache_position
,
"past_key_values"
:
past_key_values
,
"use_cache"
:
kwargs
.
get
(
"use_cache"
),
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment