Unverified Commit c95f8470 authored by aws-sangeetha's avatar aws-sangeetha Committed by GitHub
Browse files

Clip floating point constants to bf16 range to avoid inf conversion (#20605)


Co-authored-by: default avatarEC2 Default User <ec2-user@ip-172-31-40-169.us-west-2.compute.internal>
parent f68796bd
...@@ -32,7 +32,7 @@ from torch import Tensor, nn ...@@ -32,7 +32,7 @@ from torch import Tensor, nn
from torch.nn import CrossEntropyLoss from torch.nn import CrossEntropyLoss
from transformers.utils.hub import convert_file_size_to_int, get_checkpoint_shard_files from transformers.utils.hub import convert_file_size_to_int, get_checkpoint_shard_files
from transformers.utils.import_utils import is_sagemaker_mp_enabled from transformers.utils.import_utils import ENV_VARS_TRUE_VALUES, is_sagemaker_mp_enabled
from .activations import get_activation from .activations import get_activation
from .configuration_utils import PretrainedConfig from .configuration_utils import PretrainedConfig
...@@ -68,12 +68,16 @@ from .utils import ( ...@@ -68,12 +68,16 @@ from .utils import (
is_offline_mode, is_offline_mode,
is_remote_url, is_remote_url,
is_safetensors_available, is_safetensors_available,
is_torch_tpu_available,
logging, logging,
replace_return_docstrings, replace_return_docstrings,
) )
from .utils.versions import require_version_core from .utils.versions import require_version_core
XLA_USE_BF16 = os.environ.get("XLA_USE_BF16", "0").upper()
XLA_DOWNCAST_BF16 = os.environ.get("XLA_DOWNCAST_BF16", "0").upper()
if is_accelerate_available(): if is_accelerate_available():
from accelerate import __version__ as accelerate_version from accelerate import __version__ as accelerate_version
from accelerate import dispatch_model, infer_auto_device_map, init_empty_weights from accelerate import dispatch_model, infer_auto_device_map, init_empty_weights
...@@ -181,6 +185,17 @@ def get_parameter_dtype(parameter: Union[nn.Module, GenerationMixin, "ModuleUtil ...@@ -181,6 +185,17 @@ def get_parameter_dtype(parameter: Union[nn.Module, GenerationMixin, "ModuleUtil
for t in parameter.parameters(): for t in parameter.parameters():
last_dtype = t.dtype last_dtype = t.dtype
if t.is_floating_point(): if t.is_floating_point():
# Adding fix for https://github.com/pytorch/xla/issues/4152
# Fixes issue where the model code passes a value that is out of range for XLA_USE_BF16=1
# and XLA_DOWNCAST_BF16=1 so the conversion would cast it to -inf
if is_torch_tpu_available():
if XLA_USE_BF16 in ENV_VARS_TRUE_VALUES:
return torch.bfloat16
if XLA_DOWNCAST_BF16 in ENV_VARS_TRUE_VALUES:
if t.dtype == torch.float:
return torch.bfloat16
if t.dtype == torch.double:
return torch.float32
return t.dtype return t.dtype
if last_dtype is not None: if last_dtype is not None:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment