Commit b92d6842 authored by Matt Maybeno's avatar Matt Maybeno Committed by Julien Chaumond
Browse files

Use roberta model and update doc strings

parent 66085a13
...@@ -478,12 +478,16 @@ class RobertaForTokenClassification(BertPreTrainedModel): ...@@ -478,12 +478,16 @@ class RobertaForTokenClassification(BertPreTrainedModel):
tokenizer = RobertaTokenizer.from_pretrained('roberta-base') tokenizer = RobertaTokenizer.from_pretrained('roberta-base')
model = RobertaForTokenClassification.from_pretrained('roberta-base') model = RobertaForTokenClassification.from_pretrained('roberta-base')
input_ids = torch.tensor(tokenizer.encode("Hello, my dog is cute")).unsqueeze(0) # Batch size 1 input_ids = torch.tensor(tokenizer.encode("Hello, my dog is cute", add_special_tokens=True)).unsqueeze(0) # Batch size 1
labels = torch.tensor([1] * input_ids.size(1)).unsqueeze(0) # Batch size 1 labels = torch.tensor([1] * input_ids.size(1)).unsqueeze(0) # Batch size 1
outputs = model(input_ids, labels=labels) outputs = model(input_ids, labels=labels)
loss, scores = outputs[:2] loss, scores = outputs[:2]
""" """
config_class = RobertaConfig
pretrained_model_archive_map = ROBERTA_PRETRAINED_MODEL_ARCHIVE_MAP
base_model_prefix = "roberta"
def __init__(self, config): def __init__(self, config):
super(RobertaForTokenClassification, self).__init__(config) super(RobertaForTokenClassification, self).__init__(config)
self.num_labels = config.num_labels self.num_labels = config.num_labels
......
...@@ -396,7 +396,7 @@ class TFRobertaForTokenClassification(TFRobertaPreTrainedModel): ...@@ -396,7 +396,7 @@ class TFRobertaForTokenClassification(TFRobertaPreTrainedModel):
tokenizer = RobertaTokenizer.from_pretrained('roberta-base') tokenizer = RobertaTokenizer.from_pretrained('roberta-base')
model = TFRobertaForTokenClassification.from_pretrained('roberta-base') model = TFRobertaForTokenClassification.from_pretrained('roberta-base')
input_ids = tf.constant(tokenizer.encode("Hello, my dog is cute"))[None, :] # Batch size 1 input_ids = tf.constant(tokenizer.encode("Hello, my dog is cute", add_special_tokens=True))[None, :] # Batch size 1
outputs = model(input_ids) outputs = model(input_ids)
scores = outputs[0] scores = outputs[0]
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment