Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
6e72fd09
Commit
6e72fd09
authored
Nov 08, 2019
by
Louis MARTIN
Committed by
Julien Chaumond
Nov 16, 2019
Browse files
Add demo_camembert.py
parent
14b3aa3b
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
59 additions
and
0 deletions
+59
-0
examples/demo_camembert.py
examples/demo_camembert.py
+59
-0
No files found.
examples/demo_camembert.py
0 → 100644
View file @
6e72fd09
from
pathlib
import
Path
import
tarfile
import
urllib.request
import
torch
from
transformers.tokenization_camembert
import
CamembertTokenizer
from
transformers.modeling_roberta
import
RobertaForMaskedLM
def
fill_mask
(
masked_input
,
model
,
tokenizer
,
topk
=
5
):
# Adapted from https://github.com/pytorch/fairseq/blob/master/fairseq/models/roberta/hub_interface.py
assert
masked_input
.
count
(
'<mask>'
)
==
1
input_ids
=
torch
.
tensor
(
tokenizer
.
encode
(
masked_input
,
add_special_tokens
=
True
)).
unsqueeze
(
0
)
# Batch size 1
logits
=
model
(
input_ids
)[
0
]
# The last hidden-state is the first element of the output tuple
masked_index
=
(
input_ids
.
squeeze
()
==
tokenizer
.
mask_token_id
).
nonzero
().
item
()
logits
=
logits
[
0
,
masked_index
,
:]
prob
=
logits
.
softmax
(
dim
=
0
)
values
,
indices
=
prob
.
topk
(
k
=
topk
,
dim
=
0
)
topk_predicted_token_bpe
=
' '
.
join
([
tokenizer
.
convert_ids_to_tokens
(
indices
[
i
].
item
())
for
i
in
range
(
len
(
indices
))])
masked_token
=
tokenizer
.
mask_token
topk_filled_outputs
=
[]
for
index
,
predicted_token_bpe
in
enumerate
(
topk_predicted_token_bpe
.
split
(
' '
)):
predicted_token
=
predicted_token_bpe
.
replace
(
'
\u2581
'
,
' '
)
if
" {0}"
.
format
(
masked_token
)
in
masked_input
:
topk_filled_outputs
.
append
((
masked_input
.
replace
(
' {0}'
.
format
(
masked_token
),
predicted_token
),
values
[
index
].
item
(),
predicted_token
,
))
else
:
topk_filled_outputs
.
append
((
masked_input
.
replace
(
masked_token
,
predicted_token
),
values
[
index
].
item
(),
predicted_token
,
))
return
topk_filled_outputs
model_path
=
Path
(
'camembert.v0.pytorch'
)
if
not
model_path
.
exists
():
compressed_path
=
model_path
.
with_suffix
(
'.tar.gz'
)
url
=
'http://dl.fbaipublicfiles.com/camembert/camembert.v0.pytorch.tar.gz'
print
(
'Downloading model...'
)
urllib
.
request
.
urlretrieve
(
url
,
compressed_path
)
print
(
'Extracting model...'
)
with
tarfile
.
open
(
compressed_path
)
as
f
:
f
.
extractall
(
model_path
.
parent
)
assert
model_path
.
exists
()
tokenizer_path
=
model_path
/
'sentencepiece.bpe.model'
tokenizer
=
CamembertTokenizer
.
from_pretrained
(
tokenizer_path
)
model
=
RobertaForMaskedLM
.
from_pretrained
(
model_path
)
model
.
eval
()
masked_input
=
"Le camembert est <mask> :)"
print
(
fill_mask
(
masked_input
,
model
,
tokenizer
,
topk
=
3
))
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment