"git@developer.sourcefind.cn:chenpangpang/transformers.git" did not exist on "1ad7b0398c02a84fb3dd1f110a99d609e1df6726"
Unverified Commit e5dcceb8 authored by Matt's avatar Matt Committed by GitHub
Browse files

Fixes to TF collators (#21143)

* Add num_workers for prepare_tf_dataset

* Bugfix in the default collator and change default tensor type

* Remove the "num_workers" arg and move it to a new PR
parent 2411f0e4
......@@ -159,7 +159,7 @@ def tf_default_data_collator(features: List[InputDataClass]) -> Dict[str, Any]:
label_col_name = None
if label_col_name is not None:
if isinstance(first[label_col_name], tf.Tensor):
dtype = tf.int64 if first[label_col_name].dtype.is_integer() else tf.float32
dtype = tf.int64 if first[label_col_name].dtype.is_integer else tf.float32
elif isinstance(first[label_col_name], np.ndarray) or isinstance(first[label_col_name], np.generic):
dtype = tf.int64 if np.issubdtype(first[label_col_name].dtype, np.integer) else tf.float32
elif isinstance(first[label_col_name], (tuple, list)):
......
......@@ -1345,9 +1345,9 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin, TFGenerationMixin, Pu
if collate_fn is None:
if tokenizer is None:
collate_fn = DefaultDataCollator(return_tensors="tf")
collate_fn = DefaultDataCollator(return_tensors="np")
else:
collate_fn = DataCollatorWithPadding(tokenizer=tokenizer, return_tensors="tf")
collate_fn = DataCollatorWithPadding(tokenizer=tokenizer, return_tensors="np")
if collate_fn_args is None:
collate_fn_args = dict()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment