Unverified Commit 9080607b authored by Michael Benayoun's avatar Michael Benayoun Committed by GitHub
Browse files

Fixed torch.finfo issue with torch.fx (#20040)

parent 6f257bb3
......@@ -230,6 +230,15 @@ def torch_arange(*args, **kwargs):
return torch.empty((end - start) // step, dtype=dtype, device="meta")
def torch_full(*args, **kwargs):
args = list(args)
if isinstance(args[1], torch.Tensor) and args[1].device == torch.device("meta"):
args[1] = 1 # Any value.
kwargs_without_device = dict(kwargs)
kwargs_without_device.pop("device", None)
return torch.full(*args, **kwargs_without_device)
def torch_cat(tensors, dim=None, axis=None, *, out=None):
if dim is None and axis is None:
dim = 0
......@@ -509,6 +518,7 @@ _MANUAL_META_OVERRIDES: Dict[Callable, Callable] = {
torch.where: torch_where,
torch.abs: torch_abs,
torch.arange: torch_arange,
torch.full: torch_full,
torch.cat: torch_cat,
torch.stack: torch_stack,
torch.add: torch_add,
......@@ -552,12 +562,6 @@ class HFProxy(Proxy):
def shape(self):
return self.tracer.create_proxy("call_method", "size", (self,), {})
@property
def dtype(self):
if hasattr(self, "_metadata") and self._metadata is not None:
return self._metadata.dtype
return self.tracer.create_proxy("call_function", builtins.getattr, (self, "dtype"), {})
@property
def device(self):
# Hack so we can track when devices are used. During meta-tensor propagation,
......@@ -597,12 +601,15 @@ class HFAttribute(HFProxy):
self.tracer = root.tracer
self._node = None
if hasattr(self.root, "_metadata"):
self.install_metadata(getattr(self.root._metadata, attr))
@property
def node(self):
# the node for attributes is added lazily, since most will just be method calls
# which do not rely on the getitem call
if self._node is None:
self._node = self.tracer.create_proxy("call_function", getattr, (self.root, self.attr), {}).node
self._node = self.tracer.create_proxy("call_function", builtins.getattr, (self.root, self.attr), {}).node
return self._node
def __call__(self, *args, **kwargs):
......@@ -663,7 +670,18 @@ class HFTracer(Tracer):
# Feature flag for proxying accesses to buffer values
proxy_buffer_attributes: bool = True
allow_insert_stateless_mods: bool = True
_TORCH_METHODS_TO_PATCH = ["arange", "zeros", "ones", "full", "full_like", "eye", "empty", "tensor"]
_TORCH_METHODS_TO_PATCH = [
"arange",
"zeros",
"ones",
"full",
"full_like",
"eye",
"empty",
"tensor",
"clamp",
"finfo",
]
def __init__(self, autowrap_modules=(math,), autowrap_functions=()):
......@@ -737,6 +755,8 @@ class HFTracer(Tracer):
"GPT2DoubleHeadsModel",
]:
inputs_dict["labels"] = torch.zeros(shape, dtype=torch.long, device=device)
elif model_class_name in [*get_values(MODEL_FOR_CTC_MAPPING_NAMES)]:
inputs_dict["labels"] = torch.zeros(shape, dtype=torch.float32, device=device)
else:
raise NotImplementedError(
f"Generating the dummy input named {input_name} for {model_class_name} is not supported yet."
......
......@@ -835,17 +835,14 @@ class ModelTesterMixin:
filtered_inputs = {k: v for (k, v) in inputs.items() if k in input_names}
input_names = list(filtered_inputs.keys())
model_output = model(**filtered_inputs)
if (
isinstance(model, tuple(MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING.values()))
and not hasattr(model.config, "problem_type")
or model.config.problem_type is None
if isinstance(model, tuple(MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING.values())) and (
not hasattr(model.config, "problem_type") or model.config.problem_type is None
):
model.config.problem_type = "single_label_classification"
traced_model = symbolic_trace(model, input_names)
traced_output = traced_model(**filtered_inputs)
model_output = model(**filtered_inputs)
except Exception as e:
self.fail(f"Couldn't trace module: {e}")
......@@ -871,20 +868,6 @@ class ModelTesterMixin:
f"traced {i}th output doesn't match model {i}th output for {model_class}",
)
# Test that the model can be TorchScripted
try:
scripted = torch.jit.script(traced_model)
except Exception as e:
self.fail(f"Could not TorchScript the traced model: {e}")
scripted_output = scripted(**filtered_inputs)
scripted_output = flatten_output(scripted_output)
for i in range(num_outputs):
self.assertTrue(
torch.allclose(model_output[i], scripted_output[i]),
f"scripted {i}th output doesn't match model {i}th output for {model_class}",
)
# Test that the model can be serialized and restored properly
with tempfile.TemporaryDirectory() as tmp_dir_name:
pkl_file_name = os.path.join(tmp_dir_name, "model.pkl")
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment