"vscode:/vscode.git/clone" did not exist on "df311a5ccf50be3031474e289b43b1be43111144"
Unverified Commit 33bf4264 authored by Prajjwal Bhargava's avatar Prajjwal Bhargava Committed by GitHub
Browse files

removed redundant arg in prepare_inputs (#6614)

* removed redundant arg in prepare_inputs

* made same change in prediction_loop
parent cabfdfaf
...@@ -705,7 +705,7 @@ class Trainer: ...@@ -705,7 +705,7 @@ class Trainer:
print(output) print(output)
def _prepare_inputs( def _prepare_inputs(
self, inputs: Dict[str, Union[torch.Tensor, Any]], model: nn.Module self, inputs: Dict[str, Union[torch.Tensor, Any]]
) -> Dict[str, Union[torch.Tensor, Any]]: ) -> Dict[str, Union[torch.Tensor, Any]]:
""" """
Prepare :obj:`inputs` before feeding them to the model, converting them to tensors if they are not already and Prepare :obj:`inputs` before feeding them to the model, converting them to tensors if they are not already and
...@@ -746,7 +746,7 @@ class Trainer: ...@@ -746,7 +746,7 @@ class Trainer:
return self._training_step(model, inputs, self.optimizer) return self._training_step(model, inputs, self.optimizer)
model.train() model.train()
inputs = self._prepare_inputs(inputs, model) inputs = self._prepare_inputs(inputs)
if self.args.fp16 and _use_native_amp: if self.args.fp16 and _use_native_amp:
with autocast(): with autocast():
...@@ -1071,7 +1071,7 @@ class Trainer: ...@@ -1071,7 +1071,7 @@ class Trainer:
""" """
has_labels = any(inputs.get(k) is not None for k in ["labels", "lm_labels", "masked_lm_labels"]) has_labels = any(inputs.get(k) is not None for k in ["labels", "lm_labels", "masked_lm_labels"])
inputs = self._prepare_inputs(inputs, model) inputs = self._prepare_inputs(inputs)
with torch.no_grad(): with torch.no_grad():
outputs = model(**inputs) outputs = model(**inputs)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment