Unverified Commit 188105a2 authored by Juwan Yoo's avatar Juwan Yoo Committed by GitHub
Browse files

deps: lazy import optional dependencies `gguf` and `torchvision` (#4826)

parent b3953258
...@@ -4,7 +4,6 @@ from dataclasses import dataclass ...@@ -4,7 +4,6 @@ from dataclasses import dataclass
from typing import Dict, List, Optional, Tuple from typing import Dict, List, Optional, Tuple
import torch import torch
import torchvision.transforms as T
from PIL import Image, ImageOps from PIL import Image, ImageOps
from transformers import ( from transformers import (
AutoProcessor, AutoProcessor,
...@@ -76,6 +75,16 @@ class ImageTransform(object): ...@@ -76,6 +75,16 @@ class ImageTransform(object):
self.std = std self.std = std
self.normalize = normalize self.normalize = normalize
# only load torchvision.transforms when needed
try:
import torchvision.transforms as T
# FIXME: add version check for gguf
except ImportError as err:
raise ImportError(
"Please install torchvision via `pip install torchvision` to use Deepseek-VL2."
) from err
transform_pipelines = [T.ToTensor()] transform_pipelines = [T.ToTensor()]
if normalize: if normalize:
......
...@@ -14,7 +14,6 @@ from abc import ABC, abstractmethod ...@@ -14,7 +14,6 @@ from abc import ABC, abstractmethod
from contextlib import contextmanager from contextlib import contextmanager
from typing import Any, Dict, Generator, Iterable, List, Optional, Tuple, cast from typing import Any, Dict, Generator, Iterable, List, Optional, Tuple, cast
import gguf
import huggingface_hub import huggingface_hub
import numpy as np import numpy as np
import torch import torch
...@@ -1155,6 +1154,17 @@ class GGUFModelLoader(BaseModelLoader): ...@@ -1155,6 +1154,17 @@ class GGUFModelLoader(BaseModelLoader):
See "Standardized tensor names" in See "Standardized tensor names" in
https://github.com/ggerganov/ggml/blob/master/docs/gguf.md for details. https://github.com/ggerganov/ggml/blob/master/docs/gguf.md for details.
""" """
# only load the gguf module when needed
try:
import gguf
# FIXME: add version check for gguf
except ImportError as err:
raise ImportError(
"Please install gguf via `pip install gguf` to use gguf quantizer."
) from err
config = model_config.hf_config config = model_config.hf_config
model_type = config.model_type model_type = config.model_type
# hack: ggufs have a different name than transformers # hack: ggufs have a different name than transformers
......
...@@ -22,7 +22,6 @@ from typing import ( ...@@ -22,7 +22,6 @@ from typing import (
) )
import filelock import filelock
import gguf
import huggingface_hub.constants import huggingface_hub.constants
import numpy as np import numpy as np
import safetensors.torch import safetensors.torch
...@@ -464,6 +463,8 @@ def pt_weights_iterator( ...@@ -464,6 +463,8 @@ def pt_weights_iterator(
def get_gguf_extra_tensor_names( def get_gguf_extra_tensor_names(
gguf_file: str, gguf_to_hf_name_map: Dict[str, str] gguf_file: str, gguf_to_hf_name_map: Dict[str, str]
) -> List[str]: ) -> List[str]:
import gguf
reader = gguf.GGUFReader(gguf_file) reader = gguf.GGUFReader(gguf_file)
expected_gguf_keys = set(gguf_to_hf_name_map.keys()) expected_gguf_keys = set(gguf_to_hf_name_map.keys())
exact_gguf_keys = set([tensor.name for tensor in reader.tensors]) exact_gguf_keys = set([tensor.name for tensor in reader.tensors])
...@@ -479,6 +480,8 @@ def gguf_quant_weights_iterator( ...@@ -479,6 +480,8 @@ def gguf_quant_weights_iterator(
them to torch tensors them to torch tensors
""" """
import gguf
reader = gguf.GGUFReader(gguf_file) reader = gguf.GGUFReader(gguf_file)
for tensor in reader.tensors: for tensor in reader.tensors:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment