Unverified Commit 26a2ec56 authored by Sylvain Gugger's avatar Sylvain Gugger Committed by GitHub
Browse files

Clean up old Accelerate checks (#24279)

* Clean up old Accelerate checks

* Put back imports
parent 860d11ff
...@@ -98,7 +98,7 @@ if stale_egg_info.exists(): ...@@ -98,7 +98,7 @@ if stale_egg_info.exists():
# 2. once modified, run: `make deps_table_update` to update src/transformers/dependency_versions_table.py # 2. once modified, run: `make deps_table_update` to update src/transformers/dependency_versions_table.py
_deps = [ _deps = [
"Pillow", "Pillow",
"accelerate>=0.20.2", "accelerate>=0.20.3",
"av==9.2.0", # Latest version of PyAV (10.0.0) has issues with audio stream. "av==9.2.0", # Latest version of PyAV (10.0.0) has issues with audio stream.
"beautifulsoup4", "beautifulsoup4",
"black~=23.1", "black~=23.1",
......
...@@ -3,7 +3,7 @@ ...@@ -3,7 +3,7 @@
# 2. run `make deps_table_update`` # 2. run `make deps_table_update``
deps = { deps = {
"Pillow": "Pillow", "Pillow": "Pillow",
"accelerate": "accelerate>=0.20.2", "accelerate": "accelerate>=0.20.3",
"av": "av==9.2.0", "av": "av==9.2.0",
"beautifulsoup4": "beautifulsoup4", "beautifulsoup4": "beautifulsoup4",
"black": "black~=23.1", "black": "black~=23.1",
......
...@@ -82,27 +82,17 @@ XLA_USE_BF16 = os.environ.get("XLA_USE_BF16", "0").upper() ...@@ -82,27 +82,17 @@ XLA_USE_BF16 = os.environ.get("XLA_USE_BF16", "0").upper()
XLA_DOWNCAST_BF16 = os.environ.get("XLA_DOWNCAST_BF16", "0").upper() XLA_DOWNCAST_BF16 = os.environ.get("XLA_DOWNCAST_BF16", "0").upper()
if is_accelerate_available(): if is_accelerate_available():
from accelerate import __version__ as accelerate_version
from accelerate import dispatch_model, infer_auto_device_map, init_empty_weights from accelerate import dispatch_model, infer_auto_device_map, init_empty_weights
from accelerate.utils import ( from accelerate.utils import (
check_tied_parameters_on_same_device,
find_tied_parameters, find_tied_parameters,
get_balanced_memory,
load_offloaded_weights, load_offloaded_weights,
offload_weight, offload_weight,
save_offload_index, save_offload_index,
set_module_tensor_to_device, set_module_tensor_to_device,
) )
if version.parse(accelerate_version) > version.parse("0.11.0"):
from accelerate.utils import get_balanced_memory
else:
get_balanced_memory = None
if version.parse(accelerate_version) > version.parse("0.19.0"):
from accelerate.utils import check_tied_parameters_on_same_device
else:
check_tied_parameters_on_same_device = None
else:
find_tied_parameters = None
if is_safetensors_available(): if is_safetensors_available():
from safetensors import safe_open from safetensors import safe_open
from safetensors.torch import load_file as safe_load_file from safetensors.torch import load_file as safe_load_file
...@@ -2792,8 +2782,6 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix ...@@ -2792,8 +2782,6 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
"If passing a string for `device_map`, please choose 'auto', 'balanced', 'balanced_low_0' or " "If passing a string for `device_map`, please choose 'auto', 'balanced', 'balanced_low_0' or "
"'sequential'." "'sequential'."
) )
elif device_map in ["balanced", "balanced_low_0"] and get_balanced_memory is None:
raise ValueError(f"`device_map={device_map}` requires a source install of Accelerate.")
kwargs = {"no_split_module_classes": no_split_modules} kwargs = {"no_split_module_classes": no_split_modules}
if "special_dtypes" in inspect.signature(infer_auto_device_map).parameters: if "special_dtypes" in inspect.signature(infer_auto_device_map).parameters:
...@@ -2803,7 +2791,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix ...@@ -2803,7 +2791,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
"This model has some weights that should be kept in higher precision, you need to upgrade " "This model has some weights that should be kept in higher precision, you need to upgrade "
"`accelerate` to properly deal with them (`pip install --upgrade accelerate`)." "`accelerate` to properly deal with them (`pip install --upgrade accelerate`)."
) )
if device_map != "sequential" and get_balanced_memory is not None: if device_map != "sequential":
max_memory = get_balanced_memory( max_memory = get_balanced_memory(
model, model,
dtype=target_dtype, dtype=target_dtype,
...@@ -2838,8 +2826,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix ...@@ -2838,8 +2826,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
model.tie_weights() model.tie_weights()
tied_params = find_tied_parameters(model) tied_params = find_tied_parameters(model)
# check if we don't have tied param in different devices # check if we don't have tied param in different devices
if check_tied_parameters_on_same_device is not None: check_tied_parameters_on_same_device(tied_params, device_map)
check_tied_parameters_on_same_device(tied_params, device_map)
if from_tf: if from_tf:
if resolved_archive_file.endswith(".index"): if resolved_archive_file.endswith(".index"):
...@@ -3031,7 +3018,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix ...@@ -3031,7 +3018,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
missing_keys = list(set(expected_keys) - set(loaded_keys)) missing_keys = list(set(expected_keys) - set(loaded_keys))
unexpected_keys = list(set(loaded_keys) - set(expected_keys)) unexpected_keys = list(set(loaded_keys) - set(expected_keys))
if find_tied_parameters is not None: if is_accelerate_available():
tied_params = find_tied_parameters(model) tied_params = find_tied_parameters(model)
else: else:
tied_params = [] tied_params = []
......
...@@ -32,8 +32,6 @@ from collections.abc import Mapping ...@@ -32,8 +32,6 @@ from collections.abc import Mapping
from pathlib import Path from pathlib import Path
from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Tuple, Union from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Tuple, Union
from tqdm.auto import tqdm
# Integrations must be imported before ML frameworks: # Integrations must be imported before ML frameworks:
# isort: off # isort: off
...@@ -206,14 +204,9 @@ if is_peft_available(): ...@@ -206,14 +204,9 @@ if is_peft_available():
from peft import PeftModel from peft import PeftModel
skip_first_batches = None
if is_accelerate_available(): if is_accelerate_available():
from accelerate import Accelerator, skip_first_batches
from accelerate import __version__ as accelerate_version from accelerate import __version__ as accelerate_version
if version.parse(accelerate_version) >= version.parse("0.16"):
from accelerate import skip_first_batches
from accelerate import Accelerator
from accelerate.utils import DistributedDataParallelKwargs from accelerate.utils import DistributedDataParallelKwargs
if version.parse(accelerate_version) > version.parse("0.20.3"): if version.parse(accelerate_version) > version.parse("0.20.3"):
...@@ -322,6 +315,7 @@ class Trainer: ...@@ -322,6 +315,7 @@ class Trainer:
""" """
# Those are used as methods of the Trainer in examples.
from .trainer_pt_utils import _get_learning_rate, log_metrics, metrics_format, save_metrics, save_state from .trainer_pt_utils import _get_learning_rate, log_metrics, metrics_format, save_metrics, save_state
def __init__( def __init__(
...@@ -1714,22 +1708,10 @@ class Trainer: ...@@ -1714,22 +1708,10 @@ class Trainer:
logger.info(f" Continuing training from epoch {epochs_trained}") logger.info(f" Continuing training from epoch {epochs_trained}")
logger.info(f" Continuing training from global step {self.state.global_step}") logger.info(f" Continuing training from global step {self.state.global_step}")
if not args.ignore_data_skip: if not args.ignore_data_skip:
if skip_first_batches is None: logger.info(
logger.info( f" Will skip the first {epochs_trained} epochs then the first"
f" Will skip the first {epochs_trained} epochs then the first" f" {steps_trained_in_current_epoch} batches in the first epoch."
f" {steps_trained_in_current_epoch} batches in the first epoch. If this takes a lot of time," )
" you can install the latest version of Accelerate with `pip install -U accelerate`.You can"
" also add the `--ignore_data_skip` flag to your launch command, but you will resume the"
" training on data already seen by your model."
)
else:
logger.info(
f" Will skip the first {epochs_trained} epochs then the first"
f" {steps_trained_in_current_epoch} batches in the first epoch."
)
if self.is_local_process_zero() and not args.disable_tqdm and skip_first_batches is None:
steps_trained_progress_bar = tqdm(total=steps_trained_in_current_epoch)
steps_trained_progress_bar.set_description("Skipping the first batches")
# Update the references # Update the references
self.callback_handler.model = self.model self.callback_handler.model = self.model
...@@ -1787,7 +1769,7 @@ class Trainer: ...@@ -1787,7 +1769,7 @@ class Trainer:
rng_to_sync = False rng_to_sync = False
steps_skipped = 0 steps_skipped = 0
if skip_first_batches is not None and steps_trained_in_current_epoch > 0: if steps_trained_in_current_epoch > 0:
epoch_iterator = skip_first_batches(epoch_iterator, steps_trained_in_current_epoch) epoch_iterator = skip_first_batches(epoch_iterator, steps_trained_in_current_epoch)
steps_skipped = steps_trained_in_current_epoch steps_skipped = steps_trained_in_current_epoch
steps_trained_in_current_epoch = 0 steps_trained_in_current_epoch = 0
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment