Unverified Commit f785c516 authored by NielsRogge's avatar NielsRogge Committed by GitHub
Browse files

Update code example (#11631)

* Update code example

* Code review
parent 7e406f4a
...@@ -1069,6 +1069,7 @@ class LukeForEntityClassification(LukePreTrainedModel): ...@@ -1069,6 +1069,7 @@ class LukeForEntityClassification(LukePreTrainedModel):
>>> logits = outputs.logits >>> logits = outputs.logits
>>> predicted_class_idx = logits.argmax(-1).item() >>> predicted_class_idx = logits.argmax(-1).item()
>>> print("Predicted class:", model.config.id2label[predicted_class_idx]) >>> print("Predicted class:", model.config.id2label[predicted_class_idx])
Predicted class: person
""" """
return_dict = return_dict if return_dict is not None else self.config.use_return_dict return_dict = return_dict if return_dict is not None else self.config.use_return_dict
...@@ -1181,6 +1182,7 @@ class LukeForEntityPairClassification(LukePreTrainedModel): ...@@ -1181,6 +1182,7 @@ class LukeForEntityPairClassification(LukePreTrainedModel):
>>> logits = outputs.logits >>> logits = outputs.logits
>>> predicted_class_idx = logits.argmax(-1).item() >>> predicted_class_idx = logits.argmax(-1).item()
>>> print("Predicted class:", model.config.id2label[predicted_class_idx]) >>> print("Predicted class:", model.config.id2label[predicted_class_idx])
Predicted class: per:cities_of_residence
""" """
return_dict = return_dict if return_dict is not None else self.config.use_return_dict return_dict = return_dict if return_dict is not None else self.config.use_return_dict
...@@ -1309,8 +1311,12 @@ class LukeForEntitySpanClassification(LukePreTrainedModel): ...@@ -1309,8 +1311,12 @@ class LukeForEntitySpanClassification(LukePreTrainedModel):
>>> inputs = tokenizer(text, entity_spans=entity_spans, return_tensors="pt") >>> inputs = tokenizer(text, entity_spans=entity_spans, return_tensors="pt")
>>> outputs = model(**inputs) >>> outputs = model(**inputs)
>>> logits = outputs.logits >>> logits = outputs.logits
>>> predicted_class_idx = logits.argmax(-1).item() >>> predicted_class_indices = logits.argmax(-1).squeeze().tolist()
>>> print("Predicted class:", model.config.id2label[predicted_class_idx]) >>> for span, predicted_class_idx in zip(entity_spans, predicted_class_indices):
... if predicted_class_idx != 0:
... print(text[span[0]:span[1]], model.config.id2label[predicted_class_idx])
Beyoncé PER
Los Angeles LOC
""" """
return_dict = return_dict if return_dict is not None else self.config.use_return_dict return_dict = return_dict if return_dict is not None else self.config.use_return_dict
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment