Unverified Commit 3e8761ab authored by Matt's avatar Matt Committed by GitHub
Browse files

Enable DefaultDataCollator class (#14141)

parent 84b9579d
......@@ -74,6 +74,11 @@ def default_data_collator(features: List[InputDataClass], return_tensors="pt") -
class DefaultDataCollator(DataCollatorMixin):
return_tensors: str = "pt"
def __call__(self, features: List[Dict[str, Any]], return_tensors=None) -> Dict[str, Any]:
if return_tensors is None:
return_tensors = self.return_tensors
return default_data_collator(features, return_tensors)
def torch_default_data_collator(features: List[InputDataClass]) -> Dict[str, Any]:
import torch
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment