Unverified Commit dfd81842 authored by casuallyName's avatar casuallyName Committed by GitHub
Browse files

Fix attribute error problem (#20765)



fix: 修复Trainer无法使用use_legacy_prediction_loop参数的问题

解决使用use_legacy_prediction_loop参数在predict阶段使用prediction_loop进行预测时,遇到AttributeError: 'PredictionOutput' object has no attribute 'num_samples'的问题
Co-authored-by: default avatarZhouHang <zhouhang@idataway.com>
parent 11745b4e
...@@ -3514,7 +3514,7 @@ class Trainer: ...@@ -3514,7 +3514,7 @@ class Trainer:
prediction_loss_only: Optional[bool] = None, prediction_loss_only: Optional[bool] = None,
ignore_keys: Optional[List[str]] = None, ignore_keys: Optional[List[str]] = None,
metric_key_prefix: str = "eval", metric_key_prefix: str = "eval",
) -> PredictionOutput: ) -> EvalLoopOutput:
""" """
Prediction/evaluation loop, shared by `Trainer.evaluate()` and `Trainer.predict()`. Prediction/evaluation loop, shared by `Trainer.evaluate()` and `Trainer.predict()`.
...@@ -3651,7 +3651,7 @@ class Trainer: ...@@ -3651,7 +3651,7 @@ class Trainer:
if not key.startswith(f"{metric_key_prefix}_"): if not key.startswith(f"{metric_key_prefix}_"):
metrics[f"{metric_key_prefix}_{key}"] = metrics.pop(key) metrics[f"{metric_key_prefix}_{key}"] = metrics.pop(key)
return PredictionOutput(predictions=preds, label_ids=label_ids, metrics=metrics) return EvalLoopOutput(predictions=preds, label_ids=label_ids, metrics=metrics, num_samples=num_examples)
def _gather_and_numpify(self, tensors, name): def _gather_and_numpify(self, tensors, name):
""" """
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment