Unverified Commit fb7d2469 authored by dumpmemory's avatar dumpmemory Committed by GitHub
Browse files

Try to fix training Loss inconsistent after resume from old checkpoint (#25872)



* fix loss inconsistent after resume  #25340

* fix typo

* clean code

* reformatted code

* adjust code according to comments

* adjust check_dataloader_randomsampler location

* return sampler only

* handle sampler is None

* Update src/transformers/trainer_pt_utils.py

thanks @amyeroberts
Co-authored-by: default avataramyeroberts <22614925+amyeroberts@users.noreply.github.com>

---------
Co-authored-by: default avataramyeroberts <22614925+amyeroberts@users.noreply.github.com>
parent c5e66a40
...@@ -65,7 +65,7 @@ from .modelcard import TrainingSummary ...@@ -65,7 +65,7 @@ from .modelcard import TrainingSummary
from .modeling_utils import PreTrainedModel, load_sharded_checkpoint, unwrap_model from .modeling_utils import PreTrainedModel, load_sharded_checkpoint, unwrap_model
from .models.auto.modeling_auto import MODEL_FOR_CAUSAL_LM_MAPPING_NAMES, MODEL_MAPPING_NAMES from .models.auto.modeling_auto import MODEL_FOR_CAUSAL_LM_MAPPING_NAMES, MODEL_MAPPING_NAMES
from .optimization import Adafactor, get_scheduler from .optimization import Adafactor, get_scheduler
from .pytorch_utils import ALL_LAYERNORM_LAYERS from .pytorch_utils import ALL_LAYERNORM_LAYERS, is_torch_less_than_1_11
from .tokenization_utils_base import PreTrainedTokenizerBase from .tokenization_utils_base import PreTrainedTokenizerBase
from .trainer_callback import ( from .trainer_callback import (
CallbackHandler, CallbackHandler,
...@@ -85,6 +85,7 @@ from .trainer_pt_utils import ( ...@@ -85,6 +85,7 @@ from .trainer_pt_utils import (
distributed_broadcast_scalars, distributed_broadcast_scalars,
distributed_concat, distributed_concat,
find_batch_size, find_batch_size,
get_dataloader_sampler,
get_model_param_count, get_model_param_count,
get_module_class_from_name, get_module_class_from_name,
get_parameter_names, get_parameter_names,
...@@ -219,6 +220,7 @@ if is_accelerate_available(): ...@@ -219,6 +220,7 @@ if is_accelerate_available():
if TYPE_CHECKING: if TYPE_CHECKING:
import optuna import optuna
logger = logging.get_logger(__name__) logger = logging.get_logger(__name__)
...@@ -1808,8 +1810,17 @@ class Trainer: ...@@ -1808,8 +1810,17 @@ class Trainer:
# Skip the first epochs_trained epochs to get the random state of the dataloader at the right point. # Skip the first epochs_trained epochs to get the random state of the dataloader at the right point.
if not args.ignore_data_skip: if not args.ignore_data_skip:
for epoch in range(epochs_trained): for epoch in range(epochs_trained):
sampler = get_dataloader_sampler(train_dataloader)
is_random_sampler = isinstance(sampler, RandomSampler)
if is_torch_less_than_1_11 or not is_random_sampler:
# We just need to begin an iteration to create the randomization of the sampler.
for _ in train_dataloader: for _ in train_dataloader:
break break
else:
# Otherwise we need to call the whooooole sampler cause there is some random operation added
# AT THE VERY END!
sampler = sampler if sampler is not None else []
_ = list(sampler)
total_batched_samples = 0 total_batched_samples = 0
for epoch in range(epochs_trained, num_train_epochs): for epoch in range(epochs_trained, num_train_epochs):
......
...@@ -55,6 +55,13 @@ except ImportError: ...@@ -55,6 +55,13 @@ except ImportError:
logger = logging.get_logger(__name__) logger = logging.get_logger(__name__)
def get_dataloader_sampler(dataloader):
if hasattr(dataloader, "batch_sampler") and dataloader.batch_sampler is not None:
return get_dataloader_sampler(dataloader.batch_sampler)
elif hasattr(dataloader, "sampler"):
return dataloader.sampler
def atleast_1d(tensor_or_array: Union[torch.Tensor, np.ndarray]): def atleast_1d(tensor_or_array: Union[torch.Tensor, np.ndarray]):
if isinstance(tensor_or_array, torch.Tensor): if isinstance(tensor_or_array, torch.Tensor):
if hasattr(torch, "atleast_1d"): if hasattr(torch, "atleast_1d"):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment