Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
fd3eb3e3
Unverified
Commit
fd3eb3e3
authored
Mar 22, 2023
by
Joao Gante
Committed by
GitHub
Mar 22, 2023
Browse files
Beef up Llama tests (#22314)
* tmp commit * beef up llama tests
parent
12febc20
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
15 additions
and
18 deletions
+15
-18
tests/generation/test_utils.py
tests/generation/test_utils.py
+1
-1
tests/models/llama/test_modeling_llama.py
tests/models/llama/test_modeling_llama.py
+14
-17
No files found.
tests/generation/test_utils.py
View file @
fd3eb3e3
...
@@ -1463,10 +1463,10 @@ class GenerationTesterMixin:
...
@@ -1463,10 +1463,10 @@ class GenerationTesterMixin:
attention_names
=
[
"encoder_attentions"
,
"decoder_attentions"
,
"cross_attentions"
]
attention_names
=
[
"encoder_attentions"
,
"decoder_attentions"
,
"cross_attentions"
]
for
model_class
in
self
.
all_generative_model_classes
:
for
model_class
in
self
.
all_generative_model_classes
:
config
,
input_ids
,
attention_mask
,
max_length
=
self
.
_get_input_ids_and_config
()
config
,
input_ids
,
attention_mask
,
max_length
=
self
.
_get_input_ids_and_config
()
model
=
model_class
(
config
).
to
(
torch_device
)
# We want to test only encoder-decoder models
# We want to test only encoder-decoder models
if
not
config
.
is_encoder_decoder
:
if
not
config
.
is_encoder_decoder
:
continue
continue
model
=
model_class
(
config
).
to
(
torch_device
)
head_masking
=
{
head_masking
=
{
"head_mask"
:
torch
.
zeros
(
config
.
encoder_layers
,
config
.
encoder_attention_heads
,
device
=
torch_device
),
"head_mask"
:
torch
.
zeros
(
config
.
encoder_layers
,
config
.
encoder_attention_heads
,
device
=
torch_device
),
...
...
tests/models/llama/test_modeling_llama.py
View file @
fd3eb3e3
...
@@ -20,8 +20,10 @@ import unittest
...
@@ -20,8 +20,10 @@ import unittest
from
transformers
import
LlamaConfig
,
is_torch_available
from
transformers
import
LlamaConfig
,
is_torch_available
from
transformers.testing_utils
import
require_torch
,
torch_device
from
transformers.testing_utils
import
require_torch
,
torch_device
from
...generation.test_utils
import
GenerationTesterMixin
from
...test_configuration_common
import
ConfigTester
from
...test_configuration_common
import
ConfigTester
from
...test_modeling_common
import
ModelTesterMixin
,
ids_tensor
,
random_attention_mask
from
...test_modeling_common
import
ModelTesterMixin
,
ids_tensor
,
random_attention_mask
from
...test_pipeline_mixin
import
PipelineTesterMixin
if
is_torch_available
():
if
is_torch_available
():
...
@@ -254,10 +256,21 @@ class LlamaModelTester:
...
@@ -254,10 +256,21 @@ class LlamaModelTester:
@
require_torch
@
require_torch
class
LlamaModelTest
(
ModelTesterMixin
,
unittest
.
TestCase
):
class
LlamaModelTest
(
ModelTesterMixin
,
GenerationTesterMixin
,
PipelineTesterMixin
,
unittest
.
TestCase
):
all_model_classes
=
(
LlamaModel
,
LlamaForCausalLM
,
LlamaForSequenceClassification
)
if
is_torch_available
()
else
()
all_model_classes
=
(
LlamaModel
,
LlamaForCausalLM
,
LlamaForSequenceClassification
)
if
is_torch_available
()
else
()
all_generative_model_classes
=
(
LlamaForCausalLM
,)
if
is_torch_available
()
else
()
all_generative_model_classes
=
(
LlamaForCausalLM
,)
if
is_torch_available
()
else
()
pipeline_model_mapping
=
(
{
"feature-extraction"
:
LlamaModel
,
"text-classification"
:
LlamaForSequenceClassification
,
"text-generation"
:
LlamaForCausalLM
,
"zero-shot"
:
LlamaForSequenceClassification
,
}
if
is_torch_available
()
else
{}
)
test_headmasking
=
False
test_headmasking
=
False
test_pruning
=
False
def
setUp
(
self
):
def
setUp
(
self
):
self
.
model_tester
=
LlamaModelTester
(
self
)
self
.
model_tester
=
LlamaModelTester
(
self
)
...
@@ -316,22 +329,6 @@ class LlamaModelTest(ModelTesterMixin, unittest.TestCase):
...
@@ -316,22 +329,6 @@ class LlamaModelTest(ModelTesterMixin, unittest.TestCase):
result
=
model
(
input_ids
,
attention_mask
=
attention_mask
,
labels
=
sequence_labels
)
result
=
model
(
input_ids
,
attention_mask
=
attention_mask
,
labels
=
sequence_labels
)
self
.
assertEqual
(
result
.
logits
.
shape
,
(
self
.
model_tester
.
batch_size
,
self
.
model_tester
.
num_labels
))
self
.
assertEqual
(
result
.
logits
.
shape
,
(
self
.
model_tester
.
batch_size
,
self
.
model_tester
.
num_labels
))
@
unittest
.
skip
(
"LLaMA does not support head pruning."
)
def
test_head_pruning
(
self
):
pass
@
unittest
.
skip
(
"LLaMA does not support head pruning."
)
def
test_head_pruning_integration
(
self
):
pass
@
unittest
.
skip
(
"LLaMA does not support head pruning."
)
def
test_head_pruning_save_load_from_config_init
(
self
):
pass
@
unittest
.
skip
(
"LLaMA does not support head pruning."
)
def
test_head_pruning_save_load_from_pretrained
(
self
):
pass
@
unittest
.
skip
(
"LLaMA buffers include complex numbers, which breaks this test"
)
@
unittest
.
skip
(
"LLaMA buffers include complex numbers, which breaks this test"
)
def
test_save_load_fast_init_from_base
(
self
):
def
test_save_load_fast_init_from_base
(
self
):
pass
pass
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment