Commit 3e20c2e8 authored by Louis MARTIN's avatar Louis MARTIN Committed by Julien Chaumond
Browse files

Update demo_camembert.py with new classes

parent f12e4d8d
......@@ -5,7 +5,7 @@ import urllib.request
import torch
from transformers.tokenization_camembert import CamembertTokenizer
from transformers.modeling_roberta import RobertaForMaskedLM
from transformers.modeling_camembert import CamembertForMaskedLM
def fill_mask(masked_input, model, tokenizer, topk=5):
......@@ -40,19 +40,8 @@ def fill_mask(masked_input, model, tokenizer, topk=5):
return topk_filled_outputs
model_path = Path('camembert.v0.pytorch')
if not model_path.exists():
compressed_path = model_path.with_suffix('.tar.gz')
url = 'http://dl.fbaipublicfiles.com/camembert/camembert.v0.pytorch.tar.gz'
print('Downloading model...')
urllib.request.urlretrieve(url, compressed_path)
print('Extracting model...')
with tarfile.open(compressed_path) as f:
f.extractall(model_path.parent)
assert model_path.exists()
tokenizer_path = model_path / 'sentencepiece.bpe.model'
tokenizer = CamembertTokenizer.from_pretrained(tokenizer_path)
model = RobertaForMaskedLM.from_pretrained(model_path)
tokenizer = CamembertTokenizer.from_pretrained('camembert-base')
model = CamembertForMaskedLM.from_pretrained('camembert-base')
model.eval()
masked_input = "Le camembert est <mask> :)"
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment