Unverified Commit 06d0e880 authored by Daniël de Kok's avatar Daniël de Kok Committed by GitHub
Browse files

Add support for AWQ-quantized Idefics2 (#2233)

Fixes #2036.
parent 0ad7f6f8
...@@ -34,6 +34,7 @@ from text_generation_server.layers import ( ...@@ -34,6 +34,7 @@ from text_generation_server.layers import (
TensorParallelEmbedding, TensorParallelEmbedding,
TensorParallelRowLinear, TensorParallelRowLinear,
) )
from text_generation_server.utils.weights import DefaultWeightsLoader
def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
...@@ -682,7 +683,7 @@ class Idefics2Connector(nn.Module): ...@@ -682,7 +683,7 @@ class Idefics2Connector(nn.Module):
class Idefics2ForConditionalGeneration(nn.Module): class Idefics2ForConditionalGeneration(nn.Module):
def __init__(self, prefix, config, weights): def __init__(self, prefix, config, weights):
super().__init__() super().__init__()
config.vision_config.quantize = config.quantize config.vision_config.quantize = None
config.vision_config.speculator = config.speculator config.vision_config.speculator = config.speculator
config.text_config.quantize = config.quantize config.text_config.quantize = config.quantize
config.text_config.speculator = config.speculator config.text_config.speculator = config.speculator
...@@ -695,16 +696,28 @@ class Idefics2ForConditionalGeneration(nn.Module): ...@@ -695,16 +696,28 @@ class Idefics2ForConditionalGeneration(nn.Module):
name="text_model", name="text_model",
) )
self.dtype = weights.dtype self.dtype = weights.dtype
self.vision_model = Idefics2VisionTransformer(
prefix=f"{prefix}.model.vision_model" if prefix else "model.vision_model", # The vision and connector models are not quantized.
config=vision_config, with weights.use_loader(DefaultWeightsLoader()):
weights=weights, self.vision_model = Idefics2VisionTransformer(
) prefix=(
self.connector = Idefics2Connector( f"{prefix}.model.vision_model" if prefix else "model.vision_model"
prefix=f"{prefix}.model.connector" if prefix else "model.connector", ),
config=config, config=vision_config,
weights=weights, weights=weights,
) )
quantize = config.quantize
try:
config.quantize = None
self.connector = Idefics2Connector(
prefix=f"{prefix}.model.connector" if prefix else "model.connector",
config=config,
weights=weights,
)
finally:
config.quantize = quantize
self.config = config self.config = config
self.image_seq_len = config.perceiver_config.resampler_n_latents self.image_seq_len = config.perceiver_config.resampler_n_latents
self.image_token_id = config.image_token_id self.image_token_id = config.image_token_id
......
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from contextlib import contextmanager
from pathlib import Path from pathlib import Path
from typing import Dict, List, Optional, Union from typing import Dict, List, Optional, Union
from safetensors import safe_open from safetensors import safe_open
...@@ -306,6 +307,20 @@ class Weights: ...@@ -306,6 +307,20 @@ class Weights:
def get_weights_row(self, prefix: str): def get_weights_row(self, prefix: str):
return self.weights_loader.get_weights_row(self, prefix) return self.weights_loader.get_weights_row(self, prefix)
@contextmanager
def use_loader(self, weights_loader: WeightsLoader):
"""
This method is a context manager that can be used to use `Weights` with
a different loader for the duration of the context.
"""
old_loader = self.weights_loader
self.weights_loader = weights_loader
try:
yield
finally:
self.weights_loader = old_loader
def _blocks_to_block_sizes(total_size: int, blocks: Union[int, List[int]]) -> List[int]: def _blocks_to_block_sizes(total_size: int, blocks: Union[int, List[int]]) -> List[int]:
""" """
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment