Unverified Commit 04e1af94 authored by drbh's avatar drbh Committed by GitHub
Browse files

Enable multiple LoRa adapters (#2010)



* feat: first draft load multiple lora

* feat: load weights within layer and refactor lora pass

* fix: refactor and reduce lora math

* feat: baseline impl single request multi lora support

* feat: prefer lorax implementation and port loading logic

* fix: prefer adapter_data and refactors

* feat: perfer loraxs custom punica kernels and add mlp loras

* fix: adjust batch for bgmv

* fix: adjust adapter_segments logic when in batch

* fix: refactor and move changes to v3 proto

* fix: pass model_id for all flash causal lms

* fix: pass model_id for all causal and seq2seq lms

* fix: add model_id to model test

* feat: add lora support to mistral and refactors

* feat: prefer model id in request

* fix: include rust code for adapter id

* feat: bump launcher and add new lora docs

* feat: support base model generation and refactors

* fix: rename doc to retry ci build

* feat: support if vlm models

* fix: add adapter_data param and avoid missing layers

* fix: add adapter_data param to phi and neox

* fix: update all models forwards to include adapter_data

* fix: add model_id to IdeficsCausalLM

* Update lora.md

Fixed a typo

* Update lora.md

Fixing spam image

* fix: add lora kernel to dockerfile, support running without kernels and refactors

* fix: avoid dockerfile conflict

* fix: refactors and adjust flash llama lora logic

* fix: skip llama test due to CI issue (temp)

* fix: skip llama test CI (temp) 2

* fix: revert skips and prefer updated ci token for tests

* fix: refactors and helpful comments

* fix: add noop in TensorParallelAdapterRowLinear too

* fix: refactor and move shard_lora_weights logic

* fix: exit early if no adapter_data

---------
Co-authored-by: default avatarDerek <datavistics@gmail.com>
parent a2a97b05
...@@ -525,6 +525,7 @@ class FlashStarcoder2ForCausalLM(torch.nn.Module): ...@@ -525,6 +525,7 @@ class FlashStarcoder2ForCausalLM(torch.nn.Module):
max_s: int, max_s: int,
prefill_cache_indices: Optional[torch.Tensor], prefill_cache_indices: Optional[torch.Tensor],
lm_head_indices: Optional[torch.Tensor] = None, lm_head_indices: Optional[torch.Tensor] = None,
adapter_data: Optional[torch.Tensor] = None,
) -> torch.Tensor: ) -> torch.Tensor:
true_max_s = max_s true_max_s = max_s
if prefill_cache_indices is not None: if prefill_cache_indices is not None:
......
...@@ -741,6 +741,7 @@ class Idefics2ForConditionalGeneration(nn.Module): ...@@ -741,6 +741,7 @@ class Idefics2ForConditionalGeneration(nn.Module):
pixel_attention_mask: Optional[torch.BoolTensor] = None, pixel_attention_mask: Optional[torch.BoolTensor] = None,
# Unused here # Unused here
image_sizes: Optional[torch.Tensor] = None, image_sizes: Optional[torch.Tensor] = None,
adapter_data: Optional[torch.Tensor] = None,
): ):
inputs_embeds = self.text_model.embed_tokens(input_ids) inputs_embeds = self.text_model.embed_tokens(input_ids)
if pixel_values is not None: if pixel_values is not None:
......
...@@ -178,6 +178,7 @@ class LlavaNextForConditionalGeneration(nn.Module): ...@@ -178,6 +178,7 @@ class LlavaNextForConditionalGeneration(nn.Module):
# Unused for this model # Unused for this model
pixel_attention_mask=None, pixel_attention_mask=None,
image_sizes: Optional[torch.LongTensor] = None, image_sizes: Optional[torch.LongTensor] = None,
adapter_data: Optional[torch.Tensor] = None,
): ):
inputs_embeds = self.language_model.embed_tokens(input_ids) inputs_embeds = self.language_model.embed_tokens(input_ids)
if pixel_values is not None and len(pixel_values) > 0: if pixel_values is not None and len(pixel_values) > 0:
......
...@@ -13,6 +13,7 @@ from opentelemetry import trace ...@@ -13,6 +13,7 @@ from opentelemetry import trace
from transformers import PreTrainedTokenizerBase from transformers import PreTrainedTokenizerBase
from typing import Iterable, Optional, Tuple, List, Type, Dict from typing import Iterable, Optional, Tuple, List, Type, Dict
from text_generation_server.adapters import AdapterBatchData, AdapterBatchMetadata
from huggingface_hub.constants import HUGGINGFACE_HUB_CACHE from huggingface_hub.constants import HUGGINGFACE_HUB_CACHE
from text_generation_server.utils.chunks import concat_text_chunks from text_generation_server.utils.chunks import concat_text_chunks
from text_generation_server.utils.import_utils import SYSTEM from text_generation_server.utils.import_utils import SYSTEM
...@@ -31,6 +32,7 @@ from text_generation_server.models.globals import MEM_POOL, CUDA_GRAPHS ...@@ -31,6 +32,7 @@ from text_generation_server.models.globals import MEM_POOL, CUDA_GRAPHS
import text_generation_server.models.globals as tgi_globals import text_generation_server.models.globals as tgi_globals
from text_generation_server.utils import StoppingCriteria, HeterogeneousNextTokenChooser from text_generation_server.utils import StoppingCriteria, HeterogeneousNextTokenChooser
from text_generation_server.utils.dist import MEMORY_FRACTION from text_generation_server.utils.dist import MEMORY_FRACTION
from text_generation_server.utils.segments import SegmentConcatBuilder, find_segments
from text_generation_server.utils.import_utils import ( from text_generation_server.utils.import_utils import (
empty_cache, empty_cache,
...@@ -114,6 +116,9 @@ class FlashCausalLMBatch(Batch): ...@@ -114,6 +116,9 @@ class FlashCausalLMBatch(Batch):
top_n_tokens: List[int] top_n_tokens: List[int]
top_n_tokens_tensor: torch.Tensor top_n_tokens_tensor: torch.Tensor
# Adapter metadata for each request
adapter_meta: AdapterBatchMetadata
# Number of blocks in this batch # Number of blocks in this batch
num_blocks: int num_blocks: int
# Maximum number of blocks # Maximum number of blocks
...@@ -174,6 +179,9 @@ class FlashCausalLMBatch(Batch): ...@@ -174,6 +179,9 @@ class FlashCausalLMBatch(Batch):
stopping_criterias = [] stopping_criterias = []
top_n_tokens = [] top_n_tokens = []
adapter_indices_list = []
adapter_set = set()
# Cumulative length # Cumulative length
cumulative_length = 0 cumulative_length = 0
cumulative_max_length = 0 cumulative_max_length = 0
...@@ -225,6 +233,10 @@ class FlashCausalLMBatch(Batch): ...@@ -225,6 +233,10 @@ class FlashCausalLMBatch(Batch):
stopping_criterias.append(stopping_criteria) stopping_criterias.append(stopping_criteria)
top_n_tokens.append(r.top_n_tokens) top_n_tokens.append(r.top_n_tokens)
adapter_index = tgi_globals.ADAPTER_TO_INDEX.get(r.adapter_id, 0)
adapter_indices_list.append(torch.full((input_length,), adapter_index))
adapter_set.add(adapter_index)
# Paged attention # Paged attention
# Remove one as the first token des not have a past # Remove one as the first token des not have a past
speculative_length = get_speculate() speculative_length = get_speculate()
...@@ -296,6 +308,10 @@ class FlashCausalLMBatch(Batch): ...@@ -296,6 +308,10 @@ class FlashCausalLMBatch(Batch):
max_length, input_length + max_new_tokens + speculative_length max_length, input_length + max_new_tokens + speculative_length
) )
adapter_indices = torch.cat(adapter_indices_list).to(
dtype=torch.int64, device=device
)
next_token_chooser = HeterogeneousNextTokenChooser.from_pb( next_token_chooser = HeterogeneousNextTokenChooser.from_pb(
next_token_chooser_parameters, dtype, device, tokenizer next_token_chooser_parameters, dtype, device, tokenizer
) )
...@@ -339,6 +355,11 @@ class FlashCausalLMBatch(Batch): ...@@ -339,6 +355,11 @@ class FlashCausalLMBatch(Batch):
input_lengths, dtype=torch.int32, device=device input_lengths, dtype=torch.int32, device=device
) )
adapter_segments, adapter_segment_indices = find_segments(adapter_indices)
adapter_segments = torch.tensor(
adapter_segments, dtype=torch.int32, device=device
)
if all_prefill_logprobs: if all_prefill_logprobs:
prefill_head_indices = None prefill_head_indices = None
prefill_next_token_indices = cu_seqlen_prefill[1:] - 1 prefill_next_token_indices = cu_seqlen_prefill[1:] - 1
...@@ -393,6 +414,12 @@ class FlashCausalLMBatch(Batch): ...@@ -393,6 +414,12 @@ class FlashCausalLMBatch(Batch):
top_n_tokens_tensor=top_n_tokens_tensor, top_n_tokens_tensor=top_n_tokens_tensor,
num_blocks=num_blocks, num_blocks=num_blocks,
max_blocks=max_blocks, max_blocks=max_blocks,
adapter_meta=AdapterBatchMetadata(
adapter_indices=adapter_indices,
adapter_set=adapter_set,
adapter_segments=adapter_segments,
segment_indices=adapter_segment_indices,
),
speculative_ids=None, speculative_ids=None,
) )
...@@ -443,6 +470,7 @@ class FlashCausalLMBatch(Batch): ...@@ -443,6 +470,7 @@ class FlashCausalLMBatch(Batch):
stopping_criterias = [] stopping_criterias = []
top_n_tokens = [] top_n_tokens = []
adapter_set = set()
num_blocks = 0 num_blocks = 0
max_blocks = 0 max_blocks = 0
...@@ -471,6 +499,11 @@ class FlashCausalLMBatch(Batch): ...@@ -471,6 +499,11 @@ class FlashCausalLMBatch(Batch):
top_n_tokens.append(self.top_n_tokens[idx]) top_n_tokens.append(self.top_n_tokens[idx])
adapter_index = tgi_globals.ADAPTER_TO_INDEX.get(
self.requests[idx].adapter_id, 0
)
adapter_set.add(adapter_index)
remaining_tokens = ( remaining_tokens = (
stopping_criteria.max_new_tokens - stopping_criteria.current_tokens stopping_criteria.max_new_tokens - stopping_criteria.current_tokens
) )
...@@ -498,6 +531,7 @@ class FlashCausalLMBatch(Batch): ...@@ -498,6 +531,7 @@ class FlashCausalLMBatch(Batch):
# Index into tensors # Index into tensors
input_ids = self.input_ids[indices] input_ids = self.input_ids[indices]
position_ids = self.position_ids[indices] position_ids = self.position_ids[indices]
adapter_indices = self.adapter_meta.adapter_indices[indices]
all_input_ids_tensor = self.all_input_ids_tensor[indices] all_input_ids_tensor = self.all_input_ids_tensor[indices]
block_tables_tensor = self.block_tables_tensor[indices] block_tables_tensor = self.block_tables_tensor[indices]
input_lengths_tensor = self.input_lengths_tensor[indices] input_lengths_tensor = self.input_lengths_tensor[indices]
...@@ -513,6 +547,11 @@ class FlashCausalLMBatch(Batch): ...@@ -513,6 +547,11 @@ class FlashCausalLMBatch(Batch):
# Move to GPU now that we have the whole tensor # Move to GPU now that we have the whole tensor
slot_indices = slot_indices.to(device) slot_indices = slot_indices.to(device)
adapter_segments, adapter_segment_indices = find_segments(adapter_indices)
adapter_segments = torch.tensor(
adapter_segments, dtype=torch.int32, device=device
)
return type(self)( return type(self)(
batch_id=self.batch_id, batch_id=self.batch_id,
requests=requests, requests=requests,
...@@ -543,6 +582,12 @@ class FlashCausalLMBatch(Batch): ...@@ -543,6 +582,12 @@ class FlashCausalLMBatch(Batch):
num_blocks=num_blocks, num_blocks=num_blocks,
max_blocks=max_blocks, max_blocks=max_blocks,
speculative_ids=speculative_ids, speculative_ids=speculative_ids,
adapter_meta=AdapterBatchMetadata(
adapter_indices=adapter_indices,
adapter_set=adapter_set,
adapter_segments=adapter_segments,
segment_indices=adapter_segment_indices,
),
) )
@classmethod @classmethod
...@@ -596,6 +641,14 @@ class FlashCausalLMBatch(Batch): ...@@ -596,6 +641,14 @@ class FlashCausalLMBatch(Batch):
top_n_tokens_tensor = batches[0].top_n_tokens_tensor.new_zeros( top_n_tokens_tensor = batches[0].top_n_tokens_tensor.new_zeros(
total_batch_size, total_batch_size,
) )
total_indices_size = sum(
b.adapter_meta.adapter_indices.shape[0] for b in batches
)
adapter_indices = batches[0].adapter_meta.adapter_indices.new_empty(
total_indices_size
)
adapter_set = set()
adapter_segment_builder = SegmentConcatBuilder()
start_slots = [] start_slots = []
block_tables = [] block_tables = []
...@@ -613,6 +666,7 @@ class FlashCausalLMBatch(Batch): ...@@ -613,6 +666,7 @@ class FlashCausalLMBatch(Batch):
# Cumulative length # Cumulative length
cumulative_batch_size = 0 cumulative_batch_size = 0
cumulative_slots = 0 cumulative_slots = 0
cumulative_adapter_indices_size = 0
for i, batch in enumerate(batches): for i, batch in enumerate(batches):
requests.extend(batch.requests) requests.extend(batch.requests)
...@@ -637,6 +691,21 @@ class FlashCausalLMBatch(Batch): ...@@ -637,6 +691,21 @@ class FlashCausalLMBatch(Batch):
top_n_tokens_tensor[start_index:end_index] = batch.top_n_tokens_tensor top_n_tokens_tensor[start_index:end_index] = batch.top_n_tokens_tensor
slots[slots_start_index:slots_end_index] = batch.slots slots[slots_start_index:slots_end_index] = batch.slots
# Copy over adapter indices
adapter_start_index = cumulative_adapter_indices_size
adapter_end_index = (
cumulative_adapter_indices_size
+ batch.adapter_meta.adapter_indices.shape[0]
)
adapter_indices[adapter_start_index:adapter_end_index] = (
batch.adapter_meta.adapter_indices
)
cumulative_adapter_indices_size = adapter_end_index
adapter_set.update(batch.adapter_meta.adapter_set)
adapter_segment_builder.concat(
batch.adapter_meta.adapter_segments, batch.adapter_meta.segment_indices
)
all_input_ids_tensor[ all_input_ids_tensor[
start_index:end_index, : batch.all_input_ids_tensor.shape[1] start_index:end_index, : batch.all_input_ids_tensor.shape[1]
] = batch.all_input_ids_tensor[:, :max_length] ] = batch.all_input_ids_tensor[:, :max_length]
...@@ -680,6 +749,8 @@ class FlashCausalLMBatch(Batch): ...@@ -680,6 +749,8 @@ class FlashCausalLMBatch(Batch):
else None else None
) )
adapter_segments, adapter_segment_indices = adapter_segment_builder.build()
return cls( return cls(
batch_id=batches[0].batch_id, batch_id=batches[0].batch_id,
requests=requests, requests=requests,
...@@ -710,6 +781,12 @@ class FlashCausalLMBatch(Batch): ...@@ -710,6 +781,12 @@ class FlashCausalLMBatch(Batch):
num_blocks=num_blocks, num_blocks=num_blocks,
max_blocks=max_blocks, max_blocks=max_blocks,
speculative_ids=speculative_ids, speculative_ids=speculative_ids,
adapter_meta=AdapterBatchMetadata(
adapter_indices=adapter_indices,
adapter_set=adapter_set,
adapter_segments=adapter_segments,
segment_indices=adapter_segment_indices,
),
) )
def __len__(self): def __len__(self):
...@@ -719,6 +796,7 @@ class FlashCausalLMBatch(Batch): ...@@ -719,6 +796,7 @@ class FlashCausalLMBatch(Batch):
class FlashCausalLM(Model): class FlashCausalLM(Model):
def __init__( def __init__(
self, self,
model_id: str,
model: torch.nn.Module, model: torch.nn.Module,
tokenizer: PreTrainedTokenizerBase, tokenizer: PreTrainedTokenizerBase,
num_layers: int, num_layers: int,
...@@ -738,6 +816,7 @@ class FlashCausalLM(Model): ...@@ -738,6 +816,7 @@ class FlashCausalLM(Model):
self.kv_cache = [] self.kv_cache = []
super(FlashCausalLM, self).__init__( super(FlashCausalLM, self).__init__(
model_id=model_id,
model=model, model=model,
tokenizer=tokenizer, tokenizer=tokenizer,
requires_padding=False, requires_padding=False,
...@@ -895,12 +974,13 @@ class FlashCausalLM(Model): ...@@ -895,12 +974,13 @@ class FlashCausalLM(Model):
total_cache_size = self.num_layers * cache_block_size * 2 * dtype_size total_cache_size = self.num_layers * cache_block_size * 2 * dtype_size
free_memory = get_free_memory(self.device, MEMORY_FRACTION) free_memory = get_free_memory(self.device, MEMORY_FRACTION)
batch_num_blocks = batch.num_blocks if batch is not None else 0
num_blocks = ( num_blocks = (
# Leave 5% for some wiggle room # Leave 5% for some wiggle room
int((free_memory * 0.95) // total_cache_size) int((free_memory * 0.95) // total_cache_size)
# Add batch.num_blocks as we allocated it above, so it is included in the peak memory. # Add batch.num_blocks as we allocated it above, so it is included in the peak memory.
+ batch.num_blocks + batch_num_blocks
) )
del batch del batch
...@@ -1001,7 +1081,7 @@ class FlashCausalLM(Model): ...@@ -1001,7 +1081,7 @@ class FlashCausalLM(Model):
) )
def forward( def forward(
self, batch: FlashCausalLMBatch self, batch: FlashCausalLMBatch, adapter_data: AdapterBatchData
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
# Model Forward # Model Forward
if batch.speculative_ids is not None: if batch.speculative_ids is not None:
...@@ -1080,6 +1160,7 @@ class FlashCausalLM(Model): ...@@ -1080,6 +1160,7 @@ class FlashCausalLM(Model):
max_s=max_s, max_s=max_s,
prefill_cache_indices=batch.prefill_cache_indices, prefill_cache_indices=batch.prefill_cache_indices,
lm_head_indices=lm_head_indices, lm_head_indices=lm_head_indices,
adapter_data=adapter_data,
) )
if batch.prefill_cache_indices is not None: if batch.prefill_cache_indices is not None:
batch.prefill_cache_indices = None batch.prefill_cache_indices = None
...@@ -1116,7 +1197,34 @@ class FlashCausalLM(Model): ...@@ -1116,7 +1197,34 @@ class FlashCausalLM(Model):
prefill = batch.cu_seqlen_prefill is not None prefill = batch.cu_seqlen_prefill is not None
prefill_logprobs = batch.prefill_next_token_indices is not None prefill_logprobs = batch.prefill_next_token_indices is not None
out, speculative_logits = self.forward(batch) # Update adapter indices for speculative tokens (if present)
adapter_meta = batch.adapter_meta
if batch.speculative_ids is not None:
B, speculative_length = batch.speculative_ids.shape
new_length = speculative_length + 1
adapter_indices = (
adapter_meta.adapter_indices.unsqueeze(-1)
.expand(B, new_length)
.reshape(-1)
)
adapter_segments = adapter_meta.adapter_segments * new_length
adapter_meta = AdapterBatchMetadata(
adapter_indices=adapter_indices,
adapter_set=adapter_meta.adapter_set,
adapter_segments=adapter_segments,
segment_indices=adapter_meta.segment_indices,
)
# Assign pointers to adapter weights
# TODO(travis): don't update this if indices haven't changed
adapter_data = AdapterBatchData.from_meta(
adapter_meta,
self.layer_to_adapter_weights,
prefill,
batch.prefill_head_indices,
)
out, speculative_logits = self.forward(batch, adapter_data)
if prefill: if prefill:
next_token_logits = ( next_token_logits = (
...@@ -1128,8 +1236,13 @@ class FlashCausalLM(Model): ...@@ -1128,8 +1236,13 @@ class FlashCausalLM(Model):
if prefill_logprobs if prefill_logprobs
else speculative_logits else speculative_logits
) )
next_adapter_indices = batch.adapter_meta.adapter_indices.new_empty(
len(batch)
)
else: else:
next_token_logits = out next_token_logits = out
next_adapter_indices = batch.adapter_meta.adapter_indices
speculate = get_speculate() speculate = get_speculate()
( (
...@@ -1195,6 +1308,12 @@ class FlashCausalLM(Model): ...@@ -1195,6 +1308,12 @@ class FlashCausalLM(Model):
# In decode, we do not need this as we can just increment position ids # In decode, we do not need this as we can just increment position ids
next_position_ids[i] = batch.position_ids[end_index - 1] next_position_ids[i] = batch.position_ids[end_index - 1]
# Initialize adapter indices
# In decode, we only have one token per row in the batch, so grab last index
next_adapter_indices[i] = batch.adapter_meta.adapter_indices[
end_index - 1
]
# Used to gather prefill logprobs # Used to gather prefill logprobs
# Copy batch.input_ids to prefill_token_indices # Copy batch.input_ids to prefill_token_indices
if prefill_logprobs: if prefill_logprobs:
...@@ -1220,6 +1339,16 @@ class FlashCausalLM(Model): ...@@ -1220,6 +1339,16 @@ class FlashCausalLM(Model):
batch.position_ids = next_position_ids + accepted_ids batch.position_ids = next_position_ids + accepted_ids
batch.input_lengths_tensor += accepted_ids batch.input_lengths_tensor += accepted_ids
batch.slot_indices += accepted_ids batch.slot_indices += accepted_ids
batch.adapter_meta.adapter_indices = next_adapter_indices
if prefill:
# adjust segment lengths to account for all request lengths being 1 during decoding
adapter_segments, _ = find_segments(batch.adapter_meta.adapter_indices)
batch.adapter_meta.adapter_segments = torch.tensor(
adapter_segments,
dtype=torch.int32,
device=batch.adapter_meta.adapter_segments.device,
)
if prefill and prefill_logprobs: if prefill and prefill_logprobs:
# Get prefill logprobs # Get prefill logprobs
......
...@@ -62,6 +62,7 @@ class FlashCohere(FlashCausalLM): ...@@ -62,6 +62,7 @@ class FlashCohere(FlashCausalLM):
torch.distributed.barrier(group=self.process_group) torch.distributed.barrier(group=self.process_group)
super(FlashCohere, self).__init__( super(FlashCohere, self).__init__(
model_id=model_id,
model=model, model=model,
tokenizer=tokenizer, tokenizer=tokenizer,
num_layers=len(model.model.layers), num_layers=len(model.model.layers),
......
...@@ -87,6 +87,7 @@ class FlashDbrx(FlashCausalLM): ...@@ -87,6 +87,7 @@ class FlashDbrx(FlashCausalLM):
torch.distributed.barrier(group=self.process_group) torch.distributed.barrier(group=self.process_group)
super(FlashDbrx, self).__init__( super(FlashDbrx, self).__init__(
model_id=model_id,
model=model, model=model,
tokenizer=tokenizer, tokenizer=tokenizer,
num_layers=len(model.model.layers), num_layers=len(model.model.layers),
......
...@@ -62,6 +62,7 @@ class FlashGemma(FlashCausalLM): ...@@ -62,6 +62,7 @@ class FlashGemma(FlashCausalLM):
torch.distributed.barrier(group=self.process_group) torch.distributed.barrier(group=self.process_group)
super(FlashGemma, self).__init__( super(FlashGemma, self).__init__(
model_id=model_id,
model=model, model=model,
tokenizer=tokenizer, tokenizer=tokenizer,
num_layers=len(model.model.layers), num_layers=len(model.model.layers),
......
...@@ -69,6 +69,7 @@ class FlashGPT2(FlashCausalLM): ...@@ -69,6 +69,7 @@ class FlashGPT2(FlashCausalLM):
model = FlashGPT2ForCausalLM(prefix, config, weights) model = FlashGPT2ForCausalLM(prefix, config, weights)
torch.distributed.barrier(group=self.process_group) torch.distributed.barrier(group=self.process_group)
super(FlashGPT2, self).__init__( super(FlashGPT2, self).__init__(
model_id=model_id,
model=model, model=model,
tokenizer=tokenizer, tokenizer=tokenizer,
num_layers=len(model.model.layers), num_layers=len(model.model.layers),
......
import os
import torch import torch
import torch.distributed import torch.distributed
from opentelemetry import trace from opentelemetry import trace
from transformers import AutoConfig, AutoTokenizer, GenerationConfig from transformers import AutoConfig, AutoTokenizer, GenerationConfig
from typing import Optional from typing import Optional, Tuple, Dict, List
from text_generation_server.models import FlashCausalLM from text_generation_server.models import FlashCausalLM
from text_generation_server.models.custom_modeling.flash_llama_modeling import ( from text_generation_server.models.custom_modeling.flash_llama_modeling import (
...@@ -13,12 +14,24 @@ from text_generation_server.utils import ( ...@@ -13,12 +14,24 @@ from text_generation_server.utils import (
initialize_torch_distributed, initialize_torch_distributed,
weight_files, weight_files,
Weights, Weights,
hub,
) )
tracer = trace.get_tracer(__name__) tracer = trace.get_tracer(__name__)
from text_generation_server.utils.import_utils import SYSTEM from text_generation_server.utils.import_utils import SYSTEM
ADAPTER_LAYERS = [
"q_proj",
"k_proj",
"v_proj",
"o_proj",
"gate_proj",
"up_proj",
"down_proj",
]
ROW_PARALLEL = {"o_proj", "down_proj", "lm_head"}
class FlashLlama(FlashCausalLM): class FlashLlama(FlashCausalLM):
def __init__( def __init__(
...@@ -29,6 +42,7 @@ class FlashLlama(FlashCausalLM): ...@@ -29,6 +42,7 @@ class FlashLlama(FlashCausalLM):
speculator: Optional[str] = None, speculator: Optional[str] = None,
dtype: Optional[torch.dtype] = None, dtype: Optional[torch.dtype] = None,
trust_remote_code: bool = False, trust_remote_code: bool = False,
lora_adapter_ids: Optional[list] = [],
): ):
self.process_group, rank, world_size = initialize_torch_distributed() self.process_group, rank, world_size = initialize_torch_distributed()
if torch.cuda.is_available(): if torch.cuda.is_available():
...@@ -78,6 +92,7 @@ class FlashLlama(FlashCausalLM): ...@@ -78,6 +92,7 @@ class FlashLlama(FlashCausalLM):
model = FlashLlamaForCausalLM(prefix, config, weights) model = FlashLlamaForCausalLM(prefix, config, weights)
torch.distributed.barrier(group=self.process_group) torch.distributed.barrier(group=self.process_group)
super(FlashLlama, self).__init__( super(FlashLlama, self).__init__(
model_id=model_id,
model=model, model=model,
tokenizer=tokenizer, tokenizer=tokenizer,
num_layers=len(model.model.layers), num_layers=len(model.model.layers),
...@@ -88,3 +103,69 @@ class FlashLlama(FlashCausalLM): ...@@ -88,3 +103,69 @@ class FlashLlama(FlashCausalLM):
rank=rank, rank=rank,
world_size=world_size, world_size=world_size,
) )
@property
def supports_adapter_loading(self) -> bool:
return True
def adapter_target_to_layer(self) -> Dict[str, Tuple[str, torch.Tensor]]:
layer_weights = {}
prefix = "model.layers"
# This accounts for VLMs (e.g. LlavaNext, Idefics2)
# that have a language_model inside of the larger model.
if hasattr(self.model, "language_model"):
_model = self.model.language_model
elif hasattr(self.model, "text_model"):
_model = self.model.text_model
else:
_model = self.model
for i, layer in enumerate(_model.model.layers):
layer_weights[(i, "q_proj")] = (
f"{prefix}.{i}.self_attn.q_proj",
layer.self_attn.query_key_value,
)
layer_weights[(i, "k_proj")] = (
f"{prefix}.{i}.self_attn.k_proj",
layer.self_attn.query_key_value,
)
layer_weights[(i, "v_proj")] = (
f"{prefix}.{i}.self_attn.v_proj",
layer.self_attn.query_key_value,
)
layer_weights[(i, "o_proj")] = (
f"{prefix}.{i}.self_attn.o_proj",
layer.self_attn.o_proj,
)
layer_weights[(i, "gate_proj")] = (
f"{prefix}.{i}.mlp.gate_proj",
layer.mlp.gate_up_proj,
)
layer_weights[(i, "up_proj")] = (
f"{prefix}.{i}.mlp.up_proj",
layer.mlp.gate_up_proj,
)
layer_weights[(i, "down_proj")] = (
f"{prefix}.{i}.mlp.down_proj",
layer.mlp.down_proj,
)
layer_weights[(0, "lm_head")] = ("lm_head", _model.lm_head)
return layer_weights
@property
def adapter_layers(self) -> List[str]:
return ADAPTER_LAYERS
@property
def default_traced_adapter_layers(self) -> List[str]:
return ["q_proj", "v_proj"]
def get_num_layers_for_type(self, layer_type: str) -> int:
return 1 if layer_type == "lm_head" else len(self.model.model.layers)
def is_row_parallel(self, layer_type: str) -> bool:
return layer_type in ROW_PARALLEL
...@@ -3,7 +3,7 @@ import torch.distributed ...@@ -3,7 +3,7 @@ import torch.distributed
from opentelemetry import trace from opentelemetry import trace
from transformers import AutoTokenizer, AutoConfig from transformers import AutoTokenizer, AutoConfig
from typing import Optional, Tuple from typing import Optional, Tuple, Dict, List
from text_generation_server.models import FlashCausalLM from text_generation_server.models import FlashCausalLM
from text_generation_server.models.flash_causal_lm import set_sliding_window from text_generation_server.models.flash_causal_lm import set_sliding_window
...@@ -21,6 +21,18 @@ from text_generation_server.utils.import_utils import SYSTEM ...@@ -21,6 +21,18 @@ from text_generation_server.utils.import_utils import SYSTEM
tracer = trace.get_tracer(__name__) tracer = trace.get_tracer(__name__)
ADAPTER_LAYERS = [
"q_proj",
"k_proj",
"v_proj",
"o_proj",
"gate_proj",
"up_proj",
"down_proj",
]
ROW_PARALLEL = {"o_proj", "down_proj", "lm_head"}
class BaseFlashMistral(FlashCausalLM): class BaseFlashMistral(FlashCausalLM):
def __init__( def __init__(
self, self,
...@@ -83,6 +95,7 @@ class BaseFlashMistral(FlashCausalLM): ...@@ -83,6 +95,7 @@ class BaseFlashMistral(FlashCausalLM):
torch.distributed.barrier(group=self.process_group) torch.distributed.barrier(group=self.process_group)
num_layers, num_kv_heads, head_size = self.get_layer_config(model) num_layers, num_kv_heads, head_size = self.get_layer_config(model)
super().__init__( super().__init__(
model_id=model_id,
model=model, model=model,
tokenizer=tokenizer, tokenizer=tokenizer,
num_layers=num_layers, num_layers=num_layers,
...@@ -102,6 +115,75 @@ class BaseFlashMistral(FlashCausalLM): ...@@ -102,6 +115,75 @@ class BaseFlashMistral(FlashCausalLM):
model.model.head_size, model.model.head_size,
) )
@property
def supports_adapter_loading(self) -> bool:
return True
def adapter_target_to_layer(self) -> Dict[str, Tuple[str, torch.Tensor]]:
layer_weights = {}
prefix = "model.layers"
# This accounts for VLMs (e.g. LlavaNext, Idefics2)
# that have a language_model inside of the larger model.
if hasattr(self.model, "language_model"):
_model = self.model.language_model
elif hasattr(self.model, "text_model"):
_model = self.model.text_model
else:
_model = self.model
for i, layer in enumerate(_model.model.layers):
layer_weights[(i, "q_proj")] = (
f"{prefix}.{i}.self_attn.q_proj",
layer.self_attn.query_key_value,
)
layer_weights[(i, "k_proj")] = (
f"{prefix}.{i}.self_attn.k_proj",
layer.self_attn.query_key_value,
)
layer_weights[(i, "v_proj")] = (
f"{prefix}.{i}.self_attn.v_proj",
layer.self_attn.query_key_value,
)
layer_weights[(i, "o_proj")] = (
f"{prefix}.{i}.self_attn.o_proj",
layer.self_attn.o_proj,
)
# TODO: this is a hack to avoid the gate_proj for
# FlashStarcoder2 that doesnt have these layers
if hasattr(layer.mlp, "gate_up_proj"):
layer_weights[(i, "gate_proj")] = (
f"{prefix}.{i}.mlp.gate_proj",
layer.mlp.gate_up_proj,
)
layer_weights[(i, "up_proj")] = (
f"{prefix}.{i}.mlp.up_proj",
layer.mlp.gate_up_proj,
)
layer_weights[(i, "down_proj")] = (
f"{prefix}.{i}.mlp.down_proj",
layer.mlp.down_proj,
)
layer_weights[(0, "lm_head")] = ("lm_head", _model.lm_head)
return layer_weights
@property
def adapter_layers(self) -> List[str]:
return ADAPTER_LAYERS
@property
def default_traced_adapter_layers(self) -> List[str]:
return ["q_proj", "v_proj"]
def get_num_layers_for_type(self, layer_type: str) -> int:
return 1 if layer_type == "lm_head" else len(self.model.model.layers)
def is_row_parallel(self, layer_type: str) -> bool:
return layer_type in ROW_PARALLEL
class FlashMistral(BaseFlashMistral): class FlashMistral(BaseFlashMistral):
def __init__( def __init__(
......
...@@ -69,6 +69,7 @@ class FlashNeoXSharded(FlashCausalLM): ...@@ -69,6 +69,7 @@ class FlashNeoXSharded(FlashCausalLM):
torch.distributed.barrier(group=self.process_group) torch.distributed.barrier(group=self.process_group)
super(FlashNeoXSharded, self).__init__( super(FlashNeoXSharded, self).__init__(
model_id=model_id,
model=model.to(device), model=model.to(device),
tokenizer=tokenizer, tokenizer=tokenizer,
num_layers=len(model.gpt_neox.layers), num_layers=len(model.gpt_neox.layers),
......
...@@ -90,6 +90,7 @@ class FlashPhi(FlashCausalLM): ...@@ -90,6 +90,7 @@ class FlashPhi(FlashCausalLM):
torch.distributed.barrier(group=self.process_group) torch.distributed.barrier(group=self.process_group)
super(FlashPhi, self).__init__( super(FlashPhi, self).__init__(
model_id=model_id,
model=model, model=model,
tokenizer=tokenizer, tokenizer=tokenizer,
num_layers=len(model.model.layers), num_layers=len(model.model.layers),
......
...@@ -71,6 +71,7 @@ class FlashQwen2(BaseFlashMistral): ...@@ -71,6 +71,7 @@ class FlashQwen2(BaseFlashMistral):
torch.distributed.barrier(group=self.process_group) torch.distributed.barrier(group=self.process_group)
super(BaseFlashMistral, self).__init__( super(BaseFlashMistral, self).__init__(
model_id=model_id,
model=model, model=model,
tokenizer=tokenizer, tokenizer=tokenizer,
num_layers=len(model.model.layers), num_layers=len(model.model.layers),
......
...@@ -78,6 +78,7 @@ class FlashRWSharded(FlashCausalLM): ...@@ -78,6 +78,7 @@ class FlashRWSharded(FlashCausalLM):
torch.distributed.barrier(group=self.process_group) torch.distributed.barrier(group=self.process_group)
super(FlashRWSharded, self).__init__( super(FlashRWSharded, self).__init__(
model_id=model_id,
model=model.to(device), model=model.to(device),
tokenizer=tokenizer, tokenizer=tokenizer,
num_layers=len(model.transformer.h), num_layers=len(model.transformer.h),
......
...@@ -80,6 +80,7 @@ class FlashSantacoderSharded(FlashCausalLM): ...@@ -80,6 +80,7 @@ class FlashSantacoderSharded(FlashCausalLM):
torch.distributed.barrier(group=self.process_group) torch.distributed.barrier(group=self.process_group)
super(FlashSantacoderSharded, self).__init__( super(FlashSantacoderSharded, self).__init__(
model_id=model_id,
model=model.to(device), model=model.to(device),
tokenizer=tokenizer, tokenizer=tokenizer,
num_layers=len(model.transformer.h), num_layers=len(model.transformer.h),
......
...@@ -70,6 +70,7 @@ class FlashStarcoder2(BaseFlashMistral): ...@@ -70,6 +70,7 @@ class FlashStarcoder2(BaseFlashMistral):
torch.distributed.barrier(group=self.process_group) torch.distributed.barrier(group=self.process_group)
super(BaseFlashMistral, self).__init__( super(BaseFlashMistral, self).__init__(
model_id=model_id,
model=model, model=model,
tokenizer=tokenizer, tokenizer=tokenizer,
num_layers=len(model.model.layers), num_layers=len(model.model.layers),
......
...@@ -212,6 +212,7 @@ class GalacticaSharded(CausalLM): ...@@ -212,6 +212,7 @@ class GalacticaSharded(CausalLM):
torch.distributed.barrier(group=self.process_group) torch.distributed.barrier(group=self.process_group)
super(CausalLM, self).__init__( super(CausalLM, self).__init__(
model_id=model_id,
model=model, model=model,
tokenizer=tokenizer, tokenizer=tokenizer,
requires_padding=True, requires_padding=True,
......
import torch import torch
import os import os
from loguru import logger from loguru import logger
from typing import Dict
MEM_POOL = torch.cuda.graph_pool_handle() if torch.cuda.is_available() else None MEM_POOL = torch.cuda.graph_pool_handle() if torch.cuda.is_available() else None
# This is overridden by the cli # This is overridden by the cli
...@@ -32,3 +33,14 @@ MODEL_ID = None ...@@ -32,3 +33,14 @@ MODEL_ID = None
def set_model_id(model_id: str): def set_model_id(model_id: str):
global MODEL_ID global MODEL_ID
MODEL_ID = model_id MODEL_ID = model_id
# NOTE: eventually we should move this into the router and pass back the
# index in all cases.
global ADAPTER_TO_INDEX
ADAPTER_TO_INDEX: Dict[str, int] = None
def set_adapter_to_index(adapter_to_index: Dict[str, int]):
global ADAPTER_TO_INDEX
ADAPTER_TO_INDEX = adapter_to_index
...@@ -65,6 +65,7 @@ class GPTNeoxSharded(CausalLM): ...@@ -65,6 +65,7 @@ class GPTNeoxSharded(CausalLM):
torch.distributed.barrier(group=self.process_group) torch.distributed.barrier(group=self.process_group)
super(CausalLM, self).__init__( super(CausalLM, self).__init__(
model_id=model_id,
model=model, model=model,
tokenizer=tokenizer, tokenizer=tokenizer,
requires_padding=True, requires_padding=True,
......
...@@ -83,6 +83,7 @@ class IDEFICSSharded(IdeficsCausalLM): ...@@ -83,6 +83,7 @@ class IDEFICSSharded(IdeficsCausalLM):
torch.distributed.barrier(group=self.process_group) torch.distributed.barrier(group=self.process_group)
super(IdeficsCausalLM, self).__init__( super(IdeficsCausalLM, self).__init__(
model_id=model_id,
model=model, model=model,
tokenizer=tokenizer, tokenizer=tokenizer,
requires_padding=True, requires_padding=True,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment