Unverified Commit 1567bef3 authored by Joao Gante's avatar Joao Gante Committed by GitHub
Browse files

Generate: PT Dynamo without graph breaks in the main greedy/sample loop (#21648)

parent 7a5533b2
...@@ -298,6 +298,9 @@ class GenerationConfig(PushToHubMixin): ...@@ -298,6 +298,9 @@ class GenerationConfig(PushToHubMixin):
self.validate() self.validate()
def __eq__(self, other): def __eq__(self, other):
if not isinstance(other, GenerationConfig):
return False
self_dict = self.__dict__.copy() self_dict = self.__dict__.copy()
other_dict = other.__dict__.copy() other_dict = other.__dict__.copy()
# ignore metadata # ignore metadata
......
...@@ -190,10 +190,10 @@ def get_parameter_dtype(parameter: Union[nn.Module, GenerationMixin, "ModuleUtil ...@@ -190,10 +190,10 @@ def get_parameter_dtype(parameter: Union[nn.Module, GenerationMixin, "ModuleUtil
# Adding fix for https://github.com/pytorch/xla/issues/4152 # Adding fix for https://github.com/pytorch/xla/issues/4152
# Fixes issue where the model code passes a value that is out of range for XLA_USE_BF16=1 # Fixes issue where the model code passes a value that is out of range for XLA_USE_BF16=1
# and XLA_DOWNCAST_BF16=1 so the conversion would cast it to -inf # and XLA_DOWNCAST_BF16=1 so the conversion would cast it to -inf
if is_torch_tpu_available(): # NOTE: `is_torch_tpu_available()` is checked last as it induces a graph break in torch dynamo
if XLA_USE_BF16 in ENV_VARS_TRUE_VALUES: if XLA_USE_BF16 in ENV_VARS_TRUE_VALUES and is_torch_tpu_available():
return torch.bfloat16 return torch.bfloat16
if XLA_DOWNCAST_BF16 in ENV_VARS_TRUE_VALUES: if XLA_DOWNCAST_BF16 in ENV_VARS_TRUE_VALUES and is_torch_tpu_available():
if t.dtype == torch.float: if t.dtype == torch.float:
return torch.bfloat16 return torch.bfloat16
if t.dtype == torch.double: if t.dtype == torch.double:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment