Commit 3e20c2e8 authored by Louis MARTIN's avatar Louis MARTIN Committed by Julien Chaumond
Browse files

Update demo_camembert.py with new classes

parent f12e4d8d
...@@ -5,7 +5,7 @@ import urllib.request ...@@ -5,7 +5,7 @@ import urllib.request
import torch import torch
from transformers.tokenization_camembert import CamembertTokenizer from transformers.tokenization_camembert import CamembertTokenizer
from transformers.modeling_roberta import RobertaForMaskedLM from transformers.modeling_camembert import CamembertForMaskedLM
def fill_mask(masked_input, model, tokenizer, topk=5): def fill_mask(masked_input, model, tokenizer, topk=5):
...@@ -40,19 +40,8 @@ def fill_mask(masked_input, model, tokenizer, topk=5): ...@@ -40,19 +40,8 @@ def fill_mask(masked_input, model, tokenizer, topk=5):
return topk_filled_outputs return topk_filled_outputs
model_path = Path('camembert.v0.pytorch') tokenizer = CamembertTokenizer.from_pretrained('camembert-base')
if not model_path.exists(): model = CamembertForMaskedLM.from_pretrained('camembert-base')
compressed_path = model_path.with_suffix('.tar.gz')
url = 'http://dl.fbaipublicfiles.com/camembert/camembert.v0.pytorch.tar.gz'
print('Downloading model...')
urllib.request.urlretrieve(url, compressed_path)
print('Extracting model...')
with tarfile.open(compressed_path) as f:
f.extractall(model_path.parent)
assert model_path.exists()
tokenizer_path = model_path / 'sentencepiece.bpe.model'
tokenizer = CamembertTokenizer.from_pretrained(tokenizer_path)
model = RobertaForMaskedLM.from_pretrained(model_path)
model.eval() model.eval()
masked_input = "Le camembert est <mask> :)" masked_input = "Le camembert est <mask> :)"
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment