Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
1fd60fec
Unverified
Commit
1fd60fec
authored
Jun 20, 2024
by
Joao Gante
Committed by
GitHub
Jun 20, 2024
Browse files
RWKV: enable generation tests (#31490)
* add rwkv tests * has_attentions set in individual tests
parent
d28e647f
Changes
3
Show whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
54 additions
and
19 deletions
+54
-19
src/transformers/models/rwkv/modeling_rwkv.py
src/transformers/models/rwkv/modeling_rwkv.py
+3
-18
tests/generation/test_utils.py
tests/generation/test_utils.py
+4
-0
tests/models/rwkv/test_modeling_rwkv.py
tests/models/rwkv/test_modeling_rwkv.py
+47
-1
No files found.
src/transformers/models/rwkv/modeling_rwkv.py
View file @
1fd60fec
...
@@ -625,6 +625,9 @@ class RwkvModel(RwkvPreTrainedModel):
...
@@ -625,6 +625,9 @@ class RwkvModel(RwkvPreTrainedModel):
use_cache
=
use_cache
if
use_cache
is
not
None
else
(
self
.
config
.
use_cache
if
not
self
.
training
else
False
)
use_cache
=
use_cache
if
use_cache
is
not
None
else
(
self
.
config
.
use_cache
if
not
self
.
training
else
False
)
return_dict
=
return_dict
if
return_dict
is
not
None
else
self
.
config
.
use_return_dict
return_dict
=
return_dict
if
return_dict
is
not
None
else
self
.
config
.
use_return_dict
if
attention_mask
is
None
:
logger
.
warning_once
(
"`attention_mask` was passed, but it is unused in this model."
)
if
self
.
training
==
self
.
layers_are_rescaled
:
if
self
.
training
==
self
.
layers_are_rescaled
:
self
.
_rescale_layers
()
self
.
_rescale_layers
()
...
@@ -765,24 +768,6 @@ class RwkvForCausalLM(RwkvPreTrainedModel):
...
@@ -765,24 +768,6 @@ class RwkvForCausalLM(RwkvPreTrainedModel):
def
set_output_embeddings
(
self
,
new_embeddings
):
def
set_output_embeddings
(
self
,
new_embeddings
):
self
.
head
=
new_embeddings
self
.
head
=
new_embeddings
def
generate
(
self
,
*
args
,
**
kwargs
):
# Thin wrapper to raise exceptions when trying to generate with methods that manipulate `past_key_values`.
# RWKV is one of the few models that don't have it (it has `state` instead, which has different properties and
# usage).
try
:
gen_output
=
super
().
generate
(
*
args
,
**
kwargs
)
except
AttributeError
as
exc
:
# Expected exception: "AttributeError: '(object name)' object has no attribute 'past_key_values'"
if
"past_key_values"
in
str
(
exc
):
raise
AttributeError
(
"You tried to call `generate` with a decoding strategy that manipulates `past_key_values`. RWKV "
"doesn't have that attribute, try another generation strategy instead. For the available "
"generation strategies, check this doc: https://huggingface.co/docs/transformers/en/generation_strategies#decoding-strategies"
)
else
:
raise
exc
return
gen_output
def
prepare_inputs_for_generation
(
self
,
input_ids
,
state
=
None
,
inputs_embeds
=
None
,
**
kwargs
):
def
prepare_inputs_for_generation
(
self
,
input_ids
,
state
=
None
,
inputs_embeds
=
None
,
**
kwargs
):
# only last token for inputs_ids if the state is passed along.
# only last token for inputs_ids if the state is passed along.
if
state
is
not
None
:
if
state
is
not
None
:
...
...
tests/generation/test_utils.py
View file @
1fd60fec
...
@@ -464,6 +464,8 @@ class GenerationTesterMixin:
...
@@ -464,6 +464,8 @@ class GenerationTesterMixin:
if
not
hasattr
(
config
,
"use_cache"
):
if
not
hasattr
(
config
,
"use_cache"
):
self
.
skipTest
(
"This model doesn't support caching"
)
self
.
skipTest
(
"This model doesn't support caching"
)
if
any
(
model_name
in
model_class
.
__name__
.
lower
()
for
model_name
in
[
"rwkv"
]):
self
.
skipTest
(
"Won't fix: model with non-standard dictionary output shapes"
)
config
.
use_cache
=
True
config
.
use_cache
=
True
config
.
is_decoder
=
True
config
.
is_decoder
=
True
...
@@ -624,6 +626,8 @@ class GenerationTesterMixin:
...
@@ -624,6 +626,8 @@ class GenerationTesterMixin:
if
not
hasattr
(
config
,
"use_cache"
):
if
not
hasattr
(
config
,
"use_cache"
):
self
.
skipTest
(
"This model doesn't support caching"
)
self
.
skipTest
(
"This model doesn't support caching"
)
if
any
(
model_name
in
model_class
.
__name__
.
lower
()
for
model_name
in
[
"rwkv"
]):
self
.
skipTest
(
"Won't fix: model with non-standard dictionary output shapes"
)
model
=
model_class
(
config
).
to
(
torch_device
).
eval
()
model
=
model_class
(
config
).
to
(
torch_device
).
eval
()
logits_process_kwargs
,
_
=
self
.
_get_logits_processor_and_warper_kwargs
(
logits_process_kwargs
,
_
=
self
.
_get_logits_processor_and_warper_kwargs
(
...
...
tests/models/rwkv/test_modeling_rwkv.py
View file @
1fd60fec
...
@@ -269,7 +269,7 @@ class RwkvModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin
...
@@ -269,7 +269,7 @@ class RwkvModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin
pipeline_model_mapping
=
(
pipeline_model_mapping
=
(
{
"feature-extraction"
:
RwkvModel
,
"text-generation"
:
RwkvForCausalLM
}
if
is_torch_available
()
else
{}
{
"feature-extraction"
:
RwkvModel
,
"text-generation"
:
RwkvForCausalLM
}
if
is_torch_available
()
else
{}
)
)
#
all_generative_model_classes = (RwkvForCausalLM,) if is_torch_available() else ()
all_generative_model_classes
=
(
RwkvForCausalLM
,)
if
is_torch_available
()
else
()
fx_compatible
=
False
fx_compatible
=
False
test_missing_keys
=
False
test_missing_keys
=
False
test_model_parallel
=
False
test_model_parallel
=
False
...
@@ -422,6 +422,52 @@ class RwkvModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin
...
@@ -422,6 +422,52 @@ class RwkvModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin
model
=
RwkvModel
.
from_pretrained
(
model_name
)
model
=
RwkvModel
.
from_pretrained
(
model_name
)
self
.
assertIsNotNone
(
model
)
self
.
assertIsNotNone
(
model
)
def
test_beam_sample_generate_dict_output
(
self
):
# This model has a custom attention output shape AND config flags, let's skip those checks
old_has_attentions
=
self
.
has_attentions
self
.
has_attentions
=
False
super
().
test_beam_sample_generate_dict_output
()
self
.
has_attentions
=
old_has_attentions
def
test_beam_search_generate_dict_output
(
self
):
# This model has a custom attention output shape AND config flags, let's skip those checks
old_has_attentions
=
self
.
has_attentions
self
.
has_attentions
=
False
super
().
test_beam_search_generate_dict_output
()
self
.
has_attentions
=
old_has_attentions
def
test_constrained_beam_search_generate_dict_output
(
self
):
# This model has a custom attention output shape AND config flags, let's skip those checks
old_has_attentions
=
self
.
has_attentions
self
.
has_attentions
=
False
super
().
test_constrained_beam_search_generate_dict_output
()
self
.
has_attentions
=
old_has_attentions
def
test_greedy_generate_dict_outputs
(
self
):
# This model has a custom attention output shape AND config flags, let's skip those checks
old_has_attentions
=
self
.
has_attentions
self
.
has_attentions
=
False
super
().
test_greedy_generate_dict_outputs
()
self
.
has_attentions
=
old_has_attentions
def
test_group_beam_search_generate_dict_output
(
self
):
# This model has a custom attention output shape AND config flags, let's skip those checks
old_has_attentions
=
self
.
has_attentions
self
.
has_attentions
=
False
super
().
test_group_beam_search_generate_dict_output
()
self
.
has_attentions
=
old_has_attentions
def
test_sample_generate_dict_output
(
self
):
# This model has a custom attention output shape AND config flags, let's skip those checks
old_has_attentions
=
self
.
has_attentions
self
.
has_attentions
=
False
super
().
test_sample_generate_dict_output
()
self
.
has_attentions
=
old_has_attentions
@
unittest
.
skip
(
"This model doesn't support padding"
)
def
test_left_padding_compatibility
(
self
):
pass
@
unittest
.
skipIf
(
@
unittest
.
skipIf
(
not
is_torch_greater_or_equal_than_2_0
,
reason
=
"See https://github.com/huggingface/transformers/pull/24204"
not
is_torch_greater_or_equal_than_2_0
,
reason
=
"See https://github.com/huggingface/transformers/pull/24204"
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment