Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
ae49b218
Unverified
Commit
ae49b218
authored
Feb 21, 2024
by
Younes Belkada
Committed by
GitHub
Feb 21, 2024
Browse files
FIX [`Gemma`] Fix bad rebase with transformers main (#29170)
fix bad rebase
parent
594c1277
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
21 additions
and
23 deletions
+21
-23
src/transformers/models/gemma/modeling_gemma.py
src/transformers/models/gemma/modeling_gemma.py
+21
-23
No files found.
src/transformers/models/gemma/modeling_gemma.py
View file @
ae49b218
...
...
@@ -124,7 +124,7 @@ def rotate_half(x):
# Copied from transformers.models.llama.modeling_llama.apply_rotary_pos_emb
def
apply_rotary_pos_emb
(
q
,
k
,
cos
,
sin
,
position_ids
,
unsqueeze_dim
=
1
):
def
apply_rotary_pos_emb
(
q
,
k
,
cos
,
sin
,
position_ids
=
None
,
unsqueeze_dim
=
1
):
"""Applies Rotary Position Embedding to the query and key tensors.
Args:
...
...
@@ -132,9 +132,8 @@ def apply_rotary_pos_emb(q, k, cos, sin, position_ids, unsqueeze_dim=1):
k (`torch.Tensor`): The key tensor.
cos (`torch.Tensor`): The cosine part of the rotary embedding.
sin (`torch.Tensor`): The sine part of the rotary embedding.
position_ids (`torch.Tensor`):
The position indices of the tokens corresponding to the query and key tensors. For example, this can be
used to pass offsetted position ids when working with a KV-cache.
position_ids (`torch.Tensor`, *optional*):
Deprecated and unused.
unsqueeze_dim (`int`, *optional*, defaults to 1):
The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and
sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note
...
...
@@ -940,6 +939,10 @@ class GemmaModel(GemmaPreTrainedModel):
attentions
=
all_self_attns
,
)
# TODO: As of torch==2.2.0, the `attention_mask` passed to the model in `generate` is 2D and of dynamic length even when the static
# KV cache is used. This is an issue for torch.compile which then recaptures cudagraphs at each decode steps due to the dynamic shapes.
# (`recording cudagraph tree for symint key 13`, etc.), which is VERY slow. A workaround is `@torch.compiler.disable`, but this prevents using
# `fullgraph=True`. See more context in https://github.com/huggingface/transformers/pull/29114
def
_update_causal_mask
(
self
,
attention_mask
,
input_tensor
):
if
self
.
config
.
_attn_implementation
==
"flash_attention_2"
:
if
attention_mask
is
not
None
and
0.0
in
attention_mask
:
...
...
@@ -955,16 +958,8 @@ class GemmaModel(GemmaPreTrainedModel):
causal_mask
=
torch
.
full
((
2
*
self
.
causal_mask
.
shape
[
-
1
],
2
*
self
.
causal_mask
.
shape
[
-
1
]),
fill_value
=
1
)
self
.
register_buffer
(
"causal_mask"
,
torch
.
triu
(
causal_mask
,
diagonal
=
1
),
persistent
=
False
)
if
hasattr
(
self
,
"causal_mask"
):
# we use the current dtype to avoid any overflows
causal_mask
=
(
self
.
causal_mask
[
None
,
None
,
:,
:].
repeat
(
batch_size
,
1
,
1
,
1
).
to
(
dtype
)
*
torch
.
finfo
(
dtype
).
min
)
else
:
mask
=
torch
.
full
(
(
self
.
config
.
max_position_embeddings
,
self
.
config
.
max_position_embeddings
),
fill_value
=
torch
.
finfo
(
dtype
).
min
,
)
causal_mask
=
torch
.
triu
(
mask
,
diagonal
=
1
)
# We use the current dtype to avoid any overflows
causal_mask
=
self
.
causal_mask
[
None
,
None
,
:,
:].
repeat
(
batch_size
,
1
,
1
,
1
).
to
(
dtype
)
*
torch
.
finfo
(
dtype
).
min
causal_mask
=
causal_mask
.
to
(
dtype
=
dtype
,
device
=
device
)
if
attention_mask
is
not
None
and
attention_mask
.
dim
()
==
2
:
...
...
@@ -1146,29 +1141,32 @@ class GemmaForCausalLM(GemmaPreTrainedModel):
if
past_key_values
:
position_ids
=
position_ids
[:,
-
input_ids
.
shape
[
1
]
:]
if
past_key_value
:
=
getattr
(
self
.
model
.
layers
[
0
].
self_attn
,
"past_key_value"
,
None
):
if
getattr
(
self
.
model
.
layers
[
0
].
self_attn
,
"past_key_value"
,
None
)
is
not
None
:
# generation with static cache
past_length
=
past_key_value
.
get_seq_length
()
cache_position
=
kwargs
.
get
(
"cache_position"
,
None
)
if
cache_position
is
None
:
past_length
=
0
else
:
past_length
=
cache_position
[
-
1
]
+
1
input_ids
=
input_ids
[:,
past_length
:]
position_ids
=
position_ids
[:,
past_length
:]
# TODO @gante we should only keep a `cache_position` in generate, and do +=1.
# same goes for position ids. Could also help with continued generation.
cache_position
=
kwargs
.
get
(
"cache_position"
,
None
)
if
cache_position
is
None
:
cache_position
=
torch
.
arange
(
past_length
,
past_length
+
position_ids
.
shape
[
-
1
],
device
=
position_ids
.
device
)
cache_position
=
torch
.
arange
(
past_length
,
past_length
+
position_ids
.
shape
[
-
1
],
device
=
position_ids
.
device
)
# if `inputs_embeds` are passed, we only want to use them in the 1st generation step
if
inputs_embeds
is
not
None
and
past_key_values
is
None
:
model_inputs
=
{
"inputs_embeds"
:
inputs_embeds
}
else
:
model_inputs
=
{
"input_ids"
:
input_ids
}
# The `contiguous()` here is necessary to have a static stride during decoding. torchdynamo otherwise
# recompiles graphs as the stride of the inputs is a guard. Ref: https://github.com/huggingface/transformers/pull/29114
# TODO: use `next_tokens` directly instead.
model_inputs
=
{
"input_ids"
:
input_ids
.
contiguous
()}
model_inputs
.
update
(
{
"position_ids"
:
position_ids
,
"position_ids"
:
position_ids
.
contiguous
()
,
"cache_position"
:
cache_position
,
"past_key_values"
:
past_key_values
,
"use_cache"
:
kwargs
.
get
(
"use_cache"
),
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment