Unverified Commit 59cd9de3 authored by Yih-Dar's avatar Yih-Dar Committed by GitHub
Browse files

Byebye torch 1.10 (#28207)



* fix

* fix

---------
Co-authored-by: default avatarydshieh <ydshieh@users.noreply.github.com>
parent e768616a
...@@ -15,7 +15,7 @@ jobs: ...@@ -15,7 +15,7 @@ jobs:
strategy: strategy:
fail-fast: false fail-fast: false
matrix: matrix:
version: ["1.13", "1.12", "1.11", "1.10"] version: ["1.13", "1.12", "1.11"]
runs-on: ubuntu-22.04 runs-on: ubuntu-22.04
steps: steps:
- -
......
...@@ -56,21 +56,10 @@ jobs: ...@@ -56,21 +56,10 @@ jobs:
sha: ${{ github.sha }} sha: ${{ github.sha }}
secrets: inherit secrets: inherit
run_past_ci_pytorch_1-10:
name: PyTorch 1.10
if: (cancelled() != true) && ((github.event_name == 'schedule') || ((github.event_name == 'push') && startsWith(github.ref_name, 'run_past_ci')))
needs: [run_past_ci_pytorch_1-11]
uses: ./.github/workflows/self-past.yml
with:
framework: pytorch
version: "1.10"
sha: ${{ github.sha }}
secrets: inherit
run_past_ci_tensorflow_2-11: run_past_ci_tensorflow_2-11:
name: TensorFlow 2.11 name: TensorFlow 2.11
if: (cancelled() != true) && ((github.event_name == 'push') && startsWith(github.ref_name, 'run_past_ci')) if: (cancelled() != true) && ((github.event_name == 'push') && startsWith(github.ref_name, 'run_past_ci'))
needs: [run_past_ci_pytorch_1-10] needs: [run_past_ci_pytorch_1-11]
uses: ./.github/workflows/self-past.yml uses: ./.github/workflows/self-past.yml
with: with:
framework: tensorflow framework: tensorflow
......
...@@ -250,7 +250,7 @@ The model itself is a regular [Pytorch `nn.Module`](https://pytorch.org/docs/sta ...@@ -250,7 +250,7 @@ The model itself is a regular [Pytorch `nn.Module`](https://pytorch.org/docs/sta
### With pip ### With pip
This repository is tested on Python 3.8+, Flax 0.4.1+, PyTorch 1.10+, and TensorFlow 2.6+. This repository is tested on Python 3.8+, Flax 0.4.1+, PyTorch 1.11+, and TensorFlow 2.6+.
You should install 🤗 Transformers in a [virtual environment](https://docs.python.org/3/library/venv.html). If you're unfamiliar with Python virtual environments, check out the [user guide](https://packaging.python.org/guides/installing-using-pip-and-virtual-environments/). You should install 🤗 Transformers in a [virtual environment](https://docs.python.org/3/library/venv.html). If you're unfamiliar with Python virtual environments, check out the [user guide](https://packaging.python.org/guides/installing-using-pip-and-virtual-environments/).
......
...@@ -225,7 +225,7 @@ El modelo en si es un [Pytorch `nn.Module`](https://pytorch.org/docs/stable/nn.h ...@@ -225,7 +225,7 @@ El modelo en si es un [Pytorch `nn.Module`](https://pytorch.org/docs/stable/nn.h
### Con pip ### Con pip
Este repositorio está probado en Python 3.8+, Flax 0.4.1+, PyTorch 1.10+ y TensorFlow 2.6+. Este repositorio está probado en Python 3.8+, Flax 0.4.1+, PyTorch 1.11+ y TensorFlow 2.6+.
Deberías instalar 🤗 Transformers en un [ambiente virtual](https://docs.python.org/3/library/venv.html). Si no estas familiarizado con los entornos virtuales de Python, consulta la [guía de usuario](https://packaging.python.org/guides/installing-using-pip-and-virtual-environments/). Deberías instalar 🤗 Transformers en un [ambiente virtual](https://docs.python.org/3/library/venv.html). Si no estas familiarizado con los entornos virtuales de Python, consulta la [guía de usuario](https://packaging.python.org/guides/installing-using-pip-and-virtual-environments/).
......
...@@ -201,7 +201,7 @@ checkpoint: जाँच बिंदु ...@@ -201,7 +201,7 @@ checkpoint: जाँच बिंदु
### पिप का उपयोग करना ### पिप का उपयोग करना
इस रिपॉजिटरी का परीक्षण Python 3.8+, Flax 0.4.1+, PyTorch 1.10+ और TensorFlow 2.6+ के तहत किया गया है। इस रिपॉजिटरी का परीक्षण Python 3.8+, Flax 0.4.1+, PyTorch 1.11+ और TensorFlow 2.6+ के तहत किया गया है।
आप [वर्चुअल एनवायरनमेंट](https://docs.python.org/3/library/venv.html) में 🤗 ट्रांसफॉर्मर इंस्टॉल कर सकते हैं। यदि आप अभी तक पायथन के वर्चुअल एनवायरनमेंट से परिचित नहीं हैं, तो कृपया इसे [उपयोगकर्ता निर्देश](https://packaging.python.org/guides/installing-using-pip-and-virtual-environments/) पढ़ें। आप [वर्चुअल एनवायरनमेंट](https://docs.python.org/3/library/venv.html) में 🤗 ट्रांसफॉर्मर इंस्टॉल कर सकते हैं। यदि आप अभी तक पायथन के वर्चुअल एनवायरनमेंट से परिचित नहीं हैं, तो कृपया इसे [उपयोगकर्ता निर्देश](https://packaging.python.org/guides/installing-using-pip-and-virtual-environments/) पढ़ें।
......
...@@ -259,7 +259,7 @@ Hugging Faceチームによって作られた **[トランスフォーマーを ...@@ -259,7 +259,7 @@ Hugging Faceチームによって作られた **[トランスフォーマーを
### pipにて ### pipにて
このリポジトリは、Python 3.8+, Flax 0.4.1+, PyTorch 1.10+, TensorFlow 2.6+ でテストされています。 このリポジトリは、Python 3.8+, Flax 0.4.1+, PyTorch 1.11+, TensorFlow 2.6+ でテストされています。
🤗Transformersは[仮想環境](https://docs.python.org/3/library/venv.html)にインストールする必要があります。Pythonの仮想環境に慣れていない場合は、[ユーザーガイド](https://packaging.python.org/guides/installing-using-pip-and-virtual-environments/)を確認してください。 🤗Transformersは[仮想環境](https://docs.python.org/3/library/venv.html)にインストールする必要があります。Pythonの仮想環境に慣れていない場合は、[ユーザーガイド](https://packaging.python.org/guides/installing-using-pip-and-virtual-environments/)を確認してください。
......
...@@ -176,7 +176,7 @@ limitations under the License. ...@@ -176,7 +176,7 @@ limitations under the License.
### pip로 설치하기 ### pip로 설치하기
이 저장소는 Python 3.8+, Flax 0.4.1+, PyTorch 1.10+, TensorFlow 2.6+에서 테스트 되었습니다. 이 저장소는 Python 3.8+, Flax 0.4.1+, PyTorch 1.11+, TensorFlow 2.6+에서 테스트 되었습니다.
[가상 환경](https://docs.python.org/3/library/venv.html)에 🤗 Transformers를 설치하세요. Python 가상 환경에 익숙하지 않다면, [사용자 가이드](https://packaging.python.org/guides/installing-using-pip-and-virtual-environments/)를 확인하세요. [가상 환경](https://docs.python.org/3/library/venv.html)에 🤗 Transformers를 설치하세요. Python 가상 환경에 익숙하지 않다면, [사용자 가이드](https://packaging.python.org/guides/installing-using-pip-and-virtual-environments/)를 확인하세요.
......
...@@ -258,7 +258,7 @@ O modelo em si é um [Pytorch `nn.Module`](https://pytorch.org/docs/stable/nn.ht ...@@ -258,7 +258,7 @@ O modelo em si é um [Pytorch `nn.Module`](https://pytorch.org/docs/stable/nn.ht
### Com pip ### Com pip
Este repositório é testado no Python 3.8+, Flax 0.4.1+, PyTorch 1.10+ e TensorFlow 2.6+. Este repositório é testado no Python 3.8+, Flax 0.4.1+, PyTorch 1.11+ e TensorFlow 2.6+.
Você deve instalar o 🤗 Transformers em um [ambiente virtual](https://docs.python.org/3/library/venv.html). Se você não está familiarizado com ambientes virtuais em Python, confira o [guia do usuário](https://packaging.python.org/guides/installing-using-pip-and-virtual-environments/). Você deve instalar o 🤗 Transformers em um [ambiente virtual](https://docs.python.org/3/library/venv.html). Se você não está familiarizado com ambientes virtuais em Python, confira o [guia do usuário](https://packaging.python.org/guides/installing-using-pip-and-virtual-environments/).
......
...@@ -248,7 +248,7 @@ Hugging Face Hub. Мы хотим, чтобы Transformers позволил ра ...@@ -248,7 +248,7 @@ Hugging Face Hub. Мы хотим, чтобы Transformers позволил ра
### С помощью pip ### С помощью pip
Данный репозиторий протестирован на Python 3.8+, Flax 0.4.1+, PyTorch 1.10+ и TensorFlow 2.6+. Данный репозиторий протестирован на Python 3.8+, Flax 0.4.1+, PyTorch 1.11+ и TensorFlow 2.6+.
Устанавливать 🤗 Transformers следует в [виртуальной среде](https://docs.python.org/3/library/venv.html). Если вы не знакомы с виртуальными средами Python, ознакомьтесь с [руководством пользователя](https://packaging.python.org/guides/installing-using-pip-and-virtual-environments/). Устанавливать 🤗 Transformers следует в [виртуальной среде](https://docs.python.org/3/library/venv.html). Если вы не знакомы с виртуальными средами Python, ознакомьтесь с [руководством пользователя](https://packaging.python.org/guides/installing-using-pip-and-virtual-environments/).
......
...@@ -251,7 +251,7 @@ limitations under the License. ...@@ -251,7 +251,7 @@ limitations under the License.
### పిప్ తో ### పిప్ తో
ఈ రిపోజిటరీ పైథాన్ 3.8+, ఫ్లాక్స్ 0.4.1+, PyTorch 1.10+ మరియు TensorFlow 2.6+లో పరీక్షించబడింది. ఈ రిపోజిటరీ పైథాన్ 3.8+, ఫ్లాక్స్ 0.4.1+, PyTorch 1.11+ మరియు TensorFlow 2.6+లో పరీక్షించబడింది.
మీరు [వర్చువల్ వాతావరణం](https://docs.python.org/3/library/venv.html)లో 🤗 ట్రాన్స్‌ఫార్మర్‌లను ఇన్‌స్టాల్ చేయాలి. మీకు పైథాన్ వర్చువల్ పరిసరాల గురించి తెలియకుంటే, [యూజర్ గైడ్](https://packaging.python.org/guides/installing-using-pip-and-virtual-environments/) చూడండి. మీరు [వర్చువల్ వాతావరణం](https://docs.python.org/3/library/venv.html)లో 🤗 ట్రాన్స్‌ఫార్మర్‌లను ఇన్‌స్టాల్ చేయాలి. మీకు పైథాన్ వర్చువల్ పరిసరాల గురించి తెలియకుంటే, [యూజర్ గైడ్](https://packaging.python.org/guides/installing-using-pip-and-virtual-environments/) చూడండి.
......
...@@ -201,7 +201,7 @@ checkpoint: 检查点 ...@@ -201,7 +201,7 @@ checkpoint: 检查点
### 使用 pip ### 使用 pip
这个仓库已在 Python 3.8+、Flax 0.4.1+、PyTorch 1.10+ 和 TensorFlow 2.6+ 下经过测试。 这个仓库已在 Python 3.8+、Flax 0.4.1+、PyTorch 1.11+ 和 TensorFlow 2.6+ 下经过测试。
你可以在[虚拟环境](https://docs.python.org/3/library/venv.html)中安装 🤗 Transformers。如果你还不熟悉 Python 的虚拟环境,请阅此[用户说明](https://packaging.python.org/guides/installing-using-pip-and-virtual-environments/) 你可以在[虚拟环境](https://docs.python.org/3/library/venv.html)中安装 🤗 Transformers。如果你还不熟悉 Python 的虚拟环境,请阅此[用户说明](https://packaging.python.org/guides/installing-using-pip-and-virtual-environments/)
......
...@@ -213,7 +213,7 @@ Tokenizer 為所有的預訓練模型提供了預處理,並可以直接轉換 ...@@ -213,7 +213,7 @@ Tokenizer 為所有的預訓練模型提供了預處理,並可以直接轉換
### 使用 pip ### 使用 pip
這個 Repository 已在 Python 3.8+、Flax 0.4.1+、PyTorch 1.10+ 和 TensorFlow 2.6+ 下經過測試。 這個 Repository 已在 Python 3.8+、Flax 0.4.1+、PyTorch 1.11+ 和 TensorFlow 2.6+ 下經過測試。
你可以在[虛擬環境](https://docs.python.org/3/library/venv.html)中安裝 🤗 Transformers。如果你還不熟悉 Python 的虛擬環境,請閱此[使用者指引](https://packaging.python.org/guides/installing-using-pip-and-virtual-environments/) 你可以在[虛擬環境](https://docs.python.org/3/library/venv.html)中安裝 🤗 Transformers。如果你還不熟悉 Python 的虛擬環境,請閱此[使用者指引](https://packaging.python.org/guides/installing-using-pip-and-virtual-environments/)
......
...@@ -175,7 +175,7 @@ _deps = [ ...@@ -175,7 +175,7 @@ _deps = [
"timeout-decorator", "timeout-decorator",
"timm", "timm",
"tokenizers>=0.14,<0.19", "tokenizers>=0.14,<0.19",
"torch>=1.10,!=1.12.0", "torch>=1.11,!=1.12.0",
"torchaudio", "torchaudio",
"torchvision", "torchvision",
"pyctcdecode>=0.4.0", "pyctcdecode>=0.4.0",
......
...@@ -273,30 +273,12 @@ def convert_pytorch(nlp: Pipeline, opset: int, output: Path, use_external_format ...@@ -273,30 +273,12 @@ def convert_pytorch(nlp: Pipeline, opset: int, output: Path, use_external_format
import torch import torch
from torch.onnx import export from torch.onnx import export
from transformers.pytorch_utils import is_torch_less_than_1_11
print(f"Using framework PyTorch: {torch.__version__}") print(f"Using framework PyTorch: {torch.__version__}")
with torch.no_grad(): with torch.no_grad():
input_names, output_names, dynamic_axes, tokens = infer_shapes(nlp, "pt") input_names, output_names, dynamic_axes, tokens = infer_shapes(nlp, "pt")
ordered_input_names, model_args = ensure_valid_input(nlp.model, tokens, input_names) ordered_input_names, model_args = ensure_valid_input(nlp.model, tokens, input_names)
# PyTorch deprecated the `enable_onnx_checker` and `use_external_data_format` arguments in v1.11,
# so we check the torch version for backwards compatibility
if is_torch_less_than_1_11:
export(
nlp.model,
model_args,
f=output.as_posix(),
input_names=ordered_input_names,
output_names=output_names,
dynamic_axes=dynamic_axes,
do_constant_folding=True,
use_external_data_format=use_external_format,
enable_onnx_checker=True,
opset_version=opset,
)
else:
export( export(
nlp.model, nlp.model,
model_args, model_args,
......
...@@ -80,7 +80,7 @@ deps = { ...@@ -80,7 +80,7 @@ deps = {
"timeout-decorator": "timeout-decorator", "timeout-decorator": "timeout-decorator",
"timm": "timm", "timm": "timm",
"tokenizers": "tokenizers>=0.14,<0.19", "tokenizers": "tokenizers>=0.14,<0.19",
"torch": "torch>=1.10,!=1.12.0", "torch": "torch>=1.11,!=1.12.0",
"torchaudio": "torchaudio", "torchaudio": "torchaudio",
"torchvision": "torchvision", "torchvision": "torchvision",
"pyctcdecode": "pyctcdecode>=0.4.0", "pyctcdecode": "pyctcdecode>=0.4.0",
......
...@@ -97,7 +97,6 @@ from .utils.import_utils import ( ...@@ -97,7 +97,6 @@ from .utils.import_utils import (
is_torchdynamo_compiling, is_torchdynamo_compiling,
) )
from .utils.quantization_config import AwqConfig, BitsAndBytesConfig, GPTQConfig, QuantizationMethod from .utils.quantization_config import AwqConfig, BitsAndBytesConfig, GPTQConfig, QuantizationMethod
from .utils.versions import require_version_core
XLA_USE_BF16 = os.environ.get("XLA_USE_BF16", "0").upper() XLA_USE_BF16 = os.environ.get("XLA_USE_BF16", "0").upper()
...@@ -2898,10 +2897,6 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix ...@@ -2898,10 +2897,6 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
raise ValueError("Passing along a `device_map` requires `low_cpu_mem_usage=True`") raise ValueError("Passing along a `device_map` requires `low_cpu_mem_usage=True`")
if low_cpu_mem_usage: if low_cpu_mem_usage:
if device_map is not None:
# The max memory utils require PyTorch >= 1.10 to have torch.cuda.mem_get_info.
require_version_core("torch>=1.10")
if is_deepspeed_zero3_enabled(): if is_deepspeed_zero3_enabled():
raise ValueError( raise ValueError(
"DeepSpeed Zero-3 is not compatible with `low_cpu_mem_usage=True` or with passing a `device_map`." "DeepSpeed Zero-3 is not compatible with `low_cpu_mem_usage=True` or with passing a `device_map`."
......
...@@ -43,23 +43,10 @@ if is_vision_available(): ...@@ -43,23 +43,10 @@ if is_vision_available():
if is_torch_available(): if is_torch_available():
import torch import torch
from transformers.pytorch_utils import is_torch_greater_or_equal_than_1_11
else:
is_torch_greater_or_equal_than_1_11 = False
logger = logging.get_logger(__name__) logger = logging.get_logger(__name__)
DEFAULT_FONT_PATH = "ybelkada/fonts" DEFAULT_FONT_PATH = "ybelkada/fonts"
def _check_torch_version():
if is_torch_available() and not is_torch_greater_or_equal_than_1_11:
raise ImportError(
f"You are using torch=={torch.__version__}, but torch>=1.11.0 is required to use "
"Pix2StructImageProcessor. Please upgrade torch."
)
# adapted from: https://discuss.pytorch.org/t/tf-image-extract-patches-in-pytorch/171409/2 # adapted from: https://discuss.pytorch.org/t/tf-image-extract-patches-in-pytorch/171409/2
def torch_extract_patches(image_tensor, patch_height, patch_width): def torch_extract_patches(image_tensor, patch_height, patch_width):
""" """
...@@ -75,7 +62,6 @@ def torch_extract_patches(image_tensor, patch_height, patch_width): ...@@ -75,7 +62,6 @@ def torch_extract_patches(image_tensor, patch_height, patch_width):
The width of the patches to extract. The width of the patches to extract.
""" """
requires_backends(torch_extract_patches, ["torch"]) requires_backends(torch_extract_patches, ["torch"])
_check_torch_version()
image_tensor = image_tensor.unsqueeze(0) image_tensor = image_tensor.unsqueeze(0)
patches = torch.nn.functional.unfold(image_tensor, (patch_height, patch_width), stride=(patch_height, patch_width)) patches = torch.nn.functional.unfold(image_tensor, (patch_height, patch_width), stride=(patch_height, patch_width))
...@@ -262,7 +248,6 @@ class Pix2StructImageProcessor(BaseImageProcessor): ...@@ -262,7 +248,6 @@ class Pix2StructImageProcessor(BaseImageProcessor):
A sequence of `max_patches` flattened patches. A sequence of `max_patches` flattened patches.
""" """
requires_backends(self.extract_flattened_patches, "torch") requires_backends(self.extract_flattened_patches, "torch")
_check_torch_version()
# convert to torch # convert to torch
image = to_channel_dimension_format(image, ChannelDimension.FIRST, input_data_format) image = to_channel_dimension_format(image, ChannelDimension.FIRST, input_data_format)
......
...@@ -33,7 +33,6 @@ from .config import OnnxConfig ...@@ -33,7 +33,6 @@ from .config import OnnxConfig
if is_torch_available(): if is_torch_available():
from ..modeling_utils import PreTrainedModel from ..modeling_utils import PreTrainedModel
from ..pytorch_utils import is_torch_less_than_1_11
if is_tf_available(): if is_tf_available():
from ..modeling_tf_utils import TFPreTrainedModel from ..modeling_tf_utils import TFPreTrainedModel
...@@ -167,39 +166,6 @@ def export_pytorch( ...@@ -167,39 +166,6 @@ def export_pytorch(
config.patch_ops() config.patch_ops()
# PyTorch deprecated the `enable_onnx_checker` and `use_external_data_format` arguments in v1.11,
# so we check the torch version for backwards compatibility
if is_torch_less_than_1_11:
# export can work with named args but the dict containing named args
# has to be the last element of the args tuple.
try:
onnx_export(
model,
(model_inputs,),
f=output.as_posix(),
input_names=list(config.inputs.keys()),
output_names=onnx_outputs,
dynamic_axes=dict(chain(config.inputs.items(), config.outputs.items())),
do_constant_folding=True,
use_external_data_format=config.use_external_data_format(model.num_parameters()),
enable_onnx_checker=True,
opset_version=opset,
)
except RuntimeError as err:
message = str(err)
if (
message
== "Exporting model exceed maximum protobuf size of 2GB. Please call torch.onnx.export without"
" setting use_external_data_format parameter."
):
message = (
"Exporting model exceed maximum protobuf size of 2GB. Please call torch.onnx.export"
" without setting use_external_data_format parameter or try with torch 1.10+."
)
raise RuntimeError(message)
else:
raise err
else:
onnx_export( onnx_export(
model, model,
(model_inputs,), (model_inputs,),
......
...@@ -32,9 +32,6 @@ is_torch_greater_or_equal_than_2_1 = parsed_torch_version_base >= version.parse( ...@@ -32,9 +32,6 @@ is_torch_greater_or_equal_than_2_1 = parsed_torch_version_base >= version.parse(
is_torch_greater_or_equal_than_2_0 = parsed_torch_version_base >= version.parse("2.0") is_torch_greater_or_equal_than_2_0 = parsed_torch_version_base >= version.parse("2.0")
is_torch_greater_or_equal_than_1_13 = parsed_torch_version_base >= version.parse("1.13") is_torch_greater_or_equal_than_1_13 = parsed_torch_version_base >= version.parse("1.13")
is_torch_greater_or_equal_than_1_12 = parsed_torch_version_base >= version.parse("1.12") is_torch_greater_or_equal_than_1_12 = parsed_torch_version_base >= version.parse("1.12")
is_torch_greater_or_equal_than_1_11 = parsed_torch_version_base >= version.parse("1.11")
is_torch_less_than_1_11 = parsed_torch_version_base < version.parse("1.11")
is_torch_1_8_0 = parsed_torch_version_base == version.parse("1.8.0")
def softmax_backward_data(parent, grad_output, output, dim, self): def softmax_backward_data(parent, grad_output, output, dim, self):
...@@ -45,9 +42,6 @@ def softmax_backward_data(parent, grad_output, output, dim, self): ...@@ -45,9 +42,6 @@ def softmax_backward_data(parent, grad_output, output, dim, self):
from torch import _softmax_backward_data from torch import _softmax_backward_data
if is_torch_less_than_1_11:
return _softmax_backward_data(grad_output, output, parent.dim, self)
else:
return _softmax_backward_data(grad_output, output, parent.dim, self.dtype) return _softmax_backward_data(grad_output, output, parent.dim, self.dtype)
......
...@@ -64,7 +64,7 @@ from .modelcard import TrainingSummary ...@@ -64,7 +64,7 @@ from .modelcard import TrainingSummary
from .modeling_utils import PreTrainedModel, load_sharded_checkpoint, unwrap_model from .modeling_utils import PreTrainedModel, load_sharded_checkpoint, unwrap_model
from .models.auto.modeling_auto import MODEL_FOR_CAUSAL_LM_MAPPING_NAMES, MODEL_MAPPING_NAMES from .models.auto.modeling_auto import MODEL_FOR_CAUSAL_LM_MAPPING_NAMES, MODEL_MAPPING_NAMES
from .optimization import Adafactor, get_scheduler from .optimization import Adafactor, get_scheduler
from .pytorch_utils import ALL_LAYERNORM_LAYERS, is_torch_less_than_1_11 from .pytorch_utils import ALL_LAYERNORM_LAYERS
from .tokenization_utils_base import PreTrainedTokenizerBase from .tokenization_utils_base import PreTrainedTokenizerBase
from .trainer_callback import ( from .trainer_callback import (
CallbackHandler, CallbackHandler,
...@@ -1794,7 +1794,7 @@ class Trainer: ...@@ -1794,7 +1794,7 @@ class Trainer:
if version.parse(accelerate_version) > version.parse("0.23.0"): if version.parse(accelerate_version) > version.parse("0.23.0"):
sampler_kinds.append(SeedableRandomSampler) sampler_kinds.append(SeedableRandomSampler)
is_random_sampler = isinstance(sampler, tuple(sampler_kinds)) is_random_sampler = isinstance(sampler, tuple(sampler_kinds))
if is_torch_less_than_1_11 or not is_random_sampler: if not is_random_sampler:
# We just need to begin an iteration to create the randomization of the sampler. # We just need to begin an iteration to create the randomization of the sampler.
for _ in train_dataloader: for _ in train_dataloader:
break break
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment