Unverified Commit 04e1af94 authored by drbh's avatar drbh Committed by GitHub
Browse files

Enable multiple LoRa adapters (#2010)



* feat: first draft load multiple lora

* feat: load weights within layer and refactor lora pass

* fix: refactor and reduce lora math

* feat: baseline impl single request multi lora support

* feat: prefer lorax implementation and port loading logic

* fix: prefer adapter_data and refactors

* feat: perfer loraxs custom punica kernels and add mlp loras

* fix: adjust batch for bgmv

* fix: adjust adapter_segments logic when in batch

* fix: refactor and move changes to v3 proto

* fix: pass model_id for all flash causal lms

* fix: pass model_id for all causal and seq2seq lms

* fix: add model_id to model test

* feat: add lora support to mistral and refactors

* feat: prefer model id in request

* fix: include rust code for adapter id

* feat: bump launcher and add new lora docs

* feat: support base model generation and refactors

* fix: rename doc to retry ci build

* feat: support if vlm models

* fix: add adapter_data param and avoid missing layers

* fix: add adapter_data param to phi and neox

* fix: update all models forwards to include adapter_data

* fix: add model_id to IdeficsCausalLM

* Update lora.md

Fixed a typo

* Update lora.md

Fixing spam image

* fix: add lora kernel to dockerfile, support running without kernels and refactors

* fix: avoid dockerfile conflict

* fix: refactors and adjust flash llama lora logic

* fix: skip llama test due to CI issue (temp)

* fix: skip llama test CI (temp) 2

* fix: revert skips and prefer updated ci token for tests

* fix: refactors and helpful comments

* fix: add noop in TensorParallelAdapterRowLinear too

* fix: refactor and move shard_lora_weights logic

* fix: exit early if no adapter_data

---------
Co-authored-by: default avatarDerek <datavistics@gmail.com>
parent a2a97b05
......@@ -525,6 +525,7 @@ class FlashStarcoder2ForCausalLM(torch.nn.Module):
max_s: int,
prefill_cache_indices: Optional[torch.Tensor],
lm_head_indices: Optional[torch.Tensor] = None,
adapter_data: Optional[torch.Tensor] = None,
) -> torch.Tensor:
true_max_s = max_s
if prefill_cache_indices is not None:
......
......@@ -741,6 +741,7 @@ class Idefics2ForConditionalGeneration(nn.Module):
pixel_attention_mask: Optional[torch.BoolTensor] = None,
# Unused here
image_sizes: Optional[torch.Tensor] = None,
adapter_data: Optional[torch.Tensor] = None,
):
inputs_embeds = self.text_model.embed_tokens(input_ids)
if pixel_values is not None:
......
......@@ -178,6 +178,7 @@ class LlavaNextForConditionalGeneration(nn.Module):
# Unused for this model
pixel_attention_mask=None,
image_sizes: Optional[torch.LongTensor] = None,
adapter_data: Optional[torch.Tensor] = None,
):
inputs_embeds = self.language_model.embed_tokens(input_ids)
if pixel_values is not None and len(pixel_values) > 0:
......
......@@ -13,6 +13,7 @@ from opentelemetry import trace
from transformers import PreTrainedTokenizerBase
from typing import Iterable, Optional, Tuple, List, Type, Dict
from text_generation_server.adapters import AdapterBatchData, AdapterBatchMetadata
from huggingface_hub.constants import HUGGINGFACE_HUB_CACHE
from text_generation_server.utils.chunks import concat_text_chunks
from text_generation_server.utils.import_utils import SYSTEM
......@@ -31,6 +32,7 @@ from text_generation_server.models.globals import MEM_POOL, CUDA_GRAPHS
import text_generation_server.models.globals as tgi_globals
from text_generation_server.utils import StoppingCriteria, HeterogeneousNextTokenChooser
from text_generation_server.utils.dist import MEMORY_FRACTION
from text_generation_server.utils.segments import SegmentConcatBuilder, find_segments
from text_generation_server.utils.import_utils import (
empty_cache,
......@@ -114,6 +116,9 @@ class FlashCausalLMBatch(Batch):
top_n_tokens: List[int]
top_n_tokens_tensor: torch.Tensor
# Adapter metadata for each request
adapter_meta: AdapterBatchMetadata
# Number of blocks in this batch
num_blocks: int
# Maximum number of blocks
......@@ -174,6 +179,9 @@ class FlashCausalLMBatch(Batch):
stopping_criterias = []
top_n_tokens = []
adapter_indices_list = []
adapter_set = set()
# Cumulative length
cumulative_length = 0
cumulative_max_length = 0
......@@ -225,6 +233,10 @@ class FlashCausalLMBatch(Batch):
stopping_criterias.append(stopping_criteria)
top_n_tokens.append(r.top_n_tokens)
adapter_index = tgi_globals.ADAPTER_TO_INDEX.get(r.adapter_id, 0)
adapter_indices_list.append(torch.full((input_length,), adapter_index))
adapter_set.add(adapter_index)
# Paged attention
# Remove one as the first token des not have a past
speculative_length = get_speculate()
......@@ -296,6 +308,10 @@ class FlashCausalLMBatch(Batch):
max_length, input_length + max_new_tokens + speculative_length
)
adapter_indices = torch.cat(adapter_indices_list).to(
dtype=torch.int64, device=device
)
next_token_chooser = HeterogeneousNextTokenChooser.from_pb(
next_token_chooser_parameters, dtype, device, tokenizer
)
......@@ -339,6 +355,11 @@ class FlashCausalLMBatch(Batch):
input_lengths, dtype=torch.int32, device=device
)
adapter_segments, adapter_segment_indices = find_segments(adapter_indices)
adapter_segments = torch.tensor(
adapter_segments, dtype=torch.int32, device=device
)
if all_prefill_logprobs:
prefill_head_indices = None
prefill_next_token_indices = cu_seqlen_prefill[1:] - 1
......@@ -393,6 +414,12 @@ class FlashCausalLMBatch(Batch):
top_n_tokens_tensor=top_n_tokens_tensor,
num_blocks=num_blocks,
max_blocks=max_blocks,
adapter_meta=AdapterBatchMetadata(
adapter_indices=adapter_indices,
adapter_set=adapter_set,
adapter_segments=adapter_segments,
segment_indices=adapter_segment_indices,
),
speculative_ids=None,
)
......@@ -443,6 +470,7 @@ class FlashCausalLMBatch(Batch):
stopping_criterias = []
top_n_tokens = []
adapter_set = set()
num_blocks = 0
max_blocks = 0
......@@ -471,6 +499,11 @@ class FlashCausalLMBatch(Batch):
top_n_tokens.append(self.top_n_tokens[idx])
adapter_index = tgi_globals.ADAPTER_TO_INDEX.get(
self.requests[idx].adapter_id, 0
)
adapter_set.add(adapter_index)
remaining_tokens = (
stopping_criteria.max_new_tokens - stopping_criteria.current_tokens
)
......@@ -498,6 +531,7 @@ class FlashCausalLMBatch(Batch):
# Index into tensors
input_ids = self.input_ids[indices]
position_ids = self.position_ids[indices]
adapter_indices = self.adapter_meta.adapter_indices[indices]
all_input_ids_tensor = self.all_input_ids_tensor[indices]
block_tables_tensor = self.block_tables_tensor[indices]
input_lengths_tensor = self.input_lengths_tensor[indices]
......@@ -513,6 +547,11 @@ class FlashCausalLMBatch(Batch):
# Move to GPU now that we have the whole tensor
slot_indices = slot_indices.to(device)
adapter_segments, adapter_segment_indices = find_segments(adapter_indices)
adapter_segments = torch.tensor(
adapter_segments, dtype=torch.int32, device=device
)
return type(self)(
batch_id=self.batch_id,
requests=requests,
......@@ -543,6 +582,12 @@ class FlashCausalLMBatch(Batch):
num_blocks=num_blocks,
max_blocks=max_blocks,
speculative_ids=speculative_ids,
adapter_meta=AdapterBatchMetadata(
adapter_indices=adapter_indices,
adapter_set=adapter_set,
adapter_segments=adapter_segments,
segment_indices=adapter_segment_indices,
),
)
@classmethod
......@@ -596,6 +641,14 @@ class FlashCausalLMBatch(Batch):
top_n_tokens_tensor = batches[0].top_n_tokens_tensor.new_zeros(
total_batch_size,
)
total_indices_size = sum(
b.adapter_meta.adapter_indices.shape[0] for b in batches
)
adapter_indices = batches[0].adapter_meta.adapter_indices.new_empty(
total_indices_size
)
adapter_set = set()
adapter_segment_builder = SegmentConcatBuilder()
start_slots = []
block_tables = []
......@@ -613,6 +666,7 @@ class FlashCausalLMBatch(Batch):
# Cumulative length
cumulative_batch_size = 0
cumulative_slots = 0
cumulative_adapter_indices_size = 0
for i, batch in enumerate(batches):
requests.extend(batch.requests)
......@@ -637,6 +691,21 @@ class FlashCausalLMBatch(Batch):
top_n_tokens_tensor[start_index:end_index] = batch.top_n_tokens_tensor
slots[slots_start_index:slots_end_index] = batch.slots
# Copy over adapter indices
adapter_start_index = cumulative_adapter_indices_size
adapter_end_index = (
cumulative_adapter_indices_size
+ batch.adapter_meta.adapter_indices.shape[0]
)
adapter_indices[adapter_start_index:adapter_end_index] = (
batch.adapter_meta.adapter_indices
)
cumulative_adapter_indices_size = adapter_end_index
adapter_set.update(batch.adapter_meta.adapter_set)
adapter_segment_builder.concat(
batch.adapter_meta.adapter_segments, batch.adapter_meta.segment_indices
)
all_input_ids_tensor[
start_index:end_index, : batch.all_input_ids_tensor.shape[1]
] = batch.all_input_ids_tensor[:, :max_length]
......@@ -680,6 +749,8 @@ class FlashCausalLMBatch(Batch):
else None
)
adapter_segments, adapter_segment_indices = adapter_segment_builder.build()
return cls(
batch_id=batches[0].batch_id,
requests=requests,
......@@ -710,6 +781,12 @@ class FlashCausalLMBatch(Batch):
num_blocks=num_blocks,
max_blocks=max_blocks,
speculative_ids=speculative_ids,
adapter_meta=AdapterBatchMetadata(
adapter_indices=adapter_indices,
adapter_set=adapter_set,
adapter_segments=adapter_segments,
segment_indices=adapter_segment_indices,
),
)
def __len__(self):
......@@ -719,6 +796,7 @@ class FlashCausalLMBatch(Batch):
class FlashCausalLM(Model):
def __init__(
self,
model_id: str,
model: torch.nn.Module,
tokenizer: PreTrainedTokenizerBase,
num_layers: int,
......@@ -738,6 +816,7 @@ class FlashCausalLM(Model):
self.kv_cache = []
super(FlashCausalLM, self).__init__(
model_id=model_id,
model=model,
tokenizer=tokenizer,
requires_padding=False,
......@@ -895,12 +974,13 @@ class FlashCausalLM(Model):
total_cache_size = self.num_layers * cache_block_size * 2 * dtype_size
free_memory = get_free_memory(self.device, MEMORY_FRACTION)
batch_num_blocks = batch.num_blocks if batch is not None else 0
num_blocks = (
# Leave 5% for some wiggle room
int((free_memory * 0.95) // total_cache_size)
# Add batch.num_blocks as we allocated it above, so it is included in the peak memory.
+ batch.num_blocks
+ batch_num_blocks
)
del batch
......@@ -1001,7 +1081,7 @@ class FlashCausalLM(Model):
)
def forward(
self, batch: FlashCausalLMBatch
self, batch: FlashCausalLMBatch, adapter_data: AdapterBatchData
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
# Model Forward
if batch.speculative_ids is not None:
......@@ -1080,6 +1160,7 @@ class FlashCausalLM(Model):
max_s=max_s,
prefill_cache_indices=batch.prefill_cache_indices,
lm_head_indices=lm_head_indices,
adapter_data=adapter_data,
)
if batch.prefill_cache_indices is not None:
batch.prefill_cache_indices = None
......@@ -1116,7 +1197,34 @@ class FlashCausalLM(Model):
prefill = batch.cu_seqlen_prefill is not None
prefill_logprobs = batch.prefill_next_token_indices is not None
out, speculative_logits = self.forward(batch)
# Update adapter indices for speculative tokens (if present)
adapter_meta = batch.adapter_meta
if batch.speculative_ids is not None:
B, speculative_length = batch.speculative_ids.shape
new_length = speculative_length + 1
adapter_indices = (
adapter_meta.adapter_indices.unsqueeze(-1)
.expand(B, new_length)
.reshape(-1)
)
adapter_segments = adapter_meta.adapter_segments * new_length
adapter_meta = AdapterBatchMetadata(
adapter_indices=adapter_indices,
adapter_set=adapter_meta.adapter_set,
adapter_segments=adapter_segments,
segment_indices=adapter_meta.segment_indices,
)
# Assign pointers to adapter weights
# TODO(travis): don't update this if indices haven't changed
adapter_data = AdapterBatchData.from_meta(
adapter_meta,
self.layer_to_adapter_weights,
prefill,
batch.prefill_head_indices,
)
out, speculative_logits = self.forward(batch, adapter_data)
if prefill:
next_token_logits = (
......@@ -1128,8 +1236,13 @@ class FlashCausalLM(Model):
if prefill_logprobs
else speculative_logits
)
next_adapter_indices = batch.adapter_meta.adapter_indices.new_empty(
len(batch)
)
else:
next_token_logits = out
next_adapter_indices = batch.adapter_meta.adapter_indices
speculate = get_speculate()
(
......@@ -1195,6 +1308,12 @@ class FlashCausalLM(Model):
# In decode, we do not need this as we can just increment position ids
next_position_ids[i] = batch.position_ids[end_index - 1]
# Initialize adapter indices
# In decode, we only have one token per row in the batch, so grab last index
next_adapter_indices[i] = batch.adapter_meta.adapter_indices[
end_index - 1
]
# Used to gather prefill logprobs
# Copy batch.input_ids to prefill_token_indices
if prefill_logprobs:
......@@ -1220,6 +1339,16 @@ class FlashCausalLM(Model):
batch.position_ids = next_position_ids + accepted_ids
batch.input_lengths_tensor += accepted_ids
batch.slot_indices += accepted_ids
batch.adapter_meta.adapter_indices = next_adapter_indices
if prefill:
# adjust segment lengths to account for all request lengths being 1 during decoding
adapter_segments, _ = find_segments(batch.adapter_meta.adapter_indices)
batch.adapter_meta.adapter_segments = torch.tensor(
adapter_segments,
dtype=torch.int32,
device=batch.adapter_meta.adapter_segments.device,
)
if prefill and prefill_logprobs:
# Get prefill logprobs
......
......@@ -62,6 +62,7 @@ class FlashCohere(FlashCausalLM):
torch.distributed.barrier(group=self.process_group)
super(FlashCohere, self).__init__(
model_id=model_id,
model=model,
tokenizer=tokenizer,
num_layers=len(model.model.layers),
......
......@@ -87,6 +87,7 @@ class FlashDbrx(FlashCausalLM):
torch.distributed.barrier(group=self.process_group)
super(FlashDbrx, self).__init__(
model_id=model_id,
model=model,
tokenizer=tokenizer,
num_layers=len(model.model.layers),
......
......@@ -62,6 +62,7 @@ class FlashGemma(FlashCausalLM):
torch.distributed.barrier(group=self.process_group)
super(FlashGemma, self).__init__(
model_id=model_id,
model=model,
tokenizer=tokenizer,
num_layers=len(model.model.layers),
......
......@@ -69,6 +69,7 @@ class FlashGPT2(FlashCausalLM):
model = FlashGPT2ForCausalLM(prefix, config, weights)
torch.distributed.barrier(group=self.process_group)
super(FlashGPT2, self).__init__(
model_id=model_id,
model=model,
tokenizer=tokenizer,
num_layers=len(model.model.layers),
......
import os
import torch
import torch.distributed
from opentelemetry import trace
from transformers import AutoConfig, AutoTokenizer, GenerationConfig
from typing import Optional
from typing import Optional, Tuple, Dict, List
from text_generation_server.models import FlashCausalLM
from text_generation_server.models.custom_modeling.flash_llama_modeling import (
......@@ -13,12 +14,24 @@ from text_generation_server.utils import (
initialize_torch_distributed,
weight_files,
Weights,
hub,
)
tracer = trace.get_tracer(__name__)
from text_generation_server.utils.import_utils import SYSTEM
ADAPTER_LAYERS = [
"q_proj",
"k_proj",
"v_proj",
"o_proj",
"gate_proj",
"up_proj",
"down_proj",
]
ROW_PARALLEL = {"o_proj", "down_proj", "lm_head"}
class FlashLlama(FlashCausalLM):
def __init__(
......@@ -29,6 +42,7 @@ class FlashLlama(FlashCausalLM):
speculator: Optional[str] = None,
dtype: Optional[torch.dtype] = None,
trust_remote_code: bool = False,
lora_adapter_ids: Optional[list] = [],
):
self.process_group, rank, world_size = initialize_torch_distributed()
if torch.cuda.is_available():
......@@ -78,6 +92,7 @@ class FlashLlama(FlashCausalLM):
model = FlashLlamaForCausalLM(prefix, config, weights)
torch.distributed.barrier(group=self.process_group)
super(FlashLlama, self).__init__(
model_id=model_id,
model=model,
tokenizer=tokenizer,
num_layers=len(model.model.layers),
......@@ -88,3 +103,69 @@ class FlashLlama(FlashCausalLM):
rank=rank,
world_size=world_size,
)
@property
def supports_adapter_loading(self) -> bool:
return True
def adapter_target_to_layer(self) -> Dict[str, Tuple[str, torch.Tensor]]:
layer_weights = {}
prefix = "model.layers"
# This accounts for VLMs (e.g. LlavaNext, Idefics2)
# that have a language_model inside of the larger model.
if hasattr(self.model, "language_model"):
_model = self.model.language_model
elif hasattr(self.model, "text_model"):
_model = self.model.text_model
else:
_model = self.model
for i, layer in enumerate(_model.model.layers):
layer_weights[(i, "q_proj")] = (
f"{prefix}.{i}.self_attn.q_proj",
layer.self_attn.query_key_value,
)
layer_weights[(i, "k_proj")] = (
f"{prefix}.{i}.self_attn.k_proj",
layer.self_attn.query_key_value,
)
layer_weights[(i, "v_proj")] = (
f"{prefix}.{i}.self_attn.v_proj",
layer.self_attn.query_key_value,
)
layer_weights[(i, "o_proj")] = (
f"{prefix}.{i}.self_attn.o_proj",
layer.self_attn.o_proj,
)
layer_weights[(i, "gate_proj")] = (
f"{prefix}.{i}.mlp.gate_proj",
layer.mlp.gate_up_proj,
)
layer_weights[(i, "up_proj")] = (
f"{prefix}.{i}.mlp.up_proj",
layer.mlp.gate_up_proj,
)
layer_weights[(i, "down_proj")] = (
f"{prefix}.{i}.mlp.down_proj",
layer.mlp.down_proj,
)
layer_weights[(0, "lm_head")] = ("lm_head", _model.lm_head)
return layer_weights
@property
def adapter_layers(self) -> List[str]:
return ADAPTER_LAYERS
@property
def default_traced_adapter_layers(self) -> List[str]:
return ["q_proj", "v_proj"]
def get_num_layers_for_type(self, layer_type: str) -> int:
return 1 if layer_type == "lm_head" else len(self.model.model.layers)
def is_row_parallel(self, layer_type: str) -> bool:
return layer_type in ROW_PARALLEL
......@@ -3,7 +3,7 @@ import torch.distributed
from opentelemetry import trace
from transformers import AutoTokenizer, AutoConfig
from typing import Optional, Tuple
from typing import Optional, Tuple, Dict, List
from text_generation_server.models import FlashCausalLM
from text_generation_server.models.flash_causal_lm import set_sliding_window
......@@ -21,6 +21,18 @@ from text_generation_server.utils.import_utils import SYSTEM
tracer = trace.get_tracer(__name__)
ADAPTER_LAYERS = [
"q_proj",
"k_proj",
"v_proj",
"o_proj",
"gate_proj",
"up_proj",
"down_proj",
]
ROW_PARALLEL = {"o_proj", "down_proj", "lm_head"}
class BaseFlashMistral(FlashCausalLM):
def __init__(
self,
......@@ -83,6 +95,7 @@ class BaseFlashMistral(FlashCausalLM):
torch.distributed.barrier(group=self.process_group)
num_layers, num_kv_heads, head_size = self.get_layer_config(model)
super().__init__(
model_id=model_id,
model=model,
tokenizer=tokenizer,
num_layers=num_layers,
......@@ -102,6 +115,75 @@ class BaseFlashMistral(FlashCausalLM):
model.model.head_size,
)
@property
def supports_adapter_loading(self) -> bool:
return True
def adapter_target_to_layer(self) -> Dict[str, Tuple[str, torch.Tensor]]:
layer_weights = {}
prefix = "model.layers"
# This accounts for VLMs (e.g. LlavaNext, Idefics2)
# that have a language_model inside of the larger model.
if hasattr(self.model, "language_model"):
_model = self.model.language_model
elif hasattr(self.model, "text_model"):
_model = self.model.text_model
else:
_model = self.model
for i, layer in enumerate(_model.model.layers):
layer_weights[(i, "q_proj")] = (
f"{prefix}.{i}.self_attn.q_proj",
layer.self_attn.query_key_value,
)
layer_weights[(i, "k_proj")] = (
f"{prefix}.{i}.self_attn.k_proj",
layer.self_attn.query_key_value,
)
layer_weights[(i, "v_proj")] = (
f"{prefix}.{i}.self_attn.v_proj",
layer.self_attn.query_key_value,
)
layer_weights[(i, "o_proj")] = (
f"{prefix}.{i}.self_attn.o_proj",
layer.self_attn.o_proj,
)
# TODO: this is a hack to avoid the gate_proj for
# FlashStarcoder2 that doesnt have these layers
if hasattr(layer.mlp, "gate_up_proj"):
layer_weights[(i, "gate_proj")] = (
f"{prefix}.{i}.mlp.gate_proj",
layer.mlp.gate_up_proj,
)
layer_weights[(i, "up_proj")] = (
f"{prefix}.{i}.mlp.up_proj",
layer.mlp.gate_up_proj,
)
layer_weights[(i, "down_proj")] = (
f"{prefix}.{i}.mlp.down_proj",
layer.mlp.down_proj,
)
layer_weights[(0, "lm_head")] = ("lm_head", _model.lm_head)
return layer_weights
@property
def adapter_layers(self) -> List[str]:
return ADAPTER_LAYERS
@property
def default_traced_adapter_layers(self) -> List[str]:
return ["q_proj", "v_proj"]
def get_num_layers_for_type(self, layer_type: str) -> int:
return 1 if layer_type == "lm_head" else len(self.model.model.layers)
def is_row_parallel(self, layer_type: str) -> bool:
return layer_type in ROW_PARALLEL
class FlashMistral(BaseFlashMistral):
def __init__(
......
......@@ -69,6 +69,7 @@ class FlashNeoXSharded(FlashCausalLM):
torch.distributed.barrier(group=self.process_group)
super(FlashNeoXSharded, self).__init__(
model_id=model_id,
model=model.to(device),
tokenizer=tokenizer,
num_layers=len(model.gpt_neox.layers),
......
......@@ -90,6 +90,7 @@ class FlashPhi(FlashCausalLM):
torch.distributed.barrier(group=self.process_group)
super(FlashPhi, self).__init__(
model_id=model_id,
model=model,
tokenizer=tokenizer,
num_layers=len(model.model.layers),
......
......@@ -71,6 +71,7 @@ class FlashQwen2(BaseFlashMistral):
torch.distributed.barrier(group=self.process_group)
super(BaseFlashMistral, self).__init__(
model_id=model_id,
model=model,
tokenizer=tokenizer,
num_layers=len(model.model.layers),
......
......@@ -78,6 +78,7 @@ class FlashRWSharded(FlashCausalLM):
torch.distributed.barrier(group=self.process_group)
super(FlashRWSharded, self).__init__(
model_id=model_id,
model=model.to(device),
tokenizer=tokenizer,
num_layers=len(model.transformer.h),
......
......@@ -80,6 +80,7 @@ class FlashSantacoderSharded(FlashCausalLM):
torch.distributed.barrier(group=self.process_group)
super(FlashSantacoderSharded, self).__init__(
model_id=model_id,
model=model.to(device),
tokenizer=tokenizer,
num_layers=len(model.transformer.h),
......
......@@ -70,6 +70,7 @@ class FlashStarcoder2(BaseFlashMistral):
torch.distributed.barrier(group=self.process_group)
super(BaseFlashMistral, self).__init__(
model_id=model_id,
model=model,
tokenizer=tokenizer,
num_layers=len(model.model.layers),
......
......@@ -212,6 +212,7 @@ class GalacticaSharded(CausalLM):
torch.distributed.barrier(group=self.process_group)
super(CausalLM, self).__init__(
model_id=model_id,
model=model,
tokenizer=tokenizer,
requires_padding=True,
......
import torch
import os
from loguru import logger
from typing import Dict
MEM_POOL = torch.cuda.graph_pool_handle() if torch.cuda.is_available() else None
# This is overridden by the cli
......@@ -32,3 +33,14 @@ MODEL_ID = None
def set_model_id(model_id: str):
global MODEL_ID
MODEL_ID = model_id
# NOTE: eventually we should move this into the router and pass back the
# index in all cases.
global ADAPTER_TO_INDEX
ADAPTER_TO_INDEX: Dict[str, int] = None
def set_adapter_to_index(adapter_to_index: Dict[str, int]):
global ADAPTER_TO_INDEX
ADAPTER_TO_INDEX = adapter_to_index
......@@ -65,6 +65,7 @@ class GPTNeoxSharded(CausalLM):
torch.distributed.barrier(group=self.process_group)
super(CausalLM, self).__init__(
model_id=model_id,
model=model,
tokenizer=tokenizer,
requires_padding=True,
......
......@@ -83,6 +83,7 @@ class IDEFICSSharded(IdeficsCausalLM):
torch.distributed.barrier(group=self.process_group)
super(IdeficsCausalLM, self).__init__(
model_id=model_id,
model=model,
tokenizer=tokenizer,
requires_padding=True,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment