"...composable_kernel_rocm.git" did not exist on "ff24b537cb5412b4720a8923bbc090de6d020a3b"
Unverified Commit cf4c20b9 authored by Kossai Sbai's avatar Kossai Sbai Committed by GitHub
Browse files

Convert `torch_dtype` as `str` to actual torch data type (i.e. "float16" …to...


Convert `torch_dtype` as `str` to actual torch data type (i.e. "float16" …to `torch.float16`) (#28208)

* Convert torch_dtype as str to actual torch data type (i.e. "float16" to torch.float16)

* Check if passed torch_dtype is an attribute in torch

* Update src/transformers/pipelines/__init__.py

Check type via isinstance
Co-authored-by: default avataramyeroberts <22614925+amyeroberts@users.noreply.github.com>

---------
Co-authored-by: default avataramyeroberts <22614925+amyeroberts@users.noreply.github.com>
parent ef5ab72f
......@@ -892,6 +892,8 @@ def pipeline(
'You cannot use both `pipeline(... torch_dtype=..., model_kwargs={"torch_dtype":...})` as those'
" arguments might conflict, use only one.)"
)
if isinstance(torch_dtype, str) and hasattr(torch, torch_dtype):
torch_dtype = getattr(torch, torch_dtype)
model_kwargs["torch_dtype"] = torch_dtype
model_name = model if isinstance(model, str) else None
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment