Unverified Commit cf4c20b9 authored by Kossai Sbai's avatar Kossai Sbai Committed by GitHub
Browse files

Convert `torch_dtype` as `str` to actual torch data type (i.e. "float16" …to...


Convert `torch_dtype` as `str` to actual torch data type (i.e. "float16" …to `torch.float16`) (#28208)

* Convert torch_dtype as str to actual torch data type (i.e. "float16" to torch.float16)

* Check if passed torch_dtype is an attribute in torch

* Update src/transformers/pipelines/__init__.py

Check type via isinstance
Co-authored-by: default avataramyeroberts <22614925+amyeroberts@users.noreply.github.com>

---------
Co-authored-by: default avataramyeroberts <22614925+amyeroberts@users.noreply.github.com>
parent ef5ab72f
...@@ -892,6 +892,8 @@ def pipeline( ...@@ -892,6 +892,8 @@ def pipeline(
'You cannot use both `pipeline(... torch_dtype=..., model_kwargs={"torch_dtype":...})` as those' 'You cannot use both `pipeline(... torch_dtype=..., model_kwargs={"torch_dtype":...})` as those'
" arguments might conflict, use only one.)" " arguments might conflict, use only one.)"
) )
if isinstance(torch_dtype, str) and hasattr(torch, torch_dtype):
torch_dtype = getattr(torch, torch_dtype)
model_kwargs["torch_dtype"] = torch_dtype model_kwargs["torch_dtype"] = torch_dtype
model_name = model if isinstance(model, str) else None model_name = model if isinstance(model, str) else None
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment