Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
8bcf9c8d
Unverified
Commit
8bcf9c8d
authored
Jun 07, 2024
by
Cyril Vallez
Committed by
GitHub
Jun 07, 2024
Browse files
Fix jetmoe model (#31279)
* Fix jetmoe model * Remove skip-tests
parent
f868cf73
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
9 additions
and
21 deletions
+9
-21
src/transformers/models/jetmoe/modeling_jetmoe.py
src/transformers/models/jetmoe/modeling_jetmoe.py
+9
-13
tests/models/jetmoe/test_modeling_jetmoe.py
tests/models/jetmoe/test_modeling_jetmoe.py
+0
-8
No files found.
src/transformers/models/jetmoe/modeling_jetmoe.py
View file @
8bcf9c8d
...
@@ -1404,18 +1404,14 @@ class JetMoeForCausalLM(JetMoePreTrainedModel):
...
@@ -1404,18 +1404,14 @@ class JetMoeForCausalLM(JetMoePreTrainedModel):
past_length
=
0
past_length
=
0
if
past_key_values
is
not
None
:
if
past_key_values
is
not
None
:
if
isinstance
(
past_key_values
,
Cache
):
# Past key values are always initialized with a `Cache` object -> no need for if-else anymore
past_length
=
cache_position
[
0
]
if
cache_position
is
not
None
else
past_key_values
.
get_seq_length
()
past_length
=
cache_position
[
0
]
if
cache_position
is
not
None
else
past_key_values
.
get_seq_length
()
max_cache_length
=
(
max_cache_length
=
(
torch
.
tensor
(
past_key_values
.
get_max_length
(),
device
=
input_ids
.
device
)
torch
.
tensor
(
past_key_values
.
get_max_length
(),
device
=
input_ids
.
device
)
if
past_key_values
.
get_max_length
()
is
not
None
if
past_key_values
.
get_max_length
()
is
not
None
else
None
else
None
)
)
cache_length
=
past_length
if
max_cache_length
is
None
else
torch
.
min
(
max_cache_length
,
past_length
)
cache_length
=
past_length
if
max_cache_length
is
None
else
torch
.
min
(
max_cache_length
,
past_length
)
# TODO joao: remove this `else` after `generate` prioritizes `Cache` objects
else
:
cache_length
=
past_length
=
past_key_values
[
0
][
0
].
shape
[
2
]
max_cache_length
=
None
# Keep only the unprocessed tokens:
# Keep only the unprocessed tokens:
# 1 - If the length of the attention_mask exceeds the length of input_ids, then we are in a setting where
# 1 - If the length of the attention_mask exceeds the length of input_ids, then we are in a setting where
...
@@ -1446,7 +1442,7 @@ class JetMoeForCausalLM(JetMoePreTrainedModel):
...
@@ -1446,7 +1442,7 @@ class JetMoeForCausalLM(JetMoePreTrainedModel):
position_ids
=
position_ids
[:,
-
input_ids
.
shape
[
1
]
:]
position_ids
=
position_ids
[:,
-
input_ids
.
shape
[
1
]
:]
# if `inputs_embeds` are passed, we only want to use them in the 1st generation step
# if `inputs_embeds` are passed, we only want to use them in the 1st generation step
if
inputs_embeds
is
not
None
and
past_
key_values
is
None
:
if
inputs_embeds
is
not
None
and
past_
length
==
0
:
model_inputs
=
{
"inputs_embeds"
:
inputs_embeds
}
model_inputs
=
{
"inputs_embeds"
:
inputs_embeds
}
else
:
else
:
# The `contiguous()` here is necessary to have a static stride during decoding. torchdynamo otherwise
# The `contiguous()` here is necessary to have a static stride during decoding. torchdynamo otherwise
...
...
tests/models/jetmoe/test_modeling_jetmoe.py
View file @
8bcf9c8d
...
@@ -472,14 +472,6 @@ class JetMoeModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMix
...
@@ -472,14 +472,6 @@ class JetMoeModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMix
def
test_flash_attn_2_inference_equivalence_right_padding
(
self
):
def
test_flash_attn_2_inference_equivalence_right_padding
(
self
):
self
.
skipTest
(
"JetMoe flash attention does not support right padding"
)
self
.
skipTest
(
"JetMoe flash attention does not support right padding"
)
@
unittest
.
skip
(
"TODO: @ArthurZucker - Breaks after #30536 "
)
def
test_beam_sample_generate
(
self
):
pass
@
unittest
.
skip
(
"TODO: @ArthurZucker - Breaks after #30536 "
)
def
test_generate_from_inputs_embeds_decoder_only
(
self
):
pass
@
require_torch
@
require_torch
class
JetMoeIntegrationTest
(
unittest
.
TestCase
):
class
JetMoeIntegrationTest
(
unittest
.
TestCase
):
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment