Unverified Commit f785c516 authored by NielsRogge's avatar NielsRogge Committed by GitHub
Browse files

Update code example (#11631)

* Update code example

* Code review
parent 7e406f4a
......@@ -1069,6 +1069,7 @@ class LukeForEntityClassification(LukePreTrainedModel):
>>> logits = outputs.logits
>>> predicted_class_idx = logits.argmax(-1).item()
>>> print("Predicted class:", model.config.id2label[predicted_class_idx])
Predicted class: person
"""
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
......@@ -1181,6 +1182,7 @@ class LukeForEntityPairClassification(LukePreTrainedModel):
>>> logits = outputs.logits
>>> predicted_class_idx = logits.argmax(-1).item()
>>> print("Predicted class:", model.config.id2label[predicted_class_idx])
Predicted class: per:cities_of_residence
"""
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
......@@ -1309,8 +1311,12 @@ class LukeForEntitySpanClassification(LukePreTrainedModel):
>>> inputs = tokenizer(text, entity_spans=entity_spans, return_tensors="pt")
>>> outputs = model(**inputs)
>>> logits = outputs.logits
>>> predicted_class_idx = logits.argmax(-1).item()
>>> print("Predicted class:", model.config.id2label[predicted_class_idx])
>>> predicted_class_indices = logits.argmax(-1).squeeze().tolist()
>>> for span, predicted_class_idx in zip(entity_spans, predicted_class_indices):
... if predicted_class_idx != 0:
... print(text[span[0]:span[1]], model.config.id2label[predicted_class_idx])
Beyoncé PER
Los Angeles LOC
"""
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment