Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
fd3eb3e3
Unverified
Commit
fd3eb3e3
authored
Mar 22, 2023
by
Joao Gante
Committed by
GitHub
Mar 22, 2023
Browse files
Beef up Llama tests (#22314)
* tmp commit * beef up llama tests
parent
12febc20
Changes
2
Show whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
15 additions
and
18 deletions
+15
-18
tests/generation/test_utils.py
tests/generation/test_utils.py
+1
-1
tests/models/llama/test_modeling_llama.py
tests/models/llama/test_modeling_llama.py
+14
-17
No files found.
tests/generation/test_utils.py
View file @
fd3eb3e3
...
...
@@ -1463,10 +1463,10 @@ class GenerationTesterMixin:
attention_names
=
[
"encoder_attentions"
,
"decoder_attentions"
,
"cross_attentions"
]
for
model_class
in
self
.
all_generative_model_classes
:
config
,
input_ids
,
attention_mask
,
max_length
=
self
.
_get_input_ids_and_config
()
model
=
model_class
(
config
).
to
(
torch_device
)
# We want to test only encoder-decoder models
if
not
config
.
is_encoder_decoder
:
continue
model
=
model_class
(
config
).
to
(
torch_device
)
head_masking
=
{
"head_mask"
:
torch
.
zeros
(
config
.
encoder_layers
,
config
.
encoder_attention_heads
,
device
=
torch_device
),
...
...
tests/models/llama/test_modeling_llama.py
View file @
fd3eb3e3
...
...
@@ -20,8 +20,10 @@ import unittest
from
transformers
import
LlamaConfig
,
is_torch_available
from
transformers.testing_utils
import
require_torch
,
torch_device
from
...generation.test_utils
import
GenerationTesterMixin
from
...test_configuration_common
import
ConfigTester
from
...test_modeling_common
import
ModelTesterMixin
,
ids_tensor
,
random_attention_mask
from
...test_pipeline_mixin
import
PipelineTesterMixin
if
is_torch_available
():
...
...
@@ -254,10 +256,21 @@ class LlamaModelTester:
@
require_torch
class
LlamaModelTest
(
ModelTesterMixin
,
unittest
.
TestCase
):
class
LlamaModelTest
(
ModelTesterMixin
,
GenerationTesterMixin
,
PipelineTesterMixin
,
unittest
.
TestCase
):
all_model_classes
=
(
LlamaModel
,
LlamaForCausalLM
,
LlamaForSequenceClassification
)
if
is_torch_available
()
else
()
all_generative_model_classes
=
(
LlamaForCausalLM
,)
if
is_torch_available
()
else
()
pipeline_model_mapping
=
(
{
"feature-extraction"
:
LlamaModel
,
"text-classification"
:
LlamaForSequenceClassification
,
"text-generation"
:
LlamaForCausalLM
,
"zero-shot"
:
LlamaForSequenceClassification
,
}
if
is_torch_available
()
else
{}
)
test_headmasking
=
False
test_pruning
=
False
def
setUp
(
self
):
self
.
model_tester
=
LlamaModelTester
(
self
)
...
...
@@ -316,22 +329,6 @@ class LlamaModelTest(ModelTesterMixin, unittest.TestCase):
result
=
model
(
input_ids
,
attention_mask
=
attention_mask
,
labels
=
sequence_labels
)
self
.
assertEqual
(
result
.
logits
.
shape
,
(
self
.
model_tester
.
batch_size
,
self
.
model_tester
.
num_labels
))
@
unittest
.
skip
(
"LLaMA does not support head pruning."
)
def
test_head_pruning
(
self
):
pass
@
unittest
.
skip
(
"LLaMA does not support head pruning."
)
def
test_head_pruning_integration
(
self
):
pass
@
unittest
.
skip
(
"LLaMA does not support head pruning."
)
def
test_head_pruning_save_load_from_config_init
(
self
):
pass
@
unittest
.
skip
(
"LLaMA does not support head pruning."
)
def
test_head_pruning_save_load_from_pretrained
(
self
):
pass
@
unittest
.
skip
(
"LLaMA buffers include complex numbers, which breaks this test"
)
def
test_save_load_fast_init_from_base
(
self
):
pass
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment