Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
8bcf9c8d
Unverified
Commit
8bcf9c8d
authored
Jun 07, 2024
by
Cyril Vallez
Committed by
GitHub
Jun 07, 2024
Browse files
Fix jetmoe model (#31279)
* Fix jetmoe model * Remove skip-tests
parent
f868cf73
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
9 additions
and
21 deletions
+9
-21
src/transformers/models/jetmoe/modeling_jetmoe.py
src/transformers/models/jetmoe/modeling_jetmoe.py
+9
-13
tests/models/jetmoe/test_modeling_jetmoe.py
tests/models/jetmoe/test_modeling_jetmoe.py
+0
-8
No files found.
src/transformers/models/jetmoe/modeling_jetmoe.py
View file @
8bcf9c8d
...
...
@@ -1404,18 +1404,14 @@ class JetMoeForCausalLM(JetMoePreTrainedModel):
past_length
=
0
if
past_key_values
is
not
None
:
if
isinstance
(
past_key_values
,
Cache
):
past_length
=
cache_position
[
0
]
if
cache_position
is
not
None
else
past_key_values
.
get_seq_length
()
max_cache_length
=
(
torch
.
tensor
(
past_key_values
.
get_max_length
(),
device
=
input_ids
.
device
)
if
past_key_values
.
get_max_length
()
is
not
None
else
None
)
cache_length
=
past_length
if
max_cache_length
is
None
else
torch
.
min
(
max_cache_length
,
past_length
)
# TODO joao: remove this `else` after `generate` prioritizes `Cache` objects
else
:
cache_length
=
past_length
=
past_key_values
[
0
][
0
].
shape
[
2
]
max_cache_length
=
None
# Past key values are always initialized with a `Cache` object -> no need for if-else anymore
past_length
=
cache_position
[
0
]
if
cache_position
is
not
None
else
past_key_values
.
get_seq_length
()
max_cache_length
=
(
torch
.
tensor
(
past_key_values
.
get_max_length
(),
device
=
input_ids
.
device
)
if
past_key_values
.
get_max_length
()
is
not
None
else
None
)
cache_length
=
past_length
if
max_cache_length
is
None
else
torch
.
min
(
max_cache_length
,
past_length
)
# Keep only the unprocessed tokens:
# 1 - If the length of the attention_mask exceeds the length of input_ids, then we are in a setting where
...
...
@@ -1446,7 +1442,7 @@ class JetMoeForCausalLM(JetMoePreTrainedModel):
position_ids
=
position_ids
[:,
-
input_ids
.
shape
[
1
]
:]
# if `inputs_embeds` are passed, we only want to use them in the 1st generation step
if
inputs_embeds
is
not
None
and
past_
key_values
is
None
:
if
inputs_embeds
is
not
None
and
past_
length
==
0
:
model_inputs
=
{
"inputs_embeds"
:
inputs_embeds
}
else
:
# The `contiguous()` here is necessary to have a static stride during decoding. torchdynamo otherwise
...
...
tests/models/jetmoe/test_modeling_jetmoe.py
View file @
8bcf9c8d
...
...
@@ -472,14 +472,6 @@ class JetMoeModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMix
def
test_flash_attn_2_inference_equivalence_right_padding
(
self
):
self
.
skipTest
(
"JetMoe flash attention does not support right padding"
)
@
unittest
.
skip
(
"TODO: @ArthurZucker - Breaks after #30536 "
)
def
test_beam_sample_generate
(
self
):
pass
@
unittest
.
skip
(
"TODO: @ArthurZucker - Breaks after #30536 "
)
def
test_generate_from_inputs_embeds_decoder_only
(
self
):
pass
@
require_torch
class
JetMoeIntegrationTest
(
unittest
.
TestCase
):
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment