Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
8deff3ac
Unverified
Commit
8deff3ac
authored
Mar 30, 2020
by
Sam Shleifer
Committed by
GitHub
Mar 30, 2020
Browse files
[bart-tiny-random] Put a 5MB model on S3 to allow faster exampl… (#3488)
parent
1f728657
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
23 additions
and
6 deletions
+23
-6
examples/summarization/bart/evaluate_cnn.py
examples/summarization/bart/evaluate_cnn.py
+10
-5
examples/summarization/bart/test_bart_examples.py
examples/summarization/bart/test_bart_examples.py
+2
-1
tests/test_modeling_bart.py
tests/test_modeling_bart.py
+11
-0
No files found.
examples/summarization/bart/evaluate_cnn.py
View file @
8deff3ac
...
...
@@ -16,15 +16,17 @@ def chunks(lst, n):
yield
lst
[
i
:
i
+
n
]
def
generate_summaries
(
lns
,
out_file
,
batch_size
=
8
,
device
=
DEFAULT_DEVICE
):
def
generate_summaries
(
examples
:
list
,
out_file
:
str
,
model_name
:
str
,
batch_size
:
int
=
8
,
device
:
str
=
DEFAULT_DEVICE
):
fout
=
Path
(
out_file
).
open
(
"w"
)
model
=
BartForConditionalGeneration
.
from_pretrained
(
"bart-large-cnn"
,
output_past
=
True
,).
to
(
device
)
model
=
BartForConditionalGeneration
.
from_pretrained
(
model_name
,
output_past
=
True
,).
to
(
device
)
tokenizer
=
BartTokenizer
.
from_pretrained
(
"bart-large"
)
max_length
=
140
min_length
=
55
for
batch
in
tqdm
(
list
(
chunks
(
ln
s
,
batch_size
))):
for
batch
in
tqdm
(
list
(
chunks
(
example
s
,
batch_size
))):
dct
=
tokenizer
.
batch_encode_plus
(
batch
,
max_length
=
1024
,
return_tensors
=
"pt"
,
pad_to_max_length
=
True
)
summaries
=
model
.
generate
(
input_ids
=
dct
[
"input_ids"
].
to
(
device
),
...
...
@@ -51,6 +53,9 @@ def _run_generate():
parser
.
add_argument
(
"output_path"
,
type
=
str
,
help
=
"where to save summaries"
,
)
parser
.
add_argument
(
"model_name"
,
type
=
str
,
default
=
"bart-large-cnn"
,
help
=
"like bart-large-cnn"
,
)
parser
.
add_argument
(
"--device"
,
type
=
str
,
required
=
False
,
default
=
DEFAULT_DEVICE
,
help
=
"cuda, cuda:1, cpu etc."
,
)
...
...
@@ -58,8 +63,8 @@ def _run_generate():
"--bs"
,
type
=
int
,
default
=
8
,
required
=
False
,
help
=
"batch size: how many to summarize at a time"
,
)
args
=
parser
.
parse_args
()
ln
s
=
[
" "
+
x
.
rstrip
()
for
x
in
open
(
args
.
source_path
).
readlines
()]
generate_summaries
(
ln
s
,
args
.
output_path
,
batch_size
=
args
.
bs
,
device
=
args
.
device
)
example
s
=
[
" "
+
x
.
rstrip
()
for
x
in
open
(
args
.
source_path
).
readlines
()]
generate_summaries
(
example
s
,
args
.
output_path
,
args
.
model_name
,
batch_size
=
args
.
bs
,
device
=
args
.
device
)
if
__name__
==
"__main__"
:
...
...
examples/summarization/bart/test_bart_examples.py
View file @
8deff3ac
...
...
@@ -25,7 +25,8 @@ class TestBartExamples(unittest.TestCase):
tmp
=
Path
(
tempfile
.
gettempdir
())
/
"utest_generations_bart_sum.hypo"
with
tmp
.
open
(
"w"
)
as
f
:
f
.
write
(
"
\n
"
.
join
(
articles
))
testargs
=
[
"evaluate_cnn.py"
,
str
(
tmp
),
output_file_name
]
testargs
=
[
"evaluate_cnn.py"
,
str
(
tmp
),
output_file_name
,
"sshleifer/bart-tiny-random"
]
with
patch
.
object
(
sys
,
"argv"
,
testargs
):
_run_generate
()
self
.
assertTrue
(
Path
(
output_file_name
).
exists
())
...
...
tests/test_modeling_bart.py
View file @
8deff3ac
...
...
@@ -27,7 +27,9 @@ from .utils import CACHE_DIR, require_torch, slow, torch_device
if
is_torch_available
():
import
torch
from
transformers
import
(
AutoModel
,
AutoModelForSequenceClassification
,
AutoTokenizer
,
BartModel
,
BartForConditionalGeneration
,
BartForSequenceClassification
,
...
...
@@ -183,6 +185,15 @@ class BARTModelTest(ModelTesterMixin, unittest.TestCase):
def
test_inputs_embeds
(
self
):
pass
def
test_tiny_model
(
self
):
model_name
=
"sshleifer/bart-tiny-random"
tiny
=
AutoModel
.
from_pretrained
(
model_name
)
# same vocab size
tok
=
AutoTokenizer
.
from_pretrained
(
model_name
)
# same tokenizer
inputs_dict
=
tok
.
batch_encode_plus
([
"Hello my friends"
],
return_tensors
=
"pt"
)
with
torch
.
no_grad
():
tiny
(
**
inputs_dict
)
@
require_torch
class
BartHeadTests
(
unittest
.
TestCase
):
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment