Unverified Commit 04e1af94 authored by drbh's avatar drbh Committed by GitHub
Browse files

Enable multiple LoRa adapters (#2010)



* feat: first draft load multiple lora

* feat: load weights within layer and refactor lora pass

* fix: refactor and reduce lora math

* feat: baseline impl single request multi lora support

* feat: prefer lorax implementation and port loading logic

* fix: prefer adapter_data and refactors

* feat: perfer loraxs custom punica kernels and add mlp loras

* fix: adjust batch for bgmv

* fix: adjust adapter_segments logic when in batch

* fix: refactor and move changes to v3 proto

* fix: pass model_id for all flash causal lms

* fix: pass model_id for all causal and seq2seq lms

* fix: add model_id to model test

* feat: add lora support to mistral and refactors

* feat: prefer model id in request

* fix: include rust code for adapter id

* feat: bump launcher and add new lora docs

* feat: support base model generation and refactors

* fix: rename doc to retry ci build

* feat: support if vlm models

* fix: add adapter_data param and avoid missing layers

* fix: add adapter_data param to phi and neox

* fix: update all models forwards to include adapter_data

* fix: add model_id to IdeficsCausalLM

* Update lora.md

Fixed a typo

* Update lora.md

Fixing spam image

* fix: add lora kernel to dockerfile, support running without kernels and refactors

* fix: avoid dockerfile conflict

* fix: refactors and adjust flash llama lora logic

* fix: skip llama test due to CI issue (temp)

* fix: skip llama test CI (temp) 2

* fix: revert skips and prefer updated ci token for tests

* fix: refactors and helpful comments

* fix: add noop in TensorParallelAdapterRowLinear too

* fix: refactor and move shard_lora_weights logic

* fix: exit early if no adapter_data

---------
Co-authored-by: default avatarDerek <datavistics@gmail.com>
parent a2a97b05
...@@ -145,6 +145,13 @@ COPY server/marlin/ . ...@@ -145,6 +145,13 @@ COPY server/marlin/ .
# Build specific version of transformers # Build specific version of transformers
RUN TORCH_CUDA_ARCH_LIST="8.0;8.6+PTX" python setup.py build RUN TORCH_CUDA_ARCH_LIST="8.0;8.6+PTX" python setup.py build
# Build Lorax Punica kernels
FROM kernel-builder as lorax-punica-builder
WORKDIR /usr/src
COPY server/Makefile-lorax-punica Makefile
# Build specific version of transformers
RUN TORCH_CUDA_ARCH_LIST="8.0;8.6+PTX" make build-lorax-punica
# Build Transformers CUDA kernels # Build Transformers CUDA kernels
FROM kernel-builder as custom-kernels-builder FROM kernel-builder as custom-kernels-builder
WORKDIR /usr/src WORKDIR /usr/src
...@@ -215,6 +222,7 @@ COPY --from=awq-kernels-builder /usr/src/llm-awq/awq/kernels/build/lib.linux-x86 ...@@ -215,6 +222,7 @@ COPY --from=awq-kernels-builder /usr/src/llm-awq/awq/kernels/build/lib.linux-x86
COPY --from=eetq-kernels-builder /usr/src/eetq/build/lib.linux-x86_64-cpython-310 /opt/conda/lib/python3.10/site-packages COPY --from=eetq-kernels-builder /usr/src/eetq/build/lib.linux-x86_64-cpython-310 /opt/conda/lib/python3.10/site-packages
# Copy build artifacts from marlin kernels builder # Copy build artifacts from marlin kernels builder
COPY --from=marlin-kernels-builder /usr/src/build/lib.linux-x86_64-cpython-310 /opt/conda/lib/python3.10/site-packages COPY --from=marlin-kernels-builder /usr/src/build/lib.linux-x86_64-cpython-310 /opt/conda/lib/python3.10/site-packages
COPY --from=lorax-punica-builder /usr/src/lorax-punica/server/punica_kernels/build/lib.linux-x86_64-cpython-310 /opt/conda/lib/python3.10/site-packages
# Copy builds artifacts from vllm builder # Copy builds artifacts from vllm builder
COPY --from=vllm-builder /usr/src/vllm/build/lib.linux-x86_64-cpython-310 /opt/conda/lib/python3.10/site-packages COPY --from=vllm-builder /usr/src/vllm/build/lib.linux-x86_64-cpython-310 /opt/conda/lib/python3.10/site-packages
...@@ -266,4 +274,4 @@ COPY ./tgi-entrypoint.sh /tgi-entrypoint.sh ...@@ -266,4 +274,4 @@ COPY ./tgi-entrypoint.sh /tgi-entrypoint.sh
RUN chmod +x /tgi-entrypoint.sh RUN chmod +x /tgi-entrypoint.sh
ENTRYPOINT ["/tgi-entrypoint.sh"] ENTRYPOINT ["/tgi-entrypoint.sh"]
CMD ["--json-output"] # CMD ["--json-output"]
...@@ -157,6 +157,7 @@ async fn prefill( ...@@ -157,6 +157,7 @@ async fn prefill(
top_n_tokens: top_n_tokens.unwrap_or(0), top_n_tokens: top_n_tokens.unwrap_or(0),
blocks: vec![], blocks: vec![],
slots: vec![], slots: vec![],
adapter_id: None,
}) })
.collect(); .collect();
......
...@@ -60,6 +60,9 @@ ...@@ -60,6 +60,9 @@
- local: conceptual/speculation - local: conceptual/speculation
title: Speculation (Medusa, ngram) title: Speculation (Medusa, ngram)
- local: conceptual/guidance - local: conceptual/guidance
title: How Guidance Works (via outlines) title: How Guidance Works (via outlines
- local: conceptual/lora
title: LoRA (Low-Rank Adaptation)
title: Conceptual Guides title: Conceptual Guides
...@@ -416,6 +416,14 @@ Options: ...@@ -416,6 +416,14 @@ Options:
[env: MAX_CLIENT_BATCH_SIZE=] [env: MAX_CLIENT_BATCH_SIZE=]
[default: 4] [default: 4]
```
## LORA_ADAPTERS
```shell
--lora-adapters <LORA_ADAPTERS>
Lora Adapters a list of adapter ids i.e. `repo/adapter1,repo/adapter2` to load during startup that will be available to callers via the `adapter_id` field in a request
[env: LORA_ADAPTERS=]
``` ```
## HELP ## HELP
```shell ```shell
......
# LoRA (Low-Rank Adaptation)
## What is LoRA?
LoRA is a technique that allows for efficent fine-tuning a model while only updating a small portion of the model's weights. This is useful when you have a large model that has been pre-trained on a large dataset, but you want to fine-tune it on a smaller dataset or for a specific task.
LoRA works by adding a small number of additional weights to the model, which are used to adapt the model to the new dataset or task. These additional weights are learned during the fine-tuning process, while the rest of the model's weights are kept fixed.
## How is it used?
LoRA can be used in many ways and the community is always finding new ways to use it. Here are some examples of how you can use LoRA:
Technically, LoRA can be used to fine-tune a large language model on a small dataset. However, these use cases can span a wide range of applications, such as:
- fine-tuning a language model on a small dataset
- fine-tuning a language model on a domain-specific dataset
- fine-tuning a language model on a dataset with limited labels
## Optimizing Inference with LoRA
LoRA's can be used during inference by mutliplying the adapter weights with the model weights at each specified layer. This process can be computationally expensive, but due to awesome work by [punica-ai](https://github.com/punica-ai/punica) and the [lorax](https://github.com/predibase/lorax) team, optimized kernels/and frameworks have been developed to make this process more efficient. TGI leverages these optimizations in order to provide fast and efficient inference with mulitple LoRA models.
## Serving multiple LoRA adapters with TGI
Once a LoRA model has been trained, it can be used to generate text or perform other tasks just like a regular language model. However, because the model has been fine-tuned on a specific dataset, it may perform better on that dataset than a model that has not been fine-tuned.
In practice its often useful to have multiple LoRA models, each fine-tuned on a different dataset or for a different task. This allows you to use the model that is best suited for a particular task or dataset.
Text Generation Inference (TGI) now supports loading multiple LoRA models at startup that can be used in generation requests. This feature is available starting from version `~2.0.6` and is compatible with LoRA models trained using the `peft` library.
### Specifying LoRA models
To use LoRA in TGI, when starting the server, you can specify the list of LoRA models to load using the `LORA_ADAPTERS` environment variable. For example:
```bash
LORA_ADAPTERS=predibase/customer_support,predibase/dbpedia
```
In the server logs, you will see the following message:
```txt
Loading adapter weights into model: predibase/customer_support
Loading adapter weights into model: predibase/dbpedia
```
## Generate text
You can then use these models in generation requests by specifying the `lora_model` parameter in the request payload. For example:
```json
curl 127.0.0.1:3000/generate \
-X POST \
-H 'Content-Type: application/json' \
-d '{
"inputs": "Hello who are you?",
"parameters": {
"max_new_tokens": 40,
"adapter_id": "predibase/customer_support"
}
}'
```
> **Note:** The Lora feature is new and still being improved. If you encounter any issues or have any feedback, please let us know by opening an issue on the [GitHub repository](https://github.com/huggingface/text-generation-inference/issues/new/choose). Additionally documentation and an improved client library will be published soon.
An updated tutorial with detailed examples will be published soon. Stay tuned!
...@@ -452,6 +452,11 @@ struct Args { ...@@ -452,6 +452,11 @@ struct Args {
/// Control the maximum number of inputs that a client can send in a single request /// Control the maximum number of inputs that a client can send in a single request
#[clap(default_value = "4", long, env)] #[clap(default_value = "4", long, env)]
max_client_batch_size: usize, max_client_batch_size: usize,
/// Lora Adapters a list of adapter ids i.e. `repo/adapter1,repo/adapter2` to load during
/// startup that will be available to callers via the `adapter_id` field in a request.
#[clap(long, env)]
lora_adapters: Option<String>,
} }
#[derive(Debug)] #[derive(Debug)]
...@@ -485,6 +490,7 @@ fn shard_manager( ...@@ -485,6 +490,7 @@ fn shard_manager(
max_total_tokens: usize, max_total_tokens: usize,
max_batch_size: Option<usize>, max_batch_size: Option<usize>,
max_input_tokens: usize, max_input_tokens: usize,
lora_adapters: Option<String>,
otlp_endpoint: Option<String>, otlp_endpoint: Option<String>,
otlp_service_name: String, otlp_service_name: String,
log_level: LevelFilter, log_level: LevelFilter,
...@@ -620,6 +626,11 @@ fn shard_manager( ...@@ -620,6 +626,11 @@ fn shard_manager(
envs.push(("MAX_BATCH_SIZE".into(), max_batch_size.to_string().into())); envs.push(("MAX_BATCH_SIZE".into(), max_batch_size.to_string().into()));
} }
// Lora Adapters
if let Some(lora_adapters) = lora_adapters {
envs.push(("LORA_ADAPTERS".into(), lora_adapters.into()));
}
// If huggingface_hub_cache is some, pass it to the shard // If huggingface_hub_cache is some, pass it to the shard
// Useful when running inside a docker container // Useful when running inside a docker container
if let Some(huggingface_hub_cache) = huggingface_hub_cache { if let Some(huggingface_hub_cache) = huggingface_hub_cache {
...@@ -1060,6 +1071,7 @@ fn spawn_shards( ...@@ -1060,6 +1071,7 @@ fn spawn_shards(
let rope_scaling = args.rope_scaling; let rope_scaling = args.rope_scaling;
let rope_factor = args.rope_factor; let rope_factor = args.rope_factor;
let max_batch_size = args.max_batch_size; let max_batch_size = args.max_batch_size;
let lora_adapters = args.lora_adapters.clone();
thread::spawn(move || { thread::spawn(move || {
shard_manager( shard_manager(
model_id, model_id,
...@@ -1085,6 +1097,7 @@ fn spawn_shards( ...@@ -1085,6 +1097,7 @@ fn spawn_shards(
max_total_tokens, max_total_tokens,
max_batch_size, max_batch_size,
max_input_tokens, max_input_tokens,
lora_adapters,
otlp_endpoint, otlp_endpoint,
otlp_service_name, otlp_service_name,
max_log_level, max_log_level,
......
...@@ -134,6 +134,8 @@ message Request { ...@@ -134,6 +134,8 @@ message Request {
repeated uint32 blocks = 9; repeated uint32 blocks = 9;
/// Paged attention slots /// Paged attention slots
repeated uint32 slots = 10; repeated uint32 slots = 10;
/// LORA adapter index
optional string adapter_id = 11;
} }
message Batch { message Batch {
......
...@@ -177,6 +177,7 @@ impl Client { ...@@ -177,6 +177,7 @@ impl Client {
}), }),
prefill_logprobs: true, prefill_logprobs: true,
top_n_tokens: 20, top_n_tokens: 20,
adapter_id: None,
}); });
n_tokens += max_input_length; n_tokens += max_input_length;
......
...@@ -244,6 +244,7 @@ impl Health for ShardedClient { ...@@ -244,6 +244,7 @@ impl Health for ShardedClient {
// Block 0 is reserved for health checks // Block 0 is reserved for health checks
blocks: vec![0], blocks: vec![0],
slots: (0..16).collect(), slots: (0..16).collect(),
adapter_id: None,
}; };
let batch = Batch { let batch = Batch {
id: u64::MAX, id: u64::MAX,
......
...@@ -429,6 +429,7 @@ mod tests { ...@@ -429,6 +429,7 @@ mod tests {
stop_sequences: vec![], stop_sequences: vec![],
}, },
top_n_tokens: 0, top_n_tokens: 0,
adapter_id: None,
}, },
response_tx, response_tx,
span: info_span!("entry"), span: info_span!("entry"),
......
...@@ -351,6 +351,7 @@ impl State { ...@@ -351,6 +351,7 @@ impl State {
top_n_tokens: entry.request.top_n_tokens, top_n_tokens: entry.request.top_n_tokens,
blocks, blocks,
slots, slots,
adapter_id: entry.request.adapter_id.clone(),
}); });
// Set batch_time // Set batch_time
entry.batch_time = Some(Instant::now()); entry.batch_time = Some(Instant::now());
...@@ -491,6 +492,7 @@ mod tests { ...@@ -491,6 +492,7 @@ mod tests {
stop_sequences: vec![], stop_sequences: vec![],
}, },
top_n_tokens: 0, top_n_tokens: 0,
adapter_id: None,
}, },
response_tx, response_tx,
span: info_span!("entry"), span: info_span!("entry"),
......
...@@ -302,6 +302,11 @@ pub(crate) struct GenerateParameters { ...@@ -302,6 +302,11 @@ pub(crate) struct GenerateParameters {
#[serde(default)] #[serde(default)]
#[schema(nullable = true, default = "null", example = "null")] #[schema(nullable = true, default = "null", example = "null")]
pub grammar: Option<GrammarType>, pub grammar: Option<GrammarType>,
/// Lora adapter id
#[serde(default)]
#[schema(nullable = true, default = "null", example = "null")]
pub adapter_id: Option<String>,
} }
fn default_max_new_tokens() -> Option<u32> { fn default_max_new_tokens() -> Option<u32> {
...@@ -328,6 +333,7 @@ fn default_parameters() -> GenerateParameters { ...@@ -328,6 +333,7 @@ fn default_parameters() -> GenerateParameters {
seed: None, seed: None,
top_n_tokens: None, top_n_tokens: None,
grammar: None, grammar: None,
adapter_id: None,
} }
} }
......
...@@ -673,6 +673,7 @@ async fn completions( ...@@ -673,6 +673,7 @@ async fn completions(
seed, seed,
top_n_tokens: None, top_n_tokens: None,
grammar: None, grammar: None,
..Default::default()
}, },
}) })
.collect(); .collect();
...@@ -1115,6 +1116,7 @@ async fn chat_completions( ...@@ -1115,6 +1116,7 @@ async fn chat_completions(
seed, seed,
top_n_tokens: req.top_logprobs, top_n_tokens: req.top_logprobs,
grammar, grammar,
..Default::default()
}, },
}; };
......
...@@ -202,6 +202,7 @@ impl Validation { ...@@ -202,6 +202,7 @@ impl Validation {
decoder_input_details, decoder_input_details,
top_n_tokens, top_n_tokens,
grammar, grammar,
adapter_id,
.. ..
} = request.parameters; } = request.parameters;
...@@ -383,6 +384,7 @@ impl Validation { ...@@ -383,6 +384,7 @@ impl Validation {
parameters, parameters,
stopping_parameters, stopping_parameters,
top_n_tokens, top_n_tokens,
adapter_id,
}) })
} }
...@@ -678,6 +680,7 @@ pub(crate) struct ValidGenerateRequest { ...@@ -678,6 +680,7 @@ pub(crate) struct ValidGenerateRequest {
pub parameters: ValidParameters, pub parameters: ValidParameters,
pub stopping_parameters: ValidStoppingParameters, pub stopping_parameters: ValidStoppingParameters,
pub top_n_tokens: u32, pub top_n_tokens: u32,
pub adapter_id: Option<String>,
} }
#[derive(Error, Debug)] #[derive(Error, Debug)]
......
...@@ -4,6 +4,7 @@ include Makefile-vllm ...@@ -4,6 +4,7 @@ include Makefile-vllm
include Makefile-awq include Makefile-awq
include Makefile-eetq include Makefile-eetq
include Makefile-selective-scan include Makefile-selective-scan
include Makefile-lorax-punica
unit-tests: unit-tests:
pytest -s -vv -m "not private" tests pytest -s -vv -m "not private" tests
......
lorax_punica_commit := c71861a653412267dc27ec86013dd945ce3474bc
build-lorax-punica:
if [ ! -d 'lorax-punica' ]; then \
git clone --no-checkout https://github.com/predibase/lorax.git lorax-punica; \
fi
cd lorax-punica && git sparse-checkout set server/punica_kernels && git checkout $(lorax_punica_commit)
cd lorax-punica && git submodule update --init --recursive
cd lorax-punica/server/punica_kernels && python setup.py build
install-lorax-punica: build-lorax-punica
cd lorax-punica/server/punica_kernels && python setup.py install
...@@ -17,7 +17,12 @@ def get_test_model(): ...@@ -17,7 +17,12 @@ def get_test_model():
tokenizer = AutoTokenizer.from_pretrained("huggingface/llama-7b") tokenizer = AutoTokenizer.from_pretrained("huggingface/llama-7b")
model = TestModel( model = TestModel(
torch.nn.Linear(1, 1), tokenizer, False, torch.float32, torch.device("cpu") "test_model_id",
torch.nn.Linear(1, 1),
tokenizer,
False,
torch.float32,
torch.device("cpu"),
) )
return model return model
......
# Origin: https://github.com/predibase/lorax
# Path: lorax/server/lorax_server/adapters/__init__.py
# License: Apache License Version 2.0, January 2004
from text_generation_server.adapters.weights import (
AdapterBatchData,
AdapterBatchMetadata,
)
__all__ = [
"AdapterBatchData",
"AdapterBatchMetadata",
]
# Origin: https://github.com/predibase/lorax
# Path: lorax/server/lorax_server/adapters/config.py
# License: Apache License Version 2.0, January 2004
from abc import ABC, abstractmethod
from dataclasses import dataclass
from typing import TYPE_CHECKING, Dict, Optional, Set, Tuple
import torch
from text_generation_server.adapters.weights import AdapterWeights
if TYPE_CHECKING:
from text_generation_server.models.model import Model
@dataclass
class ModuleMap:
module_name: str
module_weights: Dict[str, Tuple[torch.Tensor, str]]
@dataclass
class AdapterConfig(ABC):
base_model_name_or_path: str
@abstractmethod
def map_weights_for_model(
self,
adapter_weights: Dict[int, AdapterWeights],
weight_names: Tuple[str],
) -> Tuple[ModuleMap, Set[str]]:
pass
@abstractmethod
def load_batched_adapter_weights(
self,
model: "Model",
module_map: ModuleMap,
layer_type: str,
unused_weight_names: Set[str],
dynamic: bool,
) -> Optional[AdapterWeights]:
pass
# Origin: https://github.com/predibase/lorax
# Path: lorax/server/lorax_server/adapters/lora.py
# License: Apache License Version 2.0, January 2004
from collections import defaultdict
from dataclasses import dataclass
from typing import TYPE_CHECKING, Dict, List, Optional, Set, Tuple, Type, Union
import torch
from peft import LoraConfig as _LoraConfig
from torch.distributed import ProcessGroup
from text_generation_server.adapters.config import AdapterConfig, ModuleMap
from text_generation_server.adapters.weights import (
AdapterBatchMetadata,
AdapterWeights,
BatchAdapterWeights,
)
from text_generation_server.utils.sgmv import (
BGMV_MAX_RANK,
MAX_RANK_CUSTOM,
get_tmp_tensors,
orient_for_rank,
pad_rank,
use_cutlass_shrink,
)
if TYPE_CHECKING:
from text_generation_server.models.model import Model
def get_start_stop_idxs_for_rank(offset, size, rank, world_size):
block_size = size // world_size
start = offset + rank * block_size
stop = offset + (rank + 1) * block_size
return start, stop
def shard_on_dim(
t: torch.Tensor, dim: int, process_group: torch.distributed.ProcessGroup
):
world_size = process_group.size()
rank = process_group.rank()
size = t.shape[dim]
start, stop = get_start_stop_idxs_for_rank(0, size, rank, world_size)
if dim == 0:
tensor = t[start:stop]
elif dim == 1:
tensor = t[:, start:stop]
else:
raise NotImplementedError("Let's make that generic when needed")
return tensor
def shard_lora_weights(
weights_a: List[torch.Tensor],
weights_b: List[torch.Tensor],
split_dim: int,
process_group: ProcessGroup,
) -> Tuple[List[torch.Tensor], List[torch.Tensor]]:
# [hidden_size, r]
weights_a = [
shard_on_dim(w, dim=split_dim, process_group=process_group) for w in weights_a
]
# [r, hidden_size]
weights_b = [shard_on_dim(w, dim=1, process_group=process_group) for w in weights_b]
return weights_a, weights_b
@dataclass
class LoraConfig(AdapterConfig):
r: int
target_modules: Optional[Union[List[str], str]]
fan_in_fan_out: bool
lora_alpha: int
use_rslora: bool
def map_weights_for_model(
self,
adapter_weights: Dict[int, AdapterWeights],
weight_names: Tuple[str],
) -> Tuple[ModuleMap, Set[str]]:
adapter_weight_names = set()
module_map = {}
for weight_name in weight_names:
lora_a_name = f"base_model.model.{weight_name}.lora_A.weight"
lora_b_name = f"base_model.model.{weight_name}.lora_B.weight"
if lora_a_name not in adapter_weights or lora_b_name not in adapter_weights:
continue
module_map[weight_name] = {
"lora_A": (adapter_weights[lora_a_name], lora_a_name),
"lora_B": (adapter_weights[lora_b_name], lora_b_name),
}
adapter_weight_names.add(lora_a_name)
adapter_weight_names.add(lora_b_name)
return module_map, adapter_weight_names
def load_batched_adapter_weights(
self,
model: "Model",
module_map: Dict[str, Dict],
layer_type: str,
unused_weight_names: Set[str],
dynamic: bool,
) -> Optional[AdapterWeights]:
return LoraWeights.load(
self,
model,
module_map,
layer_type,
unused_weight_names,
)
@classmethod
def load(cls, adapter_id: str, api_token: str) -> "LoraConfig":
hf_config = _LoraConfig.from_pretrained(adapter_id, token=api_token)
return cls(
base_model_name_or_path=hf_config.base_model_name_or_path,
r=hf_config.r,
target_modules=hf_config.target_modules,
fan_in_fan_out=hf_config.fan_in_fan_out,
lora_alpha=hf_config.lora_alpha,
use_rslora=(
hf_config.use_rslora if hasattr(hf_config, "use_rslora") else False
),
)
class LoraWeights(AdapterWeights):
"""LoRA weights for a single adapter merged across all layers."""
def __init__(
self,
weights_a: List[torch.Tensor],
weights_b: List[torch.Tensor],
adapter_config: LoraConfig,
):
self.lora_a_r = weights_a[0].size(1) if len(weights_a) > 0 else 1
self.lora_b_r = weights_b[0].size(0) if len(weights_a) > 0 else 1
self._use_cutlass_shrink = use_cutlass_shrink(self.lora_a_r)
self._is_transposed = False
# [num_layers, hidden_size, r]
weights_a = [orient_for_rank(w, w.size(1)).contiguous() for w in weights_a]
self._weights_a = torch.stack(weights_a)
# [num_layers, r, hidden_size]
self._weights_b = torch.stack(weights_b)
self.adapter_config = adapter_config
@property
def weights_a(self) -> torch.Tensor:
if self._is_transposed:
self._transpose_weights()
return self._weights_a
@property
def weights_b(self) -> torch.Tensor:
if self._is_transposed:
self._transpose_weights()
return self._weights_b
@property
def weights_a_t(self) -> torch.Tensor:
if not self._is_transposed:
self._transpose_weights()
return self._weights_a
@property
def weights_b_t(self) -> torch.Tensor:
if not self._is_transposed:
self._transpose_weights()
return self._weights_b
def _transpose_weights(self):
if self._use_cutlass_shrink:
# If we're not using the cutlass shrink, then both SGMV and BGMV use the same orientation
self._weights_a = self._weights_a.transpose(1, 2).contiguous()
self._weights_b = self._weights_b.transpose(1, 2).contiguous()
self._is_transposed = not self._is_transposed
@classmethod
def get_batch_types(cls) -> List[Type[BatchAdapterWeights]]:
return [BatchLoraWeights]
@classmethod
def load(
cls,
config: LoraConfig,
model: "Model",
module_map: Dict[str, Dict],
layer_type: str,
unused_weight_names: Set[str],
) -> Optional[AdapterWeights]:
nlayers = model.get_num_layers_for_type(layer_type)
lora_a_list = [None] * nlayers
lora_b_list = [None] * nlayers
for layer_id in range(nlayers):
key = (layer_id, layer_type)
weight_name, layer = model.target_to_layer[key]
base_weight = layer.base_layer.linear.weight
base_device = base_weight.device
if weight_name not in module_map:
# There is no LoRA weight for this layer type in the adapter
return None
lora_a, lora_a_name = module_map[weight_name]["lora_A"]
lora_a = lora_a.to(base_device, model.dtype)
lora_b, lora_b_name = module_map[weight_name]["lora_B"]
lora_b = lora_b.to(base_device, model.dtype)
scale = get_scaling_factor(
config.lora_alpha,
config.r,
uses_rslora=config.use_rslora,
)
unused_weight_names.discard(lora_a_name)
unused_weight_names.discard(lora_b_name)
# Merge scaling factor into lora_b due to associativity of matrix multiplication:
# (A * B) * C = A * (B * C)
lora_a_list[layer_id] = lora_a.transpose(0, 1)
lora_b_list[layer_id] = lora_b.transpose(0, 1) * scale
# pad lora ranks to be compatible with sgmv
lora_a_list = [
pad_rank(w, dim=1, world_size=model.world_size) for w in lora_a_list
]
lora_b_list = [
pad_rank(w, dim=0, world_size=model.world_size) for w in lora_b_list
]
if lora_a_list:
# update rank if it was padded
padded_rank = lora_a_list[0].size(1)
config.r = padded_rank
return LoraWeights(
*shard_lora_weights(
weights_a=lora_a_list,
weights_b=lora_b_list,
split_dim=0 if model.is_row_parallel(layer_type) else 1,
process_group=model.process_group,
),
config,
)
@dataclass
class RankSegments:
rank: int
lora_a_ptr: torch.Tensor
lora_b_ptr: torch.Tensor
# prefill (sgmv)
tmp_shrink: torch.Tensor
tmp_expand: torch.Tensor
segment_starts: torch.Tensor
segment_ends: torch.Tensor
# decode (bgmv)
indices: torch.Tensor
@dataclass
class BatchLoraWeights(BatchAdapterWeights):
lora_a: Dict[int, torch.Tensor]
lora_b: Dict[int, torch.Tensor]
adapter_index_configs: Dict[int, LoraConfig]
rank_data: Dict[int, RankSegments]
use_sgmv: bool
def has_adapter(self, adapter_index: int) -> bool:
return adapter_index in self.adapter_index_configs
def can_vectorize(self, pg: ProcessGroup) -> bool:
return all(
rank_data.rank // pg.size() <= MAX_RANK_CUSTOM
for rank_data in self.rank_data.values()
)
@classmethod
def key(cls) -> str:
return "lora"
@classmethod
def load(
self,
adapter_weights: Dict[int, AdapterWeights],
meta: AdapterBatchMetadata,
prefill: bool,
prefill_head_indices: Optional[torch.Tensor],
) -> Optional["BatchLoraWeights"]:
adapter_weights = {k: _convert_lora(v) for k, v in adapter_weights.items()}
adapter_weights = {
k: v for k, v in adapter_weights.items() if isinstance(v, LoraWeights)
}
if not adapter_weights:
return None
first_weights = next(iter(adapter_weights.values()))
device = first_weights.weights_a.device
segment_indices = meta.segment_indices
lora_a = {
idx: adapter_weights[idx].weights_a
for idx in segment_indices
if idx in adapter_weights
}
lora_b = {
idx: adapter_weights[idx].weights_b
for idx in segment_indices
if idx in adapter_weights
}
max_rank = max(
(
adapter_weights[idx].lora_a_r
for idx in segment_indices
if idx in adapter_weights
),
default=0,
)
if prefill or max_rank > BGMV_MAX_RANK:
use_sgmv = True
lora_a_ptr = torch.tensor(
[
(
adapter_weights[idx].weights_a.data_ptr()
if idx in adapter_weights
else 0
)
for idx in segment_indices
],
dtype=torch.int64,
device=device,
)
lora_b_ptr = torch.tensor(
[
(
adapter_weights[idx].weights_b.data_ptr()
if idx in adapter_weights
else 0
)
for idx in segment_indices
],
dtype=torch.int64,
device=device,
)
else:
use_sgmv = False
lora_a_ptr = torch.tensor(
[
(
adapter_weights[idx].weights_a_t.data_ptr()
if idx in adapter_weights
else 0
)
for idx in segment_indices
],
dtype=torch.int64,
device=device,
)
lora_b_ptr = torch.tensor(
[
(
adapter_weights[idx].weights_b_t.data_ptr()
if idx in adapter_weights
else 0
)
for idx in segment_indices
],
dtype=torch.int64,
device=device,
)
adapter_index_configs = {
idx: adapter_weights[idx].adapter_config
for idx in segment_indices
if idx in adapter_weights
}
adapter_to_segment = {v: k for k, v in enumerate(segment_indices)}
rank_indices = defaultdict(list)
for segment_idx, adapter_idx in enumerate(segment_indices):
if adapter_idx not in adapter_weights:
continue
rank_indices[adapter_weights[adapter_idx].lora_a_r].append(segment_idx)
if prefill_head_indices is not None:
j, prefill_head_segment_starts, prefill_head_segment_ends = 1, [0], [0]
for head_index in prefill_head_indices:
# j cannot go out of bounds as that would mean there are tokens without corresponding adapters
if head_index < meta.adapter_segments[j]:
prefill_head_segment_ends[-1] += 1
else:
prefill_head_segment_starts.append(prefill_head_segment_ends[-1])
prefill_head_segment_ends.append(prefill_head_segment_ends[-1] + 1)
j += 1
rank_data = {}
for rank, indices in rank_indices.items():
tmp_shrink = None
tmp_expand = None
segment_starts = None
segment_ends = None
batch_indices = None
if use_sgmv:
lora_a_ptr_indices = lora_a_ptr[indices]
tmp_shrink, tmp_expand = get_tmp_tensors(
lora_a_ptr_indices.size(0), rank, device
)
segment_starts = meta.adapter_segments[indices]
segment_ends = meta.adapter_segments[[i + 1 for i in indices]]
if prefill_head_indices is not None:
for i, segment_index in enumerate(indices):
segment_starts[i] = prefill_head_segment_starts[segment_index]
segment_ends[i] = prefill_head_segment_ends[segment_index]
else:
rank_indices = set(indices)
batch_indices = [
adapter_to_segment[idx] for idx in meta.adapter_indices.tolist()
]
batch_indices = [
idx if idx in rank_indices else -1 for idx in batch_indices
]
batch_indices = torch.tensor(
batch_indices, dtype=torch.int64, device=device
)
rank_data[rank] = RankSegments(
rank=rank,
tmp_shrink=tmp_shrink,
tmp_expand=tmp_expand,
lora_a_ptr=lora_a_ptr[indices],
lora_b_ptr=lora_b_ptr[indices],
segment_starts=segment_starts,
segment_ends=segment_ends,
indices=batch_indices,
)
return BatchLoraWeights(
lora_a=lora_a,
lora_b=lora_b,
adapter_index_configs=adapter_index_configs,
rank_data=rank_data,
use_sgmv=use_sgmv,
)
def get_scaling_factor(
lora_alpha: int,
r: int,
uses_rslora: bool = False,
) -> float:
"""Computes the scaling factor for the lora weights."""
if uses_rslora:
return lora_alpha / (r**0.5)
return lora_alpha / r
def _convert_lora(v: AdapterWeights) -> AdapterWeights:
if hasattr(v, "lora_weights"):
return v.lora_weights
return v
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment