"...git@developer.sourcefind.cn:chenpangpang/open-webui.git" did not exist on "be5534c655f6cb3cc54c079de3f19d671e7a172b"
Unverified Commit 28931f81 authored by Sylvain Gugger's avatar Sylvain Gugger Committed by GitHub
Browse files

Fix #6092 (#6093)

* Fix #6092

* Format
parent 5e97c829
...@@ -5,6 +5,7 @@ import torch ...@@ -5,6 +5,7 @@ import torch
from torch.nn.utils.rnn import pad_sequence from torch.nn.utils.rnn import pad_sequence
from ..tokenization_utils import PreTrainedTokenizer from ..tokenization_utils import PreTrainedTokenizer
from ..tokenization_utils_base import BatchEncoding
InputDataClass = NewType("InputDataClass", Any) InputDataClass = NewType("InputDataClass", Any)
...@@ -33,7 +34,7 @@ def default_data_collator(features: List[InputDataClass]) -> Dict[str, torch.Ten ...@@ -33,7 +34,7 @@ def default_data_collator(features: List[InputDataClass]) -> Dict[str, torch.Ten
# have the same attributes. # have the same attributes.
# So we will look at the first element as a proxy for what attributes exist # So we will look at the first element as a proxy for what attributes exist
# on the whole batch. # on the whole batch.
if not isinstance(features[0], dict): if not isinstance(features[0], (dict, BatchEncoding)):
features = [vars(f) for f in features] features = [vars(f) for f in features]
first = features[0] first = features[0]
...@@ -78,7 +79,7 @@ class DataCollatorForLanguageModeling: ...@@ -78,7 +79,7 @@ class DataCollatorForLanguageModeling:
mlm_probability: float = 0.15 mlm_probability: float = 0.15
def __call__(self, examples: List[Union[torch.Tensor, Dict[str, torch.Tensor]]]) -> Dict[str, torch.Tensor]: def __call__(self, examples: List[Union[torch.Tensor, Dict[str, torch.Tensor]]]) -> Dict[str, torch.Tensor]:
if isinstance(examples[0], dict): if isinstance(examples[0], (dict, BatchEncoding)):
examples = [e["input_ids"] for e in examples] examples = [e["input_ids"] for e in examples]
batch = self._tensorize_batch(examples) batch = self._tensorize_batch(examples)
if self.mlm: if self.mlm:
...@@ -151,7 +152,7 @@ class DataCollatorForPermutationLanguageModeling: ...@@ -151,7 +152,7 @@ class DataCollatorForPermutationLanguageModeling:
max_span_length: int = 5 # maximum length of a span of masked tokens max_span_length: int = 5 # maximum length of a span of masked tokens
def __call__(self, examples: List[Union[torch.Tensor, Dict[str, torch.Tensor]]]) -> Dict[str, torch.Tensor]: def __call__(self, examples: List[Union[torch.Tensor, Dict[str, torch.Tensor]]]) -> Dict[str, torch.Tensor]:
if isinstance(examples[0], dict): if isinstance(examples[0], (dict, BatchEncoding)):
examples = [e["input_ids"] for e in examples] examples = [e["input_ids"] for e in examples]
batch = self._tensorize_batch(examples) batch = self._tensorize_batch(examples)
inputs, perm_mask, target_mapping, labels = self.mask_tokens(batch) inputs, perm_mask, target_mapping, labels = self.mask_tokens(batch)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment