Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
1fd60fec
Unverified
Commit
1fd60fec
authored
Jun 20, 2024
by
Joao Gante
Committed by
GitHub
Jun 20, 2024
Browse files
RWKV: enable generation tests (#31490)
* add rwkv tests * has_attentions set in individual tests
parent
d28e647f
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
54 additions
and
19 deletions
+54
-19
src/transformers/models/rwkv/modeling_rwkv.py
src/transformers/models/rwkv/modeling_rwkv.py
+3
-18
tests/generation/test_utils.py
tests/generation/test_utils.py
+4
-0
tests/models/rwkv/test_modeling_rwkv.py
tests/models/rwkv/test_modeling_rwkv.py
+47
-1
No files found.
src/transformers/models/rwkv/modeling_rwkv.py
View file @
1fd60fec
...
@@ -625,6 +625,9 @@ class RwkvModel(RwkvPreTrainedModel):
...
@@ -625,6 +625,9 @@ class RwkvModel(RwkvPreTrainedModel):
use_cache
=
use_cache
if
use_cache
is
not
None
else
(
self
.
config
.
use_cache
if
not
self
.
training
else
False
)
use_cache
=
use_cache
if
use_cache
is
not
None
else
(
self
.
config
.
use_cache
if
not
self
.
training
else
False
)
return_dict
=
return_dict
if
return_dict
is
not
None
else
self
.
config
.
use_return_dict
return_dict
=
return_dict
if
return_dict
is
not
None
else
self
.
config
.
use_return_dict
if
attention_mask
is
None
:
logger
.
warning_once
(
"`attention_mask` was passed, but it is unused in this model."
)
if
self
.
training
==
self
.
layers_are_rescaled
:
if
self
.
training
==
self
.
layers_are_rescaled
:
self
.
_rescale_layers
()
self
.
_rescale_layers
()
...
@@ -765,24 +768,6 @@ class RwkvForCausalLM(RwkvPreTrainedModel):
...
@@ -765,24 +768,6 @@ class RwkvForCausalLM(RwkvPreTrainedModel):
def
set_output_embeddings
(
self
,
new_embeddings
):
def
set_output_embeddings
(
self
,
new_embeddings
):
self
.
head
=
new_embeddings
self
.
head
=
new_embeddings
def
generate
(
self
,
*
args
,
**
kwargs
):
# Thin wrapper to raise exceptions when trying to generate with methods that manipulate `past_key_values`.
# RWKV is one of the few models that don't have it (it has `state` instead, which has different properties and
# usage).
try
:
gen_output
=
super
().
generate
(
*
args
,
**
kwargs
)
except
AttributeError
as
exc
:
# Expected exception: "AttributeError: '(object name)' object has no attribute 'past_key_values'"
if
"past_key_values"
in
str
(
exc
):
raise
AttributeError
(
"You tried to call `generate` with a decoding strategy that manipulates `past_key_values`. RWKV "
"doesn't have that attribute, try another generation strategy instead. For the available "
"generation strategies, check this doc: https://huggingface.co/docs/transformers/en/generation_strategies#decoding-strategies"
)
else
:
raise
exc
return
gen_output
def
prepare_inputs_for_generation
(
self
,
input_ids
,
state
=
None
,
inputs_embeds
=
None
,
**
kwargs
):
def
prepare_inputs_for_generation
(
self
,
input_ids
,
state
=
None
,
inputs_embeds
=
None
,
**
kwargs
):
# only last token for inputs_ids if the state is passed along.
# only last token for inputs_ids if the state is passed along.
if
state
is
not
None
:
if
state
is
not
None
:
...
...
tests/generation/test_utils.py
View file @
1fd60fec
...
@@ -464,6 +464,8 @@ class GenerationTesterMixin:
...
@@ -464,6 +464,8 @@ class GenerationTesterMixin:
if
not
hasattr
(
config
,
"use_cache"
):
if
not
hasattr
(
config
,
"use_cache"
):
self
.
skipTest
(
"This model doesn't support caching"
)
self
.
skipTest
(
"This model doesn't support caching"
)
if
any
(
model_name
in
model_class
.
__name__
.
lower
()
for
model_name
in
[
"rwkv"
]):
self
.
skipTest
(
"Won't fix: model with non-standard dictionary output shapes"
)
config
.
use_cache
=
True
config
.
use_cache
=
True
config
.
is_decoder
=
True
config
.
is_decoder
=
True
...
@@ -624,6 +626,8 @@ class GenerationTesterMixin:
...
@@ -624,6 +626,8 @@ class GenerationTesterMixin:
if
not
hasattr
(
config
,
"use_cache"
):
if
not
hasattr
(
config
,
"use_cache"
):
self
.
skipTest
(
"This model doesn't support caching"
)
self
.
skipTest
(
"This model doesn't support caching"
)
if
any
(
model_name
in
model_class
.
__name__
.
lower
()
for
model_name
in
[
"rwkv"
]):
self
.
skipTest
(
"Won't fix: model with non-standard dictionary output shapes"
)
model
=
model_class
(
config
).
to
(
torch_device
).
eval
()
model
=
model_class
(
config
).
to
(
torch_device
).
eval
()
logits_process_kwargs
,
_
=
self
.
_get_logits_processor_and_warper_kwargs
(
logits_process_kwargs
,
_
=
self
.
_get_logits_processor_and_warper_kwargs
(
...
...
tests/models/rwkv/test_modeling_rwkv.py
View file @
1fd60fec
...
@@ -269,7 +269,7 @@ class RwkvModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin
...
@@ -269,7 +269,7 @@ class RwkvModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin
pipeline_model_mapping
=
(
pipeline_model_mapping
=
(
{
"feature-extraction"
:
RwkvModel
,
"text-generation"
:
RwkvForCausalLM
}
if
is_torch_available
()
else
{}
{
"feature-extraction"
:
RwkvModel
,
"text-generation"
:
RwkvForCausalLM
}
if
is_torch_available
()
else
{}
)
)
#
all_generative_model_classes = (RwkvForCausalLM,) if is_torch_available() else ()
all_generative_model_classes
=
(
RwkvForCausalLM
,)
if
is_torch_available
()
else
()
fx_compatible
=
False
fx_compatible
=
False
test_missing_keys
=
False
test_missing_keys
=
False
test_model_parallel
=
False
test_model_parallel
=
False
...
@@ -422,6 +422,52 @@ class RwkvModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin
...
@@ -422,6 +422,52 @@ class RwkvModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin
model
=
RwkvModel
.
from_pretrained
(
model_name
)
model
=
RwkvModel
.
from_pretrained
(
model_name
)
self
.
assertIsNotNone
(
model
)
self
.
assertIsNotNone
(
model
)
def
test_beam_sample_generate_dict_output
(
self
):
# This model has a custom attention output shape AND config flags, let's skip those checks
old_has_attentions
=
self
.
has_attentions
self
.
has_attentions
=
False
super
().
test_beam_sample_generate_dict_output
()
self
.
has_attentions
=
old_has_attentions
def
test_beam_search_generate_dict_output
(
self
):
# This model has a custom attention output shape AND config flags, let's skip those checks
old_has_attentions
=
self
.
has_attentions
self
.
has_attentions
=
False
super
().
test_beam_search_generate_dict_output
()
self
.
has_attentions
=
old_has_attentions
def
test_constrained_beam_search_generate_dict_output
(
self
):
# This model has a custom attention output shape AND config flags, let's skip those checks
old_has_attentions
=
self
.
has_attentions
self
.
has_attentions
=
False
super
().
test_constrained_beam_search_generate_dict_output
()
self
.
has_attentions
=
old_has_attentions
def
test_greedy_generate_dict_outputs
(
self
):
# This model has a custom attention output shape AND config flags, let's skip those checks
old_has_attentions
=
self
.
has_attentions
self
.
has_attentions
=
False
super
().
test_greedy_generate_dict_outputs
()
self
.
has_attentions
=
old_has_attentions
def
test_group_beam_search_generate_dict_output
(
self
):
# This model has a custom attention output shape AND config flags, let's skip those checks
old_has_attentions
=
self
.
has_attentions
self
.
has_attentions
=
False
super
().
test_group_beam_search_generate_dict_output
()
self
.
has_attentions
=
old_has_attentions
def
test_sample_generate_dict_output
(
self
):
# This model has a custom attention output shape AND config flags, let's skip those checks
old_has_attentions
=
self
.
has_attentions
self
.
has_attentions
=
False
super
().
test_sample_generate_dict_output
()
self
.
has_attentions
=
old_has_attentions
@
unittest
.
skip
(
"This model doesn't support padding"
)
def
test_left_padding_compatibility
(
self
):
pass
@
unittest
.
skipIf
(
@
unittest
.
skipIf
(
not
is_torch_greater_or_equal_than_2_0
,
reason
=
"See https://github.com/huggingface/transformers/pull/24204"
not
is_torch_greater_or_equal_than_2_0
,
reason
=
"See https://github.com/huggingface/transformers/pull/24204"
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment