Unverified Commit 04e1af94 authored by drbh's avatar drbh Committed by GitHub
Browse files

Enable multiple LoRa adapters (#2010)



* feat: first draft load multiple lora

* feat: load weights within layer and refactor lora pass

* fix: refactor and reduce lora math

* feat: baseline impl single request multi lora support

* feat: prefer lorax implementation and port loading logic

* fix: prefer adapter_data and refactors

* feat: perfer loraxs custom punica kernels and add mlp loras

* fix: adjust batch for bgmv

* fix: adjust adapter_segments logic when in batch

* fix: refactor and move changes to v3 proto

* fix: pass model_id for all flash causal lms

* fix: pass model_id for all causal and seq2seq lms

* fix: add model_id to model test

* feat: add lora support to mistral and refactors

* feat: prefer model id in request

* fix: include rust code for adapter id

* feat: bump launcher and add new lora docs

* feat: support base model generation and refactors

* fix: rename doc to retry ci build

* feat: support if vlm models

* fix: add adapter_data param and avoid missing layers

* fix: add adapter_data param to phi and neox

* fix: update all models forwards to include adapter_data

* fix: add model_id to IdeficsCausalLM

* Update lora.md

Fixed a typo

* Update lora.md

Fixing spam image

* fix: add lora kernel to dockerfile, support running without kernels and refactors

* fix: avoid dockerfile conflict

* fix: refactors and adjust flash llama lora logic

* fix: skip llama test due to CI issue (temp)

* fix: skip llama test CI (temp) 2

* fix: revert skips and prefer updated ci token for tests

* fix: refactors and helpful comments

* fix: add noop in TensorParallelAdapterRowLinear too

* fix: refactor and move shard_lora_weights logic

* fix: exit early if no adapter_data

---------
Co-authored-by: default avatarDerek <datavistics@gmail.com>
parent a2a97b05
...@@ -634,6 +634,7 @@ class IdeficsCausalLM(Model): ...@@ -634,6 +634,7 @@ class IdeficsCausalLM(Model):
tokenizer.add_special_tokens({"pad_token": "<unk>"}) tokenizer.add_special_tokens({"pad_token": "<unk>"})
super(IdeficsCausalLM, self).__init__( super(IdeficsCausalLM, self).__init__(
model_id=model_id,
model=model, model=model,
tokenizer=tokenizer, tokenizer=tokenizer,
requires_padding=True, requires_padding=True,
......
...@@ -453,6 +453,7 @@ class Mamba(Model): ...@@ -453,6 +453,7 @@ class Mamba(Model):
model = MambaModel(config, weights) model = MambaModel(config, weights)
torch.distributed.barrier(group=self.process_group) torch.distributed.barrier(group=self.process_group)
super(Mamba, self).__init__( super(Mamba, self).__init__(
model_id=model_id,
model=model, model=model,
tokenizer=tokenizer, tokenizer=tokenizer,
requires_padding=True, requires_padding=True,
......
...@@ -2,12 +2,24 @@ import inspect ...@@ -2,12 +2,24 @@ import inspect
import torch import torch
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from typing import List, Tuple, Optional, TypeVar, Type from typing import List, Tuple, Optional, TypeVar, Type, Dict, DefaultDict
from collections import defaultdict
from transformers import PreTrainedTokenizerBase, PretrainedConfig from transformers import PreTrainedTokenizerBase, PretrainedConfig
from text_generation_server.models.types import Batch, Generation from text_generation_server.models.types import Batch, Generation
from text_generation_server.utils.speculate import get_speculate from text_generation_server.utils.speculate import get_speculate
from text_generation_server.pb.generate_pb2 import InfoResponse from text_generation_server.pb.generate_pb2 import InfoResponse
from text_generation_server.adapters.weights import LayerAdapterWeights
from text_generation_server.utils.adapter import (
load_and_merge_adapters,
AdapterParameters,
AdapterSource,
)
from loguru import logger
BASE_MODEL_ADAPTER_ID = "__base_model__"
B = TypeVar("B", bound=Batch) B = TypeVar("B", bound=Batch)
...@@ -15,6 +27,7 @@ B = TypeVar("B", bound=Batch) ...@@ -15,6 +27,7 @@ B = TypeVar("B", bound=Batch)
class Model(ABC): class Model(ABC):
def __init__( def __init__(
self, self,
model_id: str,
model: torch.nn.Module, model: torch.nn.Module,
tokenizer: PreTrainedTokenizerBase, tokenizer: PreTrainedTokenizerBase,
requires_padding: bool, requires_padding: bool,
...@@ -24,7 +37,9 @@ class Model(ABC): ...@@ -24,7 +37,9 @@ class Model(ABC):
world_size: int = 1, world_size: int = 1,
sliding_window: Optional[int] = None, sliding_window: Optional[int] = None,
speculate: Optional[int] = None, speculate: Optional[int] = None,
adapter_id: str = BASE_MODEL_ADAPTER_ID,
): ):
self.model_id = model_id
self.model = model.eval() self.model = model.eval()
self.tokenizer = tokenizer self.tokenizer = tokenizer
...@@ -42,6 +57,13 @@ class Model(ABC): ...@@ -42,6 +57,13 @@ class Model(ABC):
self.world_size = world_size self.world_size = world_size
self.sliding_window = sliding_window if sliding_window != -1 else None self.sliding_window = sliding_window if sliding_window != -1 else None
self.layer_to_adapter_weights: Dict[str, LayerAdapterWeights] = defaultdict(
LayerAdapterWeights
)
self.target_to_layer = self.adapter_target_to_layer()
self.loaded_adapters = set()
self.static_adapter_id = adapter_id
if speculate is None: if speculate is None:
speculate = get_speculate() speculate = get_speculate()
self.speculate = speculate self.speculate = speculate
...@@ -119,3 +141,136 @@ class Model(ABC): ...@@ -119,3 +141,136 @@ class Model(ABC):
raise RuntimeError( raise RuntimeError(
f"found uninitialized parameters in model {self.__class__.__name__}: {uninitialized_parameters}" f"found uninitialized parameters in model {self.__class__.__name__}: {uninitialized_parameters}"
) )
@property
def supports_adapter_loading(self) -> bool:
return False
def adapter_target_to_layer(self) -> Dict[str, Tuple[str, torch.Tensor]]:
return {}
@property
def adapter_layers(self) -> List[str]:
return []
@property
def default_traced_adapter_layers(self) -> List[str]:
return []
def get_num_layers_for_type(self, layer_type: str) -> int:
return 0
def is_row_parallel(self, layer_type: str) -> bool:
return False
@property
def max_speculative_tokens(self) -> int:
return max(
[
weights.max_speculative_tokens
for weights in self.layer_to_adapter_weights.values()
],
default=0,
)
def load_adapter(
self,
adapter_parameters: AdapterParameters,
adapter_source: AdapterSource,
adapter_index: int,
api_token: str,
dynamic: bool = True,
):
"""Loads adapter weights from disk / host memory on the GPU.
adapter_id must be `BASE_MODEL_ADAPTER_ID` if adapter statically loaded
into model. Otherwise, the adapter weights are applied during the forward
pass and stored separately from the base model parameters.
"""
if adapter_index in self.loaded_adapters:
# Adapter already loaded
return
if not self.supports_adapter_loading:
raise ValueError("This model does not support adapter loading.")
if dynamic and not self.dynamic_adapter_loading_enabled:
raise ValueError(
f"This model was initialized with the adapter {self.static_adapter_id} "
f"and therefore does not support dynamic adapter loading. "
f"Please initialize a new model instance from the base model in "
f"order to use the dynamic adapter loading feature."
)
logger.info(
f"Loading adapter weights into model: {','.join(adapter_parameters.adapter_ids)}"
)
weight_names = tuple([v[0] for v in self.target_to_layer.values()])
(
module_map,
adapter_config,
adapter_weight_names,
adapter_tokenizer,
) = load_and_merge_adapters(
self.model_id,
adapter_parameters,
adapter_source,
adapter_index,
weight_names,
api_token,
False,
)
unused_weight_names = adapter_weight_names.copy()
for layer_name in self.adapter_layers:
adapter_weights = adapter_config.load_batched_adapter_weights(
self,
module_map,
layer_name,
unused_weight_names,
dynamic,
)
if adapter_weights is None:
continue
layer_weights = self.layer_to_adapter_weights[layer_name]
layer_weights.add_adapter(adapter_index, adapter_weights)
if len(unused_weight_names) > 0:
logger.warning(
f"{','.join(adapter_parameters.adapter_ids)} unused adapter weights: {unused_weight_names}"
)
if adapter_tokenizer is not None:
self.tokenizers.add_tokenizer(adapter_index, adapter_tokenizer)
self.loaded_adapters.add(adapter_index)
def offload_adapter(
self,
adapter_parameters: AdapterParameters,
adapter_source: AdapterSource,
adapter_index: int,
):
"""Offloads the adapter weights from GPU to CPU or disk."""
if adapter_index not in self.loaded_adapters:
# Adapter already offloaded
return
if not self.supports_adapter_loading:
raise ValueError("This model does not support adapter loading.")
if not self.dynamic_adapter_loading_enabled:
raise ValueError(
f"This model was initialized with the adapter {self.static_adapter_id} "
f"and therefore does not support dynamic adapter loading. "
f"Please initialize a new model instance from the base model in "
f"order to use the dynamic adapter loading feature."
)
for layer_name in self.adapter_layers:
if layer_name in self.layer_to_adapter_weights:
self.layer_to_adapter_weights[layer_name].remove_adapter(adapter_index)
self.loaded_adapters.remove(adapter_index)
...@@ -90,6 +90,7 @@ class MPTSharded(CausalLM): ...@@ -90,6 +90,7 @@ class MPTSharded(CausalLM):
torch.distributed.barrier(group=self.process_group) torch.distributed.barrier(group=self.process_group)
super(CausalLM, self).__init__( super(CausalLM, self).__init__(
model_id=model_id,
model=model, model=model,
tokenizer=tokenizer, tokenizer=tokenizer,
requires_padding=False, requires_padding=False,
......
...@@ -63,6 +63,7 @@ class OPTSharded(CausalLM): ...@@ -63,6 +63,7 @@ class OPTSharded(CausalLM):
torch.distributed.barrier(group=self.process_group) torch.distributed.barrier(group=self.process_group)
super(CausalLM, self).__init__( super(CausalLM, self).__init__(
model_id=model_id,
model=model, model=model,
tokenizer=tokenizer, tokenizer=tokenizer,
requires_padding=True, requires_padding=True,
......
...@@ -60,6 +60,7 @@ class Phi(CausalLM): ...@@ -60,6 +60,7 @@ class Phi(CausalLM):
model = PhiForCausalLM(config, weights) model = PhiForCausalLM(config, weights)
torch.distributed.barrier(group=self.process_group) torch.distributed.barrier(group=self.process_group)
super(CausalLM, self).__init__( super(CausalLM, self).__init__(
model_id=model_id,
model=model, model=model,
tokenizer=tokenizer, tokenizer=tokenizer,
requires_padding=True, requires_padding=True,
......
...@@ -62,6 +62,7 @@ class RW(CausalLM): ...@@ -62,6 +62,7 @@ class RW(CausalLM):
tokenizer.add_special_tokens({"pad_token": "[PAD]"}) tokenizer.add_special_tokens({"pad_token": "[PAD]"})
super(CausalLM, self).__init__( super(CausalLM, self).__init__(
model_id=model_id,
model=model, model=model,
tokenizer=tokenizer, tokenizer=tokenizer,
requires_padding=True, requires_padding=True,
......
...@@ -62,6 +62,7 @@ class SantaCoder(CausalLM): ...@@ -62,6 +62,7 @@ class SantaCoder(CausalLM):
) )
super(CausalLM, self).__init__( super(CausalLM, self).__init__(
model_id=model_id,
model=model, model=model,
tokenizer=tokenizer, tokenizer=tokenizer,
requires_padding=True, requires_padding=True,
......
...@@ -575,6 +575,7 @@ class Seq2SeqLM(Model): ...@@ -575,6 +575,7 @@ class Seq2SeqLM(Model):
tokenizer.bos_token_id = model.config.decoder_start_token_id tokenizer.bos_token_id = model.config.decoder_start_token_id
super(Seq2SeqLM, self).__init__( super(Seq2SeqLM, self).__init__(
model_id=model_id,
model=model, model=model,
tokenizer=tokenizer, tokenizer=tokenizer,
requires_padding=True, requires_padding=True,
......
...@@ -73,6 +73,7 @@ class T5Sharded(Seq2SeqLM): ...@@ -73,6 +73,7 @@ class T5Sharded(Seq2SeqLM):
torch.distributed.barrier(group=self.process_group) torch.distributed.barrier(group=self.process_group)
super(Seq2SeqLM, self).__init__( super(Seq2SeqLM, self).__init__(
model_id=model_id,
model=model, model=model,
tokenizer=tokenizer, tokenizer=tokenizer,
requires_padding=True, requires_padding=True,
......
...@@ -222,7 +222,9 @@ class VlmCausalLM(BaseFlashMistral): ...@@ -222,7 +222,9 @@ class VlmCausalLM(BaseFlashMistral):
return VlmCausalLMBatch return VlmCausalLMBatch
def forward( def forward(
self, batch: VlmCausalLMBatch self,
batch: VlmCausalLMBatch,
adapter_data: Optional[Dict[str, torch.Tensor]] = None,
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
# Model Forward # Model Forward
if batch.speculative_ids is not None: if batch.speculative_ids is not None:
......
...@@ -29,7 +29,10 @@ except (ImportError, NotImplementedError): ...@@ -29,7 +29,10 @@ except (ImportError, NotImplementedError):
from text_generation_server.pb import generate_pb2_grpc, generate_pb2 from text_generation_server.pb import generate_pb2_grpc, generate_pb2
from text_generation_server.tracing import UDSOpenTelemetryAioServerInterceptor from text_generation_server.tracing import UDSOpenTelemetryAioServerInterceptor
from text_generation_server.models.globals import set_model_id from text_generation_server.models.globals import set_model_id, set_adapter_to_index
from text_generation_server.utils.adapter import (
AdapterParameters,
)
class SignalHandler: class SignalHandler:
...@@ -192,6 +195,7 @@ class TextGenerationService(generate_pb2_grpc.TextGenerationServiceServicer): ...@@ -192,6 +195,7 @@ class TextGenerationService(generate_pb2_grpc.TextGenerationServiceServicer):
def serve( def serve(
model_id: str, model_id: str,
lora_adapter_ids: Optional[List[str]],
revision: Optional[str], revision: Optional[str],
sharded: bool, sharded: bool,
quantize: Optional[str], quantize: Optional[str],
...@@ -203,6 +207,7 @@ def serve( ...@@ -203,6 +207,7 @@ def serve(
): ):
async def serve_inner( async def serve_inner(
model_id: str, model_id: str,
lora_adapter_ids: Optional[List[str]],
revision: Optional[str], revision: Optional[str],
sharded: bool = False, sharded: bool = False,
quantize: Optional[str] = None, quantize: Optional[str] = None,
...@@ -211,6 +216,7 @@ def serve( ...@@ -211,6 +216,7 @@ def serve(
trust_remote_code: bool = False, trust_remote_code: bool = False,
): ):
unix_socket_template = "unix://{}-{}" unix_socket_template = "unix://{}-{}"
adapter_to_index = {}
if sharded: if sharded:
server_urls = [ server_urls = [
unix_socket_template.format(uds_path, rank) unix_socket_template.format(uds_path, rank)
...@@ -224,6 +230,7 @@ def serve( ...@@ -224,6 +230,7 @@ def serve(
try: try:
model = get_model( model = get_model(
model_id, model_id,
lora_adapter_ids,
revision, revision,
sharded, sharded,
quantize, quantize,
...@@ -232,10 +239,33 @@ def serve( ...@@ -232,10 +239,33 @@ def serve(
trust_remote_code, trust_remote_code,
max_input_tokens, max_input_tokens,
) )
if len(lora_adapter_ids) > 0:
for index, adapter_id in enumerate(lora_adapter_ids):
# TODO: improve non merged adapter loading and long term
# improve adapter loading as a whole
adapter_parameters = AdapterParameters(
adapter_ids=[adapter_id],
weights=None, # will be set to 1
merge_strategy=0,
density=1.0,
majority_sign_method=0,
)
adapter_index = index + 1
adapter_to_index[adapter_id] = adapter_index
model.load_adapter(
adapter_parameters,
None, # adapter_source
adapter_index,
None, # api_token
False, # dynamic
)
except Exception: except Exception:
logger.exception("Error when initializing model") logger.exception("Error when initializing model")
raise raise
set_adapter_to_index(adapter_to_index)
server = aio.server( server = aio.server(
interceptors=[ interceptors=[
ExceptionInterceptor(), ExceptionInterceptor(),
...@@ -266,6 +296,13 @@ def serve( ...@@ -266,6 +296,13 @@ def serve(
set_model_id(model_id) set_model_id(model_id)
asyncio.run( asyncio.run(
serve_inner( serve_inner(
model_id, revision, sharded, quantize, speculate, dtype, trust_remote_code model_id,
lora_adapter_ids,
revision,
sharded,
quantize,
speculate,
dtype,
trust_remote_code,
) )
) )
# Origin: https://github.com/predibase/lorax
# Path: lorax/server/lorax_server/utils/adapter.py
# License: Apache License Version 2.0, January 2004
import warnings
from dataclasses import dataclass
from functools import lru_cache
from typing import TYPE_CHECKING, Set, Tuple
from safetensors.torch import load_file
from transformers import AutoConfig, AutoTokenizer, PreTrainedTokenizer
from text_generation_server.pb import generate_pb2
from text_generation_server.utils.merges.strategies import merge_adapters
from text_generation_server.utils import hub
from text_generation_server.adapters.lora import LoraConfig
if TYPE_CHECKING:
from text_generation_server.adapters.config import AdapterConfig, ModuleMap
BASE_MODEL_ADAPTER_ID = "__base_model__"
@dataclass
class AdapterParameters:
adapter_ids: Tuple[str]
weights: Tuple[float]
merge_strategy: NotImplemented
density: float
majority_sign_method: NotImplemented
@dataclass
class AdapterSource:
adapter_id: str
model_id: str
revision: str
def load_and_merge_adapters(
model_id: str,
adapter_parameters: AdapterParameters,
adapter_source: str,
adapter_index: int,
weight_names: Tuple[str],
api_token: str,
trust_remote_code: bool = False,
) -> Tuple["ModuleMap", "AdapterConfig", Set[str], PreTrainedTokenizer]:
if len(adapter_parameters.adapter_ids) == 1:
return load_module_map(
model_id,
adapter_parameters.adapter_ids[0],
adapter_source,
weight_names,
api_token,
trust_remote_code,
)
adapter_params = AdapterParametersContainer(
adapter_parameters, adapter_source, adapter_index
)
return _load_and_merge(
model_id, adapter_params, weight_names, api_token, trust_remote_code
)
@dataclass
class AdapterParametersContainer:
adapter_parameters: AdapterParameters
adapter_source: str
adapter_index: int
def __hash__(self) -> int:
return self.adapter_index
@lru_cache(maxsize=32)
def _load_and_merge(
model_id: str,
adapter_params: AdapterParametersContainer,
weight_names: Tuple[str],
api_token: str,
trust_remote_code: bool = False,
) -> Tuple["ModuleMap", "AdapterConfig", Set[str], PreTrainedTokenizer]:
params = adapter_params.adapter_parameters
adapters_to_merge = []
merged_weight_names = set()
tokenizer = None
for adapter_id in params.adapter_ids:
if adapter_id == BASE_MODEL_ADAPTER_ID:
raise ValueError("Base model adapter cannot be merged.")
module_map, adapter_config, adapter_weight_names, adapter_tokenizer = (
load_module_map(
model_id,
adapter_id,
adapter_params.adapter_source,
weight_names,
api_token,
trust_remote_code,
)
)
adapters_to_merge.append((module_map, adapter_config))
merged_weight_names = merged_weight_names.union(adapter_weight_names)
if tokenizer is None:
tokenizer = adapter_tokenizer
if len(adapters_to_merge) == 0:
raise ValueError("No adapters to merge.")
module_map, adapter_config = merge_adapters(adapters_to_merge, params)
return module_map, adapter_config, merged_weight_names, tokenizer
def check_architectures(
model_id: str,
adapter_id: str,
adapter_config: "AdapterConfig",
trust_remote_code: bool = False,
):
try:
if not adapter_config.base_model_name_or_path:
# Avoid execution latency caused by the network connection retrying for AutoConfig.from_pretrained(None)
return
expected_config = AutoConfig.from_pretrained(
model_id, trust_remote_code=trust_remote_code
)
model_config = AutoConfig.from_pretrained(
adapter_config.base_model_name_or_path, trust_remote_code=trust_remote_code
)
except Exception as e:
warnings.warn(
f"Unable to check architecture compatibility for adapter '{adapter_id}' "
f"against model '{model_id}'. Assuming they are compatible. Error: {e}"
)
return
if model_config.architectures == expected_config.architectures:
warnings.warn(
f"Adapter '{adapter_id}' was not trained on base model '{model_id}'. "
f"If you encounter issues, use --model-id '{adapter_config.base_model_name_or_path}' instead."
)
else:
# TODO(travis): revisit this when we support clasification heads which will not use CausalLM
raise ValueError(
f"Adapter '{adapter_id}' is not compatible with model '{model_id}'. "
f"Architectures differ: {model_config.architectures} != {expected_config.architectures}. "
f"Use --model-id '{adapter_config.base_model_name_or_path}' instead."
)
@lru_cache(maxsize=128)
def load_module_map(
model_id: str,
adapter_id: str,
adapter_source: str,
weight_names: Tuple[str],
api_token: str,
trust_remote_code: bool = False,
) -> Tuple["ModuleMap", "AdapterConfig", Set[str], PreTrainedTokenizer]:
revision = "main"
adapter_config = LoraConfig.load(adapter_id, api_token)
if adapter_config.base_model_name_or_path != model_id:
check_architectures(model_id, adapter_id, adapter_config, trust_remote_code)
adapter_filenames = hub._cached_adapter_weight_files(
adapter_id, revision=revision, extension=".safetensors"
)
try:
adapter_tokenizer = AutoTokenizer.from_pretrained(
adapter_config.config_path,
token=api_token,
trust_remote_code=trust_remote_code,
)
except Exception:
# Adapter does not have a tokenizer, so fallback to base model tokenizer
adapter_tokenizer = None
# load adapter weights from all shards (should have relatively small memory footprint)
adapter_weights = {}
for filename in adapter_filenames:
adapter_weights.update(load_file(filename))
# map the model weights to the relevant adapter weights (LoRA A and B matrices)
module_map, adapter_weight_names = adapter_config.map_weights_for_model(
adapter_weights, weight_names
)
return module_map, adapter_config, adapter_weight_names, adapter_tokenizer
...@@ -18,6 +18,17 @@ WEIGHTS_CACHE_OVERRIDE = os.getenv("WEIGHTS_CACHE_OVERRIDE", None) ...@@ -18,6 +18,17 @@ WEIGHTS_CACHE_OVERRIDE = os.getenv("WEIGHTS_CACHE_OVERRIDE", None)
HF_HUB_OFFLINE = os.environ.get("HF_HUB_OFFLINE", "0").lower() in ["true", "1", "yes"] HF_HUB_OFFLINE = os.environ.get("HF_HUB_OFFLINE", "0").lower() in ["true", "1", "yes"]
def _cached_adapter_weight_files(
adapter_id: str, revision: Optional[str], extension: str
) -> List[str]:
"""Guess weight files from the cached revision snapshot directory"""
d = _get_cached_revision_directory(adapter_id, revision)
if not d:
return []
filenames = _adapter_weight_files_from_dir(d, extension)
return filenames
def _cached_weight_files( def _cached_weight_files(
model_id: str, revision: Optional[str], extension: str model_id: str, revision: Optional[str], extension: str
) -> List[str]: ) -> List[str]:
...@@ -60,6 +71,33 @@ def _weight_files_from_dir(d: Path, extension: str) -> List[str]: ...@@ -60,6 +71,33 @@ def _weight_files_from_dir(d: Path, extension: str) -> List[str]:
return filenames return filenames
def _adapter_weight_files_from_dir(d: Path, extension: str) -> List[str]:
# os.walk: do not iterate, just scan for depth 1, not recursively
# see _weight_files_from_dir, that's also what is done there
root, _, files = next(os.walk(str(d)))
filenames = [
os.path.join(root, f)
for f in files
if f.endswith(extension)
and "arguments" not in f
and "args" not in f
and "training" not in f
]
return filenames
def _adapter_config_files_from_dir(d: Path) -> List[str]:
# os.walk: do not iterate, just scan for depth 1, not recursively
# see _weight_files_from_dir, that's also what is done there
root, _, files = next(os.walk(str(d)))
filenames = [
os.path.join(root, f)
for f in files
if f.endswith(".json") and "arguments" not in f and "args" not in f
]
return filenames
def _get_cached_revision_directory( def _get_cached_revision_directory(
model_id: str, revision: Optional[str] model_id: str, revision: Optional[str]
) -> Optional[Path]: ) -> Optional[Path]:
......
import copy
from abc import ABC
from collections import defaultdict
from typing import TYPE_CHECKING, Dict, List, Tuple, Type, Union
import torch
class AdapterParameters:
def __init__(
self, adapter_ids, weights, merge_strategy, density, majority_sign_method
):
self.adapter_ids = adapter_ids
self.weights = weights
self.merge_strategy = merge_strategy
self.density = density
self.majority_sign_method = majority_sign_method
from text_generation_server.utils.merges.utils import (
calculate_majority_sign_mask,
disjoint_merge,
prune,
)
if TYPE_CHECKING:
from text_generation_server.adapters.lora import LoraConfig
from text_generation_server.utils.adapter import ModuleMap
def _apply_weights(
tensors: Union[torch.Tensor, List[torch.Tensor]], w: torch.Tensor
) -> torch.Tensor:
if isinstance(tensors, torch.Tensor):
t = tensors
else:
t = torch.stack(tensors, dim=0)
# element-wise weighting of each task tensor
# need to unsqueeze weights to match task tensor dimensions
# for multiplication to apply element-wise
while len(t.shape) > len(w.shape):
w = w.unsqueeze(-1)
return t * w
class MergeStrategy(ABC):
def merge(
self, task_tensors: List[torch.Tensor], weights: torch.Tensor
) -> torch.Tensor:
raise NotImplementedError()
class LinearMerge(MergeStrategy):
def __init__(self, **kwargs):
pass
def merge(
self, task_tensors: List[torch.Tensor], weights: torch.Tensor
) -> torch.Tensor:
weighted_task_tensors = _apply_weights(task_tensors, weights)
return weighted_task_tensors.sum(dim=0)
class TiesMerge(MergeStrategy):
def __init__(self, density: float, majority_sign_method: str = "total", **kwargs):
self.density = density
self.majority_sign_method = majority_sign_method
def merge(
self, task_tensors: List[torch.Tensor], weights: torch.Tensor
) -> torch.Tensor:
# sparsify
task_tensors = [
prune(tensor, self.density, method="magnitude") for tensor in task_tensors
]
task_tensors = torch.stack(task_tensors, dim=0)
# elect sign before applying weights
majority_sign_mask = calculate_majority_sign_mask(
task_tensors, method=self.majority_sign_method
)
weighted_task_tensors = _apply_weights(task_tensors, weights)
# disjoint merge
return disjoint_merge(weighted_task_tensors, majority_sign_mask)
class DareLinearMerge(MergeStrategy):
def __init__(self, density: float, **kwargs):
self.density = density
def merge(
self, task_tensors: List[torch.Tensor], weights: torch.Tensor
) -> torch.Tensor:
# sparsify
task_tensors = [
prune(tensor, self.density, method="random", rescale=True)
for tensor in task_tensors
]
weighted_task_tensors = _apply_weights(task_tensors, weights)
return weighted_task_tensors.sum(dim=0)
class DareTiesMerge(MergeStrategy):
def __init__(self, density: float, majority_sign_method: str = "total", **kwargs):
self.density = density
self.majority_sign_method = majority_sign_method
def merge(
self, task_tensors: List[torch.Tensor], weights: torch.Tensor
) -> torch.Tensor:
# sparsify
task_tensors = [
prune(tensor, self.density, method="random", rescale=True)
for tensor in task_tensors
]
task_tensors = torch.stack(task_tensors, dim=0)
# elect sign before applying weights
majority_sign_mask = calculate_majority_sign_mask(
task_tensors, method=self.majority_sign_method
)
weighted_task_tensors = _apply_weights(task_tensors, weights)
# disjoint merge
mixed_task_tensors = disjoint_merge(weighted_task_tensors, majority_sign_mask)
return mixed_task_tensors
strategy_registry: Dict[str, Type[MergeStrategy]] = {
"linear": LinearMerge,
"ties": TiesMerge,
"dare_linear": DareLinearMerge,
"dare_ties": DareTiesMerge,
}
def merge_adapters(
adapters: List[Tuple["ModuleMap", "LoraConfig"]],
merge_params: AdapterParameters,
) -> Tuple["ModuleMap", "LoraConfig"]:
# strategy_name = MergeStrategyEnum.Name(merge_params.merge_strategy).lower()
strategy_name = "linear"
weights = merge_params.weights
if not weights:
weights = torch.ones(len(adapters))
else:
weights = torch.tensor(weights)
merge_config = {
"density": merge_params.density,
# "majority_sign_method": MajoritySignMethodEnum.Name(
# merge_params.majority_sign_method
# ).lower(),
"majority_sign_method": "total",
}
merge_strategy = strategy_registry[strategy_name](**merge_config)
module_maps: Dict[str, Dict[str, Dict[str, List[torch.Tensor]]]] = defaultdict(
lambda: defaultdict(lambda: defaultdict(list))
)
lora_configs = []
weight_name_to_adapter_idx = defaultdict(list)
# input is list of (module_map, lora_config) tuples
# convert into dict[k][param_name] -> list of tensors
for idx, (module_map, lora_config) in enumerate(adapters):
for weight_name, data in module_map.items():
weight_name_to_adapter_idx[weight_name].append(idx)
for k, (param_data, param_name) in data.items():
module_maps[weight_name][k][param_name].append(param_data)
lora_configs.append(lora_config)
# validate lora configs are compatible
_validate_lora_configs(lora_configs)
# merge tensors for each module such that we have a single ModuleMap:
# dict[k] -> merged tensor
merged_module_map: "ModuleMap" = defaultdict(dict)
for weight_name, data in module_maps.items():
indices = weight_name_to_adapter_idx[weight_name]
param_weights = weights[indices]
for k, param_data in data.items():
for param_name, tensors in param_data.items():
merged_tensor = merge_strategy.merge(tensors, param_weights)
merged_module_map[weight_name][k] = (merged_tensor, param_name)
# merge lora configs
merged_lora_config = _merge_lora_configs(lora_configs)
return merged_module_map, merged_lora_config
def _validate_lora_configs(lora_configs: List["LoraConfig"]):
# check that all configs have the same rank
ranks = set(lora_config.r for lora_config in lora_configs)
if len(ranks) > 1:
raise ValueError(
f"unable to merge adapters, lora configs have different ranks: {ranks}"
)
if all(len(lora_config.target_modules) == 0 for lora_config in lora_configs):
raise ValueError(
"unable to merge adapters, lora configs have no target modules"
)
def _merge_lora_configs(lora_configs: List["LoraConfig"]) -> "LoraConfig":
merged_lora_config = copy.copy(lora_configs[0])
# merge target modules as a union operation
merged_target_modules = sorted(
set(
module
for lora_config in lora_configs
for module in lora_config.target_modules
)
)
merged_lora_config.target_modules = merged_target_modules
return merged_lora_config
# coding=utf-8
# From: https://github.com/huggingface/peft/pull/1364
# Copyright 2024-present the HuggingFace Inc. team.
# Modifications by Predibase, Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Literal
import torch
def magnitude_based_pruning(tensor: torch.Tensor, density: float) -> torch.Tensor:
"""
Prune the smallest values of the task tensors and retain the top-k values based on the specified fraction
`density`.
Args:
tensor (`torch.Tensor`):The tensor to prune.
density (`float`):The fraction of values to preserve. Should be in [0,1].
"""
mask = torch.zeros_like(tensor).reshape(-1)
k = int(density * tensor.reshape(-1).shape[0])
top_k = torch.topk(tensor.abs().reshape(-1), k=k, largest=True)
mask[top_k[1]] = 1
return tensor * mask.reshape(tensor.shape)
def random_pruning(tensor: torch.Tensor, density: float, rescale: bool) -> torch.Tensor:
"""
Prune the smallest values of the task tensors and retain the top-k values based on the specified fraction
`density`.
Args:
tensor (`torch.Tensor`):The tensor to prune.
density (`float`):The fraction of values to preserve. Should be in [0,1].
rescale (`bool`):Whether to rescale the result to preserve the expected value of the original tensor.
"""
mask = torch.bernoulli(torch.full_like(input=tensor, fill_value=density))
pruned_tensor = tensor * mask
if rescale:
torch.div(input=pruned_tensor, other=density)
return pruned_tensor
def prune(
tensor: torch.Tensor,
density: float,
method: Literal["magnitude", "random"],
rescale: bool = False,
) -> torch.Tensor:
"""
Prune the values of task tensors based on the `method`.
Args:
tensor (`torch.Tensor`):The tensor to prune.
density (`float`):The fraction of values to preserve. Should be in [0,1].
method (`str`):The method to use to prune. Should be one of ["magnitude", "random"].
rescale (`bool`):Whether to rescale the result to preserve the expected value of the original tensor.
"""
if density >= 1:
return tensor
elif density < 0:
raise ValueError("Density should be >= 0, got {density}")
if method == "magnitude":
return magnitude_based_pruning(tensor, density)
elif method == "random":
return random_pruning(tensor, density, rescale=rescale)
else:
raise ValueError(f"Unknown method {method}")
def calculate_majority_sign_mask(
tensor: torch.Tensor, method: Literal["total", "frequency"] = "total"
):
"""
Get the mask of the majority sign across the task tensors. Task tensors are stacked on dimension 0.
Args:
tensor (`torch.Tensor`):The tensor to get the mask from.
method (`str`):The method to use to get the mask. Should be one of ["total", "frequency"].
"""
sign = tensor.sign()
if method == "total":
sign_magnitude = (sign * tensor.abs()).sum(dim=0)
elif method == "frequency":
sign_magnitude = sign.sum(dim=0)
else:
raise RuntimeError(f'Unimplemented mask method "{method}"')
majority_sign = torch.where(sign_magnitude >= 0, 1, -1)
return sign == majority_sign
def disjoint_merge(task_tensors, majority_sign_mask):
mixed_task_tensors = (task_tensors * majority_sign_mask).sum(dim=0)
num_params_preserved = majority_sign_mask.sum(dim=0)
return mixed_task_tensors / torch.clamp(num_params_preserved, min=1.0)
import os import os
import json from typing import Union
from loguru import logger from loguru import logger
import torch import torch
...@@ -43,3 +43,26 @@ def download_and_unload_peft(model_id, revision, trust_remote_code): ...@@ -43,3 +43,26 @@ def download_and_unload_peft(model_id, revision, trust_remote_code):
model.save_pretrained(cache_dir, safe_serialization=True) model.save_pretrained(cache_dir, safe_serialization=True)
model.config.save_pretrained(cache_dir) model.config.save_pretrained(cache_dir)
tokenizer.save_pretrained(cache_dir) tokenizer.save_pretrained(cache_dir)
def download_peft(
model_id: Union[str, os.PathLike], revision: str, trust_remote_code: bool
):
torch_dtype = torch.float16
try:
_model = AutoPeftModelForCausalLM.from_pretrained(
model_id,
revision=revision,
torch_dtype=torch_dtype,
trust_remote_code=trust_remote_code,
low_cpu_mem_usage=True,
)
except Exception:
_model = AutoPeftModelForSeq2SeqLM.from_pretrained(
model_id,
revision=revision,
torch_dtype=torch_dtype,
trust_remote_code=trust_remote_code,
low_cpu_mem_usage=True,
)
logger.info("Peft model downloaded.")
# Origin: https://github.com/predibase/lorax
# Path: lorax/server/lorax_server/utils/segments.py
# License: Apache License Version 2.0, January 2004
from typing import List, Tuple, Union
import torch
def find_segments(
adapter_indices: Union[torch.Tensor, List[int]]
) -> Tuple[List[int], List[int]]:
segments = [0]
segment_indices = []
if isinstance(adapter_indices, torch.Tensor):
# Calling .item() repeatedly on CUDA tensor is very slow, so we move it to CPU first
adapter_indices = adapter_indices.cpu().tolist()
start_index = 0
for i in range(1, len(adapter_indices)):
if adapter_indices[i] != adapter_indices[i - 1]:
segments.append(i)
segment_indices.append(adapter_indices[i - 1])
start_index = i
# Handle the last segment
if start_index < len(adapter_indices):
segments.append(len(adapter_indices))
segment_indices.append(adapter_indices[-1])
return segments, segment_indices
class SegmentConcatBuilder:
def __init__(self):
self.adapter_segment_indices = []
self.adapter_segment_tensors = []
def concat(self, adapter_segments: torch.Tensor, segment_indices: List[int]):
# Update adapter segments
if self.adapter_segment_tensors:
# Because we have already processed at least one batch, remove the 0 start index
# from this batch denoting the beginning of the segment, then offset all segment
# positions by the value of the last segment in the previous batch to account for
# the concatenation.
adapter_segments = (
adapter_segments[1:] + self.adapter_segment_tensors[-1][-1]
)
if (
self.adapter_segment_indices
and self.adapter_segment_indices[-1] == segment_indices[0]
):
# If the last segment in the previous batch is the same as the first segment in this batch,
# then we merge them together into a single segment. In effect, this means removing it from
# the segment indices of this batch, and extending the segment span by removing the segment
# end index from the previous batch.
segment_indices = segment_indices[1:]
self.adapter_segment_tensors[-1] = self.adapter_segment_tensors[-1][:-1]
self.adapter_segment_indices.extend(segment_indices)
self.adapter_segment_tensors.append(adapter_segments)
def build(self) -> Tuple[torch.Tensor, List[int]]:
return torch.concat(self.adapter_segment_tensors), self.adapter_segment_indices
# Origin: https://github.com/predibase/lorax
# Path: lorax/server/lorax_server/utils/sgmv.py
# License: Apache License Version 2.0, January 2004
import os
import warnings
from functools import lru_cache
from typing import List, Tuple
import torch
import torch.nn.functional as F
try:
import punica_kernels as _kernels
HAS_SGMV = not bool(os.environ.get("DISABLE_SGMV", ""))
except ImportError:
warnings.warn("Could not import SGMV kernel from Punica, falling back to loop.")
_kernels = None
HAS_SGMV = False
MIN_SGMV_RANK = 8
MIN_RANK_CUSTOM = 16
MAX_RANK_CUSTOM = 128
SGMV_BLOCK_SIZE = 16
BGMV_MAX_RANK = 64
def has_sgmv() -> bool:
return HAS_SGMV
def pad_rank(t: torch.Tensor, dim: int, world_size: int) -> torch.Tensor:
"""Pad a tensor to the minimum rank for SGMV and the nearest multiple of the SGMV block size."""
if not has_sgmv():
return t
# tensor parallelism will result in effective rank being divided by world_size,
# so we need to scale the min rank to offset that effect
min_rank = MIN_SGMV_RANK * world_size
# if we're at or below the min rank, pad up to the min rank
# otherwise, pad to the nearest multiple of the block size
current_rank = t.size(dim)
target_rank = (
min_rank
if current_rank <= min_rank
else (current_rank + SGMV_BLOCK_SIZE - 1) // SGMV_BLOCK_SIZE * SGMV_BLOCK_SIZE
)
if current_rank == target_rank:
return t
pad_size = target_rank - current_rank
# see complicatd pad syntax here: https://pytorch.org/docs/stable/generated/torch.nn.functional.pad.html
pad = [0, 0] * t.dim()
pad[(t.dim() - dim - 1) * 2 + 1] = pad_size
pad = tuple(pad)
return F.pad(t, pad, mode="constant", value=0.0)
def use_cutlass_shrink(lora_rank: int) -> bool:
return lora_rank < MIN_RANK_CUSTOM
def orient_for_rank(t: torch.Tensor, rank: int) -> torch.Tensor:
if MIN_RANK_CUSTOM <= rank <= MAX_RANK_CUSTOM:
return t.transpose(0, 1)
return t
# Source: https://github.com/punica-ai/punica/blob/master/src/punica/ops/__init__.py
def add_lora_sgmv_cutlass(
y: torch.Tensor,
x: torch.Tensor,
wa_ptr: torch.Tensor,
wb_ptr: torch.Tensor,
s_start: torch.Tensor,
s_end: torch.Tensor,
layer_idx: int,
lora_rank: int,
):
"""
Semantics:
y[s[i]:s[i+1]] += x[s[i]:s[i+1]] @ deref(wa_ptr[i]).T @ deref(wb_ptr[i])
Args:
y: Shape: `[B, H2]`. Output vectors. Will be changed in-place.
x: Shape: `[B, H1]`. Input vectors.
wa_ptr: Shape: `[S]`. DType: torch.int64. Pointer to the weight matrices.\
Weight matrix shape: `[num_layers, R, H1]`.
wb_ptr: Shape: `[S]`. DType: torch.int64. Pointer to the weight matrices.\
Weight matrix shape: `[num_layers, R, H2]`.
s_start: Shape: `[S]`, DType: torch.int32. Indptr of the weight matrices start indices.
s_end: Shape: `[S]`, DType: torch.int32. Indptr of the weight matrices end indices.
layer_idx: Layer index of the weight matrices.
"""
if lora_rank < MIN_RANK_CUSTOM or lora_rank > MAX_RANK_CUSTOM:
# Custom SGMV shrink only supports rank 16, 32, 64, 128
_add_lora_sgmv_cutlass_legacy(
y, x, wa_ptr, wb_ptr, s_start, s_end, layer_idx, lora_rank
)
return
tmp1 = torch.empty((8 * 1024 * 1024,), dtype=torch.uint8, device=x.device)
tmp2_size = _kernels.sgmv_cutlass_tmp_size(wa_ptr.size(0))
tmp2 = torch.empty((tmp2_size,), dtype=torch.uint8, device=x.device)
v = torch.zeros((x.size(0), lora_rank), dtype=x.dtype, device=x.device)
_kernels.sgmv_shrink(v, x, wa_ptr, s_start, s_end, tmp1, layer_idx)
_kernels.sgmv_cutlass(y, v, wb_ptr, s_start, s_end, tmp2, layer_idx)
def _add_lora_sgmv_cutlass_legacy(
y: torch.Tensor,
x: torch.Tensor,
wa_ptr: torch.Tensor,
wb_ptr: torch.Tensor,
s_start: torch.IntTensor,
s_end: torch.IntTensor,
layer_idx: int,
lora_rank: int,
):
tmp_size = _kernels.sgmv_cutlass_tmp_size(wa_ptr.size(0))
tmp = torch.empty((tmp_size,), dtype=torch.uint8, device=x.device)
v = torch.zeros((x.size(0), lora_rank), dtype=x.dtype, device=x.device)
_kernels.sgmv_cutlass(v, x, wa_ptr, s_start, s_end, tmp, layer_idx)
_kernels.sgmv_cutlass(y, v, wb_ptr, s_start, s_end, tmp, layer_idx)
@lru_cache(maxsize=1)
def get_tmp_tensor(device: torch.device) -> torch.Tensor:
return torch.empty((8 * 1024 * 1024,), dtype=torch.uint8, device=device)
@lru_cache(maxsize=32)
def get_tmp_tensor_for_size(size: int, device: torch.device) -> torch.Tensor:
tmp_size = _kernels.sgmv_cutlass_tmp_size(size)
return torch.empty((tmp_size,), dtype=torch.uint8, device=device)
def get_tmp_tensor_for_size_no_kernels(size: int, device: torch.device) -> torch.Tensor:
return torch.empty((size,), dtype=torch.uint8, device=device)
def get_tmp_expand_size(size: int) -> int:
return _kernels.sgmv_cutlass_tmp_size(size)
def get_tmp_tensors(
nsegments: int, lora_rank: int, device: torch.device
) -> Tuple[torch.Tensor, torch.Tensor]:
if use_cutlass_shrink(lora_rank) and has_sgmv():
tmp = get_tmp_tensor_for_size(nsegments, device)
return tmp, tmp
else:
tmp_shrink = get_tmp_tensor(device)
tmp_expand = get_tmp_tensor_for_size_no_kernels(nsegments, device)
return tmp_shrink, tmp_expand
def lora_a_sgmv_cutlass(
x: torch.Tensor,
tmp: torch.Tensor,
wa_ptr: torch.Tensor,
s_start: torch.IntTensor,
s_end: torch.IntTensor,
layer_idx: int,
lora_rank: int,
) -> torch.Tensor:
v = torch.zeros((x.size(0), lora_rank), dtype=x.dtype, device=x.device)
if MIN_RANK_CUSTOM <= lora_rank <= MAX_RANK_CUSTOM:
_kernels.sgmv_shrink(v, x, wa_ptr, s_start, s_end, tmp, layer_idx)
else:
_kernels.sgmv_cutlass(v, x, wa_ptr, s_start, s_end, tmp, layer_idx)
return v
def lora_b_sgmv_cutlass(
y: torch.Tensor,
v: torch.Tensor,
tmp: torch.Tensor,
wb_ptr: torch.Tensor,
s_start: torch.IntTensor,
s_end: torch.IntTensor,
layer_idx: int,
):
_kernels.sgmv_cutlass(y, v, wb_ptr, s_start, s_end, tmp, layer_idx)
"""
Semantics:
y[i] += (
x[i].unsqueeze(0)
@ wa_T_all[indices[i], layer_idx, :, :].transpose(-1, -2)
@ wb_T_all[indices[i], layer_idx, :, :].transpose(-1, -2)
* scale
).squeeze(0)
Args:
y: Shape: `[B, H2]`. Output vectors. Will be changed in-place.
v: Shape: `[B, R]`. Temporary vector.
x: Shape: `[B, H1]`. Input vectors.
wa_T_all: Shape: `[None, L, R, H1]`. All of the transposed LoRA A matrices.
wb_T_all: Shape: `[None, L, H2, R]`. All of the transposed LoRA B matrices.
indicies: Shape: `[B]`. Indices of the LoRA weights.
layer_idx: Layer index of LoRA weights.
scale: Scaling factor.
"""
def add_lora_a_bgmv(
v: torch.Tensor,
x: torch.Tensor,
wa_T_all: torch.Tensor,
indicies: torch.LongTensor,
layer_idx: int,
):
_kernels.dispatch_bgmv(v, x, wa_T_all, indicies, layer_idx, 1.0)
def add_lora_b_bgmv(
y: torch.Tensor,
v: torch.Tensor,
wb_T_all: torch.Tensor,
indicies: torch.LongTensor,
layer_idx: int,
):
_kernels.dispatch_bgmv(y, v, wb_T_all, indicies, layer_idx, 1.0)
def segmented_matmul(
y: torch.Tensor,
x: torch.Tensor,
w: List[torch.Tensor],
b: List[torch.Tensor],
s_start: torch.IntTensor,
s_end: torch.IntTensor,
):
for i in range(len(w)):
if s_end[i] - s_start[i] <= 0:
continue
xi = x[s_start[i] : s_end[i]]
wi = w[i]
bi = b[i]
y[s_start[i] : s_end[i]] = F.linear(xi, wi, bi)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment