"...git@developer.sourcefind.cn:chenpangpang/transformers.git" did not exist on "07e3454f034b4889925621e8e3253547d2a04aa7"
Unverified Commit 4d10474f authored by Sylvain Gugger's avatar Sylvain Gugger Committed by GitHub
Browse files

Handle nested dict/lists of tensors as inputs in the Trainer (#13338)

parent 3efcfeab
...@@ -1727,22 +1727,30 @@ class Trainer: ...@@ -1727,22 +1727,30 @@ class Trainer:
self.state.log_history.append(output) self.state.log_history.append(output)
self.control = self.callback_handler.on_log(self.args, self.state, self.control, logs) self.control = self.callback_handler.on_log(self.args, self.state, self.control, logs)
def _prepare_input(self, data: Union[torch.Tensor, Any]) -> Union[torch.Tensor, Any]:
"""
Prepares one :obj:`data` before feeding it to the model, be it a tensor or a nested list/dictionary of tensors.
"""
if isinstance(data, dict):
return type(data)(**{k: self._prepare_input(v) for k, v in data.items()})
elif isinstance(data, (tuple, list)):
return type(data)(self._prepare_input(v) for v in data)
elif isinstance(data, torch.Tensor):
kwargs = dict(device=self.args.device)
if self.deepspeed and data.dtype != torch.int64:
# NLP models inputs are int64 and those get adjusted to the right dtype of the
# embedding. Other models such as wav2vec2's inputs are already float and thus
# may need special handling to match the dtypes of the model
kwargs.update(dict(dtype=self.args.hf_deepspeed_config.dtype()))
return data.to(**kwargs)
return data
def _prepare_inputs(self, inputs: Dict[str, Union[torch.Tensor, Any]]) -> Dict[str, Union[torch.Tensor, Any]]: def _prepare_inputs(self, inputs: Dict[str, Union[torch.Tensor, Any]]) -> Dict[str, Union[torch.Tensor, Any]]:
""" """
Prepare :obj:`inputs` before feeding them to the model, converting them to tensors if they are not already and Prepare :obj:`inputs` before feeding them to the model, converting them to tensors if they are not already and
handling potential state. handling potential state.
""" """
for k, v in inputs.items(): inputs = self._prepare_input(inputs)
if isinstance(v, torch.Tensor):
kwargs = dict(device=self.args.device)
if self.deepspeed and inputs[k].dtype != torch.int64:
# NLP models inputs are int64 and those get adjusted to the right dtype of the
# embedding. Other models such as wav2vec2's inputs are already float and thus
# may need special handling to match the dtypes of the model
kwargs.update(dict(dtype=self.args.hf_deepspeed_config.dtype()))
inputs[k] = v.to(**kwargs)
if self.args.past_index >= 0 and self._past is not None: if self.args.past_index >= 0 and self._past is not None:
inputs["mems"] = self._past inputs["mems"] = self._past
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment