Unverified Commit f71873c5 authored by Eugene Zapolsky's avatar Eugene Zapolsky Committed by GitHub
Browse files

[deepspeed] check whether model is NLP one instead of counting on input type (#21800)



* trying to figure out whether model is NLP

* drop my changes and apply easier fix

* trying to handle all int input types

* fix logic

---------
Co-authored-by: default avatarStas Bekman <stas@stason.org>
parent 72e9ca75
...@@ -2562,8 +2562,8 @@ class Trainer: ...@@ -2562,8 +2562,8 @@ class Trainer:
return type(data)(self._prepare_input(v) for v in data) return type(data)(self._prepare_input(v) for v in data)
elif isinstance(data, torch.Tensor): elif isinstance(data, torch.Tensor):
kwargs = {"device": self.args.device} kwargs = {"device": self.args.device}
if self.deepspeed and data.dtype != torch.int64: if self.deepspeed and (torch.is_floating_point(data) or torch.is_complex(data)):
# NLP models inputs are int64 and those get adjusted to the right dtype of the # NLP models inputs are int/uint and those get adjusted to the right dtype of the
# embedding. Other models such as wav2vec2's inputs are already float and thus # embedding. Other models such as wav2vec2's inputs are already float and thus
# may need special handling to match the dtypes of the model # may need special handling to match the dtypes of the model
kwargs.update({"dtype": self.args.hf_deepspeed_config.dtype()}) kwargs.update({"dtype": self.args.hf_deepspeed_config.dtype()})
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment