Unverified Commit 04e1af94 authored by drbh's avatar drbh Committed by GitHub
Browse files

Enable multiple LoRa adapters (#2010)



* feat: first draft load multiple lora

* feat: load weights within layer and refactor lora pass

* fix: refactor and reduce lora math

* feat: baseline impl single request multi lora support

* feat: prefer lorax implementation and port loading logic

* fix: prefer adapter_data and refactors

* feat: perfer loraxs custom punica kernels and add mlp loras

* fix: adjust batch for bgmv

* fix: adjust adapter_segments logic when in batch

* fix: refactor and move changes to v3 proto

* fix: pass model_id for all flash causal lms

* fix: pass model_id for all causal and seq2seq lms

* fix: add model_id to model test

* feat: add lora support to mistral and refactors

* feat: prefer model id in request

* fix: include rust code for adapter id

* feat: bump launcher and add new lora docs

* feat: support base model generation and refactors

* fix: rename doc to retry ci build

* feat: support if vlm models

* fix: add adapter_data param and avoid missing layers

* fix: add adapter_data param to phi and neox

* fix: update all models forwards to include adapter_data

* fix: add model_id to IdeficsCausalLM

* Update lora.md

Fixed a typo

* Update lora.md

Fixing spam image

* fix: add lora kernel to dockerfile, support running without kernels and refactors

* fix: avoid dockerfile conflict

* fix: refactors and adjust flash llama lora logic

* fix: skip llama test due to CI issue (temp)

* fix: skip llama test CI (temp) 2

* fix: revert skips and prefer updated ci token for tests

* fix: refactors and helpful comments

* fix: add noop in TensorParallelAdapterRowLinear too

* fix: refactor and move shard_lora_weights logic

* fix: exit early if no adapter_data

---------
Co-authored-by: default avatarDerek <datavistics@gmail.com>
parent a2a97b05
# Origin: https://github.com/predibase/lorax
# Path: lorax/server/lorax_server/adapters/weights.py
# License: Apache License Version 2.0, January 2004
from abc import ABC, abstractclassmethod
from collections import defaultdict
from dataclasses import dataclass
from typing import Dict, List, Optional, Set, Type
import torch
@dataclass
class AdapterBatchMetadata:
# [batch_size]
adapter_indices: torch.Tensor
# [num_adapters]
adapter_set: Set[int]
# [num_segments + 1]
adapter_segments: torch.Tensor
# [num_segments]
# maps from segment index to adapter index, i.e.:
# segment_indices[s] == adapter_indices[i]
segment_indices: List[int]
class AdapterWeights(ABC):
@abstractclassmethod
def get_batch_types(cls) -> List[Type["BatchAdapterWeights"]]:
pass
@property
def speculative_tokens(self) -> int:
return 0
class BatchAdapterWeights(ABC):
@abstractclassmethod
def has_adapter(self, adapter_index: int) -> bool:
pass
@abstractclassmethod
def key(cls) -> str:
pass
@abstractclassmethod
def load(
cls,
adapter_weights: Dict[int, AdapterWeights],
meta: "AdapterBatchMetadata",
prefill: bool,
prefill_head_indices: torch.Tensor,
) -> Optional["BatchAdapterWeights"]:
pass
class LayerAdapterWeights:
"""Adapter weights that apply to a particular layer."""
def __init__(self):
self.adapter_weights: Dict[int, AdapterWeights] = {}
def add_adapter(self, adapter_idx: int, weights: AdapterWeights):
self.adapter_weights[adapter_idx] = weights
def remove_adapter(self, adapter_idx: int):
if adapter_idx not in self.adapter_weights:
return
del self.adapter_weights[adapter_idx]
@property
def max_speculative_tokens(self) -> int:
return max(
adapter_weights.speculative_tokens
for adapter_weights in self.adapter_weights.values()
)
def is_empty(self) -> bool:
return len(self.adapter_weights) == 0
def get_data(
self,
meta: AdapterBatchMetadata,
prefill: bool,
prefill_head_indices: Optional[torch.Tensor],
) -> Dict[str, BatchAdapterWeights]:
# bucket adapters by batch class
adapter_batch_types: Dict[
Type[BatchAdapterWeights], Dict[int, AdapterWeights]
] = defaultdict(dict)
for adapter_index, adapter_weights in self.adapter_weights.items():
for batch_type in adapter_weights.get_batch_types():
adapter_batch_types[batch_type][adapter_index] = adapter_weights
batch_data = {}
for batch_type, adapter_weights in adapter_batch_types.items():
batched_weights = batch_type.load(
adapter_weights, meta, prefill, prefill_head_indices
)
if batched_weights is not None:
batch_data[batch_type.key()] = batched_weights
return batch_data
@dataclass
class AdapterBatchData:
meta: AdapterBatchMetadata
# layer type -> adapter type -> batch weight data
data: Dict[str, Dict[str, BatchAdapterWeights]]
prefill: bool
@staticmethod
def from_meta(
meta: AdapterBatchMetadata,
weights: Dict[str, LayerAdapterWeights],
prefill: bool,
prefill_head_indices: Optional[torch.Tensor],
) -> "AdapterBatchData":
data = {}
for k, v in weights.items():
if v.is_empty():
continue
data[k] = v.get_data(
meta, prefill, prefill_head_indices if k == "lm_head" else None
)
return AdapterBatchData(meta=meta, data=data, prefill=prefill)
def ranks(self) -> Set[int]:
# TODO(travis): refactor to be less coupled to lora implementation
ranks = set()
for layer_data in self.data.values():
lora_data = layer_data.get("lora")
if lora_data is None:
continue
for rank_data in lora_data.rank_data.values():
ranks.add(rank_data.rank)
return ranks
def layer_names(self) -> Set[str]:
return set(self.data.keys())
def adapter_keys(self) -> Set[str]:
adapter_keys = set()
for layer_data in self.data.values():
adapter_keys.update(layer_data.keys())
return adapter_keys
@property
def max_rank(self) -> int:
ranks = self.ranks()
return max(ranks) if len(ranks) > 0 else 0
...@@ -79,6 +79,18 @@ def serve( ...@@ -79,6 +79,18 @@ def serve(
if otlp_endpoint is not None: if otlp_endpoint is not None:
setup_tracing(otlp_service_name=otlp_service_name, otlp_endpoint=otlp_endpoint) setup_tracing(otlp_service_name=otlp_service_name, otlp_endpoint=otlp_endpoint)
lora_adapter_ids = os.getenv("LORA_ADAPTERS", None)
# split on comma and strip whitespace
lora_adapter_ids = (
[x.strip() for x in lora_adapter_ids.split(",")] if lora_adapter_ids else []
)
if len(lora_adapter_ids) > 0:
logger.warning(
f"LoRA adapters are enabled. This is an experimental feature and may not work as expected."
)
# Downgrade enum into str for easier management later on # Downgrade enum into str for easier management later on
quantize = None if quantize is None else quantize.value quantize = None if quantize is None else quantize.value
dtype = None if dtype is None else dtype.value dtype = None if dtype is None else dtype.value
...@@ -93,6 +105,7 @@ def serve( ...@@ -93,6 +105,7 @@ def serve(
) )
server.serve( server.serve(
model_id, model_id,
lora_adapter_ids,
revision, revision,
sharded, sharded,
quantize, quantize,
...@@ -113,6 +126,7 @@ def download_weights( ...@@ -113,6 +126,7 @@ def download_weights(
logger_level: str = "INFO", logger_level: str = "INFO",
json_output: bool = False, json_output: bool = False,
trust_remote_code: bool = False, trust_remote_code: bool = False,
merge_lora: bool = False,
): ):
# Remove default handler # Remove default handler
logger.remove() logger.remove()
...@@ -143,18 +157,28 @@ def download_weights( ...@@ -143,18 +157,28 @@ def download_weights(
) is not None ) is not None
if not is_local_model: if not is_local_model:
try: # TODO: maybe reverse the default value of merge_lora?
adapter_config_filename = hf_hub_download( # currently by default we don't merge the weights with the base model
model_id, revision=revision, filename="adapter_config.json" if merge_lora:
) try:
utils.download_and_unload_peft( adapter_config_filename = hf_hub_download(
model_id, revision, trust_remote_code=trust_remote_code model_id, revision=revision, filename="adapter_config.json"
) )
is_local_model = True utils.download_and_unload_peft(
utils.weight_files(model_id, revision, extension) model_id, revision, trust_remote_code=trust_remote_code
return )
except (utils.LocalEntryNotFoundError, utils.EntryNotFoundError): is_local_model = True
pass utils.weight_files(model_id, revision, extension)
return
except (utils.LocalEntryNotFoundError, utils.EntryNotFoundError):
pass
else:
try:
utils.peft.download_peft(
model_id, revision, trust_remote_code=trust_remote_code
)
except Exception:
pass
try: try:
import json import json
......
...@@ -12,3 +12,9 @@ from text_generation_server.layers.speculative import SpeculativeHead ...@@ -12,3 +12,9 @@ from text_generation_server.layers.speculative import SpeculativeHead
# Just to add the `load` methods. # Just to add the `load` methods.
from text_generation_server.layers.layernorm import load_layer_norm from text_generation_server.layers.layernorm import load_layer_norm
from text_generation_server.layers.conv import load_conv2d from text_generation_server.layers.conv import load_conv2d
from text_generation_server.layers.lora import (
LoraLinear,
TensorParallelMultiAdapterLinear,
TensorParallelAdapterRowLinear,
)
import math
import os
from typing import TYPE_CHECKING, Optional, Tuple, List
import torch
import torch.distributed
from accelerate import init_empty_weights
from torch import nn
from torch.nn import functional as F
from torch.distributed import ProcessGroup
from text_generation_server.utils.sgmv import (
add_lora_a_bgmv,
add_lora_b_bgmv,
has_sgmv,
lora_a_sgmv_cutlass,
lora_b_sgmv_cutlass,
orient_for_rank,
)
if TYPE_CHECKING:
from text_generation_server.adapters import AdapterBatchData
from text_generation_server.adapters.lora import BatchLoraWeights
class LoraLinear(nn.Module):
def __init__(
self, base_layer: nn.Module, layer_id: int, process_group: ProcessGroup
):
super().__init__()
self.base_layer = base_layer
self.layer_id = layer_id
self.process_group = process_group
def forward_layer_type(
self,
result: torch.Tensor,
input: torch.Tensor,
adapter_data: "AdapterBatchData",
layer_type: str,
start_idx: int,
end_idx: int,
) -> torch.Tensor:
if adapter_data is None:
return result
data = adapter_data.data.get(layer_type)
data: Optional["BatchLoraWeights"] = (
data.get("lora") if data is not None else None
)
if has_sgmv() and data is not None and data.can_vectorize(self.process_group):
# In tensor-parallel configurations, each GPU processes a specific segment of the output.
# The 'result' tensor represents the full output, which can vary in size based on
# the layer type (e.g., attention vs. feed-forward layers). We define the current
# segment using start_idx and end_idx. If the segment size doesn't match this GPU's
# slice of 'result', we create a zero tensor of the correct size for LoRA computation.
# This approach ensures accurate LoRA application across various layer sizes and
# configurations, adapting to different model architectures and parallelization strategies.
#
# Example scenarios where this is necessary:
# 1. The adapter's size doesn't evenly divide across GPUs.
# 2. We're processing the last segment which might be smaller.
# 3. Different projection layers (q, k, v) have different sizes.
if end_idx - start_idx != result.shape[1]:
proj = torch.zeros_like(result[:, start_idx:end_idx])
else:
proj = result
for r, rank_segments in data.rank_data.items():
lora_a_ptr = rank_segments.lora_a_ptr
lora_b_ptr = rank_segments.lora_b_ptr
if lora_a_ptr is None or lora_b_ptr is None:
raise ValueError("LoRA data is missing")
if data.use_sgmv:
# Use SGMV for prefill
v = lora_a_sgmv_cutlass(
input,
rank_segments.tmp_shrink,
lora_a_ptr,
rank_segments.segment_starts,
rank_segments.segment_ends,
self.layer_id,
r,
)
if self.process_group.size() > 1:
v = self.collect_lora_a(v)
lora_b_sgmv_cutlass(
proj,
v,
rank_segments.tmp_expand,
lora_b_ptr,
rank_segments.segment_starts,
rank_segments.segment_ends,
self.layer_id,
)
else:
# Use BGMV for decode
v = torch.zeros(
(input.size(0), r), dtype=input.dtype, device=input.device
)
# TODO: error with [-1, 0], but not [0, -1]
add_lora_a_bgmv(
v,
input,
lora_a_ptr,
rank_segments.indices,
self.layer_id,
)
if self.process_group.size() > 1:
v = self.collect_lora_a(v)
add_lora_b_bgmv(
proj,
v,
lora_b_ptr,
rank_segments.indices,
self.layer_id,
)
if end_idx - start_idx != result.shape[1]:
result[:, start_idx:end_idx] += proj
else:
for adapter_index in adapter_data.meta.adapter_set:
if data is not None and data.has_adapter(adapter_index):
adapter_mask = (
(adapter_data.meta.adapter_indices == adapter_index)
.to(input.dtype)
.view(-1, 1)
)
layer_result = self.forward_lora(
input, data, adapter_index, adapter_mask
)
result[:, start_idx:end_idx] += layer_result
return result
def forward_lora(
self,
input: torch.Tensor,
data: "BatchLoraWeights",
adapter_index: int,
adapter_mask: torch.Tensor,
) -> torch.Tensor:
lora_a = data.lora_a[adapter_index][self.layer_id, :, :]
lora_b = data.lora_b[adapter_index][self.layer_id, :, :]
lora_a = orient_for_rank(lora_a, lora_b.size(0))
a_out = input @ lora_a
if self.process_group.size() > 1:
a_out = self.collect_lora_a(a_out)
result = (a_out @ lora_b) * adapter_mask
return result
def collect_lora_a(self, a_out: torch.Tensor) -> torch.Tensor:
raise NotImplementedError("Implemented in subclasses")
class TensorParallelMultiAdapterLinear(LoraLinear):
def __init__(
self,
base_layer: nn.Module,
layer_id: int,
layer_names: List[str],
sizes: List[int],
process_group: ProcessGroup,
):
super().__init__(base_layer, layer_id, process_group)
self.layer_names = layer_names
self.sizes = sizes
@classmethod
def load(
cls,
base_layer: nn.Module,
layer_id: int,
layer_names: List[str],
sizes: List[int],
process_group: ProcessGroup,
):
return TensorParallelMultiAdapterLinear(
base_layer, layer_id, layer_names, sizes, process_group
)
def forward(
self, input: torch.Tensor, adapter_data: "AdapterBatchData"
) -> torch.Tensor:
result = self.base_layer(input)
# noop if no layer names are provided (e.g. for models without adapters)
if self.layer_names is None:
return result
# handle models like Bloom that have inputs of shape
# (batch_size, sequence_length, hidden_size)
# we need to reshape them to (batch_size * sequence_length, hidden_size)
# for the LoRA computation, then reshape back
prev_shape = result.shape
is_3d = len(input.shape) >= 3
if is_3d:
input = input.reshape(-1, input.shape[-1])
result = result.reshape(-1, result.shape[-1])
offset = 0
for i, layer_name in enumerate(self.layer_names):
start_idx = offset // self.process_group.size()
# The 'sizes' parameter is essential in tensor-parallel setups for handling multiple
# projection layers (q_proj, k_proj, v_proj) by defining their output dimensions. It
# ensures correct slicing of the result tensor, accommodating variations like grouped-query
# attention where k_proj and v_proj differ from q_proj. This allows precise application of
# LoRA adapters to each sub-component of the multi-head attention mechanism, managing the
# different projection sizes across layers and model architectures.
if self.sizes is not None:
offset += self.sizes[i]
end_idx = offset // self.process_group.size()
else:
end_idx = result.shape[1]
result = self.forward_layer_type(
result, input, adapter_data, layer_name, start_idx, end_idx
)
if is_3d:
result = result.reshape(prev_shape)
return result
def collect_lora_a(self, a_out: torch.Tensor) -> torch.Tensor:
# Tensor parallel implementation of X @ A@B, where A and B are sharded column-wise.
# We use an all-gather between X@A and (X@A)@B to ensure alignment across ranks.
#
# TODO(travis): this is not very efficient as we do an all-gather for every adapter,
# instead we could pre-allocate a (B, a, r) tensor for all adapters with the same
# rank, compute `a_out` on each, and then slice them into the buffer as shown here:
# https://discuss.pytorch.org/t/concatenate-tensors-without-memory-copying/34609
gathered_tensors = [
torch.empty_like(a_out) for _ in range(self.process_group.size())
]
torch.distributed.all_gather(gathered_tensors, a_out)
return torch.cat(gathered_tensors, dim=1)
class TensorParallelAdapterRowLinear(LoraLinear):
def __init__(self, base_layer, layer_id, layer_name, process_group):
super().__init__(base_layer, layer_id, process_group)
self.layer_name = layer_name
@classmethod
def load(cls, base_layer, layer_id, layer_name, process_group):
return cls(base_layer, layer_id, layer_name, process_group)
def forward(
self, input: torch.Tensor, adapter_data: "AdapterBatchData"
) -> torch.Tensor:
result = self.base_layer(input)
if self.layer_name is None:
return result
# Fused all-gather + all-reduce from S-LoRA paper: https://arxiv.org/abs/2311.03285
stride = result.shape[-1] // self.process_group.size()
start_idx = self.process_group.rank() * stride
end_idx = (self.process_group.rank() + 1) * stride
self.forward_layer_type(
result, input, adapter_data, self.layer_name, start_idx, end_idx
)
return result
def collect_lora_a(self, a_out: torch.Tensor) -> torch.Tensor:
# Tensor parallel implementation of X @ A@B, where A and B are sharded row-wise.
# We use an all-reduce between X@A and (X@A)@B to ensure alignment across ranks.
#
# TODO(travis): this is not very efficient as we do an all-reduce for every adapter,
# instead we could pre-allocate a (B, a, r) tensor for all adapters with the same
# rank, compute `a_out` on each, and then slice them into the buffer as shown here:
# https://discuss.pytorch.org/t/concatenate-tensors-without-memory-copying/34609
torch.distributed.all_reduce(a_out, group=self.process_group)
return a_out
...@@ -6,7 +6,7 @@ from loguru import logger ...@@ -6,7 +6,7 @@ from loguru import logger
from transformers.configuration_utils import PretrainedConfig from transformers.configuration_utils import PretrainedConfig
from transformers.models.auto import modeling_auto from transformers.models.auto import modeling_auto
from huggingface_hub import hf_hub_download, HfApi from huggingface_hub import hf_hub_download, HfApi
from typing import Optional from typing import Optional, List
from pathlib import Path from pathlib import Path
from text_generation_server.utils.speculate import get_speculate, set_speculate from text_generation_server.utils.speculate import get_speculate, set_speculate
...@@ -253,6 +253,7 @@ for data in ModelType: ...@@ -253,6 +253,7 @@ for data in ModelType:
def get_model( def get_model(
model_id: str, model_id: str,
lora_adapter_ids: Optional[List[str]],
revision: Optional[str], revision: Optional[str],
sharded: bool, sharded: bool,
quantize: Optional[str], quantize: Optional[str],
...@@ -595,6 +596,7 @@ def get_model( ...@@ -595,6 +596,7 @@ def get_model(
speculator=speculator, speculator=speculator,
dtype=dtype, dtype=dtype,
trust_remote_code=trust_remote_code, trust_remote_code=trust_remote_code,
lora_adapter_ids=lora_adapter_ids,
) )
elif sharded: elif sharded:
raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Sharded Llama")) raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Sharded Llama"))
......
...@@ -90,6 +90,7 @@ class BLOOMSharded(CausalLM): ...@@ -90,6 +90,7 @@ class BLOOMSharded(CausalLM):
torch.distributed.barrier(group=self.process_group) torch.distributed.barrier(group=self.process_group)
super(CausalLM, self).__init__( super(CausalLM, self).__init__(
model_id=model_id,
model=model, model=model,
tokenizer=tokenizer, tokenizer=tokenizer,
requires_padding=True, requires_padding=True,
......
...@@ -538,6 +538,7 @@ class CausalLM(Model): ...@@ -538,6 +538,7 @@ class CausalLM(Model):
tokenizer.add_special_tokens({"pad_token": "[PAD]"}) tokenizer.add_special_tokens({"pad_token": "[PAD]"})
super(CausalLM, self).__init__( super(CausalLM, self).__init__(
model_id=model_id,
model=model, model=model,
tokenizer=tokenizer, tokenizer=tokenizer,
requires_padding=True, requires_padding=True,
......
...@@ -514,6 +514,7 @@ class FlashCohereForCausalLM(torch.nn.Module): ...@@ -514,6 +514,7 @@ class FlashCohereForCausalLM(torch.nn.Module):
max_s: int, max_s: int,
prefill_cache_indices: Optional[torch.Tensor], prefill_cache_indices: Optional[torch.Tensor],
lm_head_indices: Optional[torch.Tensor] = None, lm_head_indices: Optional[torch.Tensor] = None,
adapter_data: Optional[torch.Tensor] = None,
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
hidden_states = self.model( hidden_states = self.model(
input_ids, input_ids,
......
...@@ -724,6 +724,7 @@ class FlashDbrxForCausalLM(torch.nn.Module): ...@@ -724,6 +724,7 @@ class FlashDbrxForCausalLM(torch.nn.Module):
max_s: int, max_s: int,
prefill_cache_indices: Optional[torch.Tensor], prefill_cache_indices: Optional[torch.Tensor],
lm_head_indices: Optional[torch.Tensor] = None, lm_head_indices: Optional[torch.Tensor] = None,
adapter_data: Optional[torch.Tensor] = None,
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
hidden_states = self.model( hidden_states = self.model(
input_ids, input_ids,
......
...@@ -460,6 +460,7 @@ class FlashGemmaForCausalLM(torch.nn.Module): ...@@ -460,6 +460,7 @@ class FlashGemmaForCausalLM(torch.nn.Module):
max_s: int, max_s: int,
prefill_cache_indices: Optional[torch.Tensor], prefill_cache_indices: Optional[torch.Tensor],
lm_head_indices: Optional[torch.Tensor] = None, lm_head_indices: Optional[torch.Tensor] = None,
adapter_data: Optional[torch.Tensor] = None,
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
input_embeds = self.embed_tokens(input_ids) input_embeds = self.embed_tokens(input_ids)
hidden_states = self.model( hidden_states = self.model(
......
...@@ -445,6 +445,7 @@ class FlashGPT2ForCausalLM(torch.nn.Module): ...@@ -445,6 +445,7 @@ class FlashGPT2ForCausalLM(torch.nn.Module):
max_s: int, max_s: int,
prefill_cache_indices: Optional[torch.Tensor] = None, prefill_cache_indices: Optional[torch.Tensor] = None,
lm_head_indices: Optional[torch.Tensor] = None, lm_head_indices: Optional[torch.Tensor] = None,
adapter_data: Optional[torch.Tensor] = None,
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
token_embeds = self.embed_tokens(input_ids) token_embeds = self.embed_tokens(input_ids)
position_embeds = self.embed_positions(position_ids) position_embeds = self.embed_positions(position_ids)
......
...@@ -38,6 +38,8 @@ from text_generation_server.layers import ( ...@@ -38,6 +38,8 @@ from text_generation_server.layers import (
TensorParallelColumnLinear, TensorParallelColumnLinear,
TensorParallelEmbedding, TensorParallelEmbedding,
SpeculativeHead, SpeculativeHead,
TensorParallelMultiAdapterLinear,
TensorParallelAdapterRowLinear,
) )
from text_generation_server.layers.rotary import PositionRotaryEmbedding from text_generation_server.layers.rotary import PositionRotaryEmbedding
from text_generation_server.layers.layernorm import ( from text_generation_server.layers.layernorm import (
...@@ -51,43 +53,61 @@ if SYSTEM == "rocm": ...@@ -51,43 +53,61 @@ if SYSTEM == "rocm":
raise ImportError(f"Could not load `vllm._custom_C`. Full error: {e}") raise ImportError(f"Could not load `vllm._custom_C`. Full error: {e}")
def load_attention(config, prefix, weights): def load_attention(config, prefix, weights, layer_id):
# Only defined in granite. # Only defined in granite.
bias = getattr(config, "attention_bias", False) bias = getattr(config, "attention_bias", False)
head_size = config.hidden_size // config.num_attention_heads
sizes = None
prefixes = None
# if specific model type, load the correct attention
if config.model_type == "phi3": if config.model_type == "phi3":
return TensorParallelColumnLinear.load_qkv( prefix = f"{prefix}.qkv_proj"
base_layer = TensorParallelColumnLinear.load_qkv(
config, config,
prefix=f"{prefix}.qkv_proj", prefix=prefix,
weights=weights, weights=weights,
bias=bias, bias=bias,
num_heads=config.num_attention_heads, num_heads=config.num_attention_heads,
num_key_value_heads=config.num_key_value_heads, num_key_value_heads=config.num_key_value_heads,
) )
elif config.model_type == "baichuan": elif config.model_type == "baichuan":
return TensorParallelColumnLinear.load_qkv( prefix = f"{prefix}.W_pack"
base_layer = TensorParallelColumnLinear.load_qkv(
config, config,
prefix=f"{prefix}.W_pack", prefix=prefix,
weights=weights, weights=weights,
bias=bias, bias=bias,
num_heads=config.num_attention_heads, num_heads=config.num_attention_heads,
num_key_value_heads=config.num_key_value_heads, num_key_value_heads=config.num_key_value_heads,
) )
else:
prefixes = ["q_proj", "k_proj", "v_proj"]
sizes = [
head_size * config.num_attention_heads,
head_size * config.num_key_value_heads,
head_size * config.num_key_value_heads,
]
base_layer = TensorParallelColumnLinear.load_multi(
config,
prefixes=[f"{prefix}.q_proj", f"{prefix}.k_proj", f"{prefix}.v_proj"],
dim=0,
weights=weights,
bias=bias,
)
# otherwise, load the default attention based on the number of heads return TensorParallelMultiAdapterLinear.load(
return TensorParallelColumnLinear.load_multi( base_layer=base_layer,
config, layer_id=layer_id,
prefixes=[f"{prefix}.q_proj", f"{prefix}.k_proj", f"{prefix}.v_proj"], layer_names=prefixes,
dim=0, sizes=sizes,
weights=weights, process_group=weights.process_group,
bias=bias,
) )
class FlashLlamaAttention(torch.nn.Module): class FlashLlamaAttention(torch.nn.Module):
def __init__( def __init__(
self, self,
index: int,
prefix: str, prefix: str,
config, config,
weights, weights,
...@@ -121,14 +141,23 @@ class FlashLlamaAttention(torch.nn.Module): ...@@ -121,14 +141,23 @@ class FlashLlamaAttention(torch.nn.Module):
config.num_key_value_heads // weights.process_group.size() config.num_key_value_heads // weights.process_group.size()
) )
self.query_key_value = load_attention(config, prefix, weights) self.query_key_value = load_attention(config, prefix, weights, index)
self.index = index
self.o_proj = TensorParallelRowLinear.load( o_proj = TensorParallelRowLinear.load(
config, config,
prefix=f"{prefix}.o_proj", prefix=f"{prefix}.o_proj",
weights=weights, weights=weights,
bias=False, bias=False,
) )
self.o_proj = TensorParallelAdapterRowLinear.load(
o_proj,
index,
"o_proj",
process_group=weights.process_group,
)
self.num_groups = self.num_heads // self.num_key_value_heads self.num_groups = self.num_heads // self.num_key_value_heads
self.kv_head_mapping = torch.arange( self.kv_head_mapping = torch.arange(
0, self.num_key_value_heads, dtype=torch.int32, device=weights.device 0, self.num_key_value_heads, dtype=torch.int32, device=weights.device
...@@ -145,8 +174,9 @@ class FlashLlamaAttention(torch.nn.Module): ...@@ -145,8 +174,9 @@ class FlashLlamaAttention(torch.nn.Module):
slots, slots,
input_lengths, input_lengths,
max_s, max_s,
adapter_data,
): ):
qkv = self.query_key_value(hidden_states) qkv = self.query_key_value(hidden_states, adapter_data)
query, kv = qkv.split( query, kv = qkv.split(
[ [
self.head_size * self.num_heads, self.head_size * self.num_heads,
...@@ -190,11 +220,13 @@ class FlashLlamaAttention(torch.nn.Module): ...@@ -190,11 +220,13 @@ class FlashLlamaAttention(torch.nn.Module):
max_s, max_s,
) )
return self.o_proj(attn_output.view(-1, self.num_heads * self.head_size)) return self.o_proj(
attn_output.view(-1, self.num_heads * self.head_size), adapter_data
)
class LlamaMLP(nn.Module): class LlamaMLP(nn.Module):
def __init__(self, prefix, config, weights): def __init__(self, prefix, config, weights, index):
super().__init__() super().__init__()
self.hidden_act = config.hidden_act self.hidden_act = config.hidden_act
self.act = ( self.act = (
...@@ -209,29 +241,54 @@ class LlamaMLP(nn.Module): ...@@ -209,29 +241,54 @@ class LlamaMLP(nn.Module):
), ),
) )
) )
prefixes = None
sizes = None
# Fuse gate and up proj # Fuse gate and up proj
bias = getattr(config, "mlp_bias", False) bias = getattr(config, "mlp_bias", False)
if config.model_type == "phi3": if config.model_type == "phi3":
self.gate_up_proj = TensorParallelColumnLinear.load_gate_up( gate_up_proj = TensorParallelColumnLinear.load_gate_up(
config, config,
prefix=f"{prefix}.gate_up_proj", prefix=f"{prefix}.gate_up_proj",
weights=weights, weights=weights,
bias=bias, bias=bias,
) )
else: else:
self.gate_up_proj = TensorParallelColumnLinear.load_multi( prefixes = [f"gate_proj", f"up_proj"]
sizes = [
config.intermediate_size,
config.intermediate_size,
]
gate_up_proj = TensorParallelColumnLinear.load_multi(
config, config,
prefixes=[f"{prefix}.gate_proj", f"{prefix}.up_proj"], prefixes=[f"{prefix}.gate_proj", f"{prefix}.up_proj"],
weights=weights, weights=weights,
dim=0, dim=0,
bias=bias, bias=bias,
) )
self.down_proj = TensorParallelRowLinear.load(
self.gate_up_proj = TensorParallelMultiAdapterLinear.load(
gate_up_proj,
index,
layer_names=prefixes,
sizes=sizes,
process_group=weights.process_group,
)
down_proj = TensorParallelRowLinear.load(
config, config,
prefix=f"{prefix}.down_proj", prefix=f"{prefix}.down_proj",
weights=weights, weights=weights,
bias=bias, bias=bias,
) )
self.down_proj = TensorParallelAdapterRowLinear.load(
down_proj,
index,
"down_proj",
process_group=weights.process_group,
)
self.intermediate_size = ( self.intermediate_size = (
config.intermediate_size // weights.process_group.size() config.intermediate_size // weights.process_group.size()
) )
...@@ -239,7 +296,7 @@ class LlamaMLP(nn.Module): ...@@ -239,7 +296,7 @@ class LlamaMLP(nn.Module):
# TODO: This is a hotfix to be removed & properly refactored. # TODO: This is a hotfix to be removed & properly refactored.
self.quantize = config.quantize self.quantize = config.quantize
def forward(self, hidden_states): def forward(self, hidden_states, adapter_data):
if ( if (
SYSTEM == "rocm" SYSTEM == "rocm"
and self.hidden_act == "silu" and self.hidden_act == "silu"
...@@ -253,20 +310,27 @@ class LlamaMLP(nn.Module): ...@@ -253,20 +310,27 @@ class LlamaMLP(nn.Module):
device="cuda", device="cuda",
) )
_custom_C.LLMM_Silu(self.gate_up_proj.linear.weight, hidden_states, out, 8) _custom_C.LLMM_Silu(self.gate_up_proj.linear.weight, hidden_states, out, 8)
return self.down_proj(out) return self.down_proj(out, adapter_data)
else: else:
gate_up_states = self.gate_up_proj(hidden_states) gate_up_states = self.gate_up_proj(hidden_states, adapter_data)
gate_up_states = gate_up_states.view(-1, 2, self.intermediate_size) gate_up_states = gate_up_states.view(-1, 2, self.intermediate_size)
return self.down_proj(self.act(gate_up_states[:, 0]) * gate_up_states[:, 1]) return self.down_proj(
self.act(gate_up_states[:, 0]) * gate_up_states[:, 1], adapter_data
)
class FlashLlamaLayer(nn.Module): class FlashLlamaLayer(nn.Module):
def __init__(self, prefix, config, weights): def __init__(self, index, prefix, config, weights):
super().__init__() super().__init__()
self.self_attn = FlashLlamaAttention( self.self_attn = FlashLlamaAttention(
prefix=f"{prefix}.self_attn", config=config, weights=weights index=index,
prefix=f"{prefix}.self_attn",
config=config,
weights=weights,
)
self.mlp = LlamaMLP(
prefix=f"{prefix}.mlp", config=config, weights=weights, index=index
) )
self.mlp = LlamaMLP(prefix=f"{prefix}.mlp", config=config, weights=weights)
self.input_layernorm = FastRMSNorm.load( self.input_layernorm = FastRMSNorm.load(
prefix=f"{prefix}.input_layernorm", weights=weights, eps=config.rms_norm_eps prefix=f"{prefix}.input_layernorm", weights=weights, eps=config.rms_norm_eps
...@@ -289,6 +353,7 @@ class FlashLlamaLayer(nn.Module): ...@@ -289,6 +353,7 @@ class FlashLlamaLayer(nn.Module):
slots, slots,
input_lengths, input_lengths,
max_s, max_s,
adapter_data,
): ):
normed_hidden_states, res = self.input_layernorm(hidden_states, residual) normed_hidden_states, res = self.input_layernorm(hidden_states, residual)
...@@ -303,6 +368,7 @@ class FlashLlamaLayer(nn.Module): ...@@ -303,6 +368,7 @@ class FlashLlamaLayer(nn.Module):
slots, slots,
input_lengths, input_lengths,
max_s, max_s,
adapter_data,
) )
# faster post attention rms norm # faster post attention rms norm
...@@ -310,7 +376,7 @@ class FlashLlamaLayer(nn.Module): ...@@ -310,7 +376,7 @@ class FlashLlamaLayer(nn.Module):
attn_output, res attn_output, res
) )
mlp_output = self.mlp(normed_attn_res_output) mlp_output = self.mlp(normed_attn_res_output, adapter_data)
return mlp_output, attn_res return mlp_output, attn_res
...@@ -325,6 +391,7 @@ class FlashLlamaModel(torch.nn.Module): ...@@ -325,6 +391,7 @@ class FlashLlamaModel(torch.nn.Module):
self.layers = nn.ModuleList( self.layers = nn.ModuleList(
[ [
FlashLlamaLayer( FlashLlamaLayer(
index=layer_id,
prefix=( prefix=(
f"model.layers.{layer_id}" f"model.layers.{layer_id}"
if not prefix if not prefix
...@@ -360,6 +427,7 @@ class FlashLlamaModel(torch.nn.Module): ...@@ -360,6 +427,7 @@ class FlashLlamaModel(torch.nn.Module):
max_s: int, max_s: int,
true_max_s: int, true_max_s: int,
prefill_cache_indices: Optional[torch.Tensor], prefill_cache_indices: Optional[torch.Tensor],
adapter_data,
) -> torch.Tensor: ) -> torch.Tensor:
hidden_states = inputs_embeds hidden_states = inputs_embeds
...@@ -382,6 +450,7 @@ class FlashLlamaModel(torch.nn.Module): ...@@ -382,6 +450,7 @@ class FlashLlamaModel(torch.nn.Module):
slots, slots,
input_lengths, input_lengths,
max_s, max_s,
adapter_data,
) )
hidden_states, _ = self.norm(hidden_states, residual) hidden_states, _ = self.norm(hidden_states, residual)
...@@ -423,6 +492,7 @@ class FlashLlamaForCausalLM(torch.nn.Module): ...@@ -423,6 +492,7 @@ class FlashLlamaForCausalLM(torch.nn.Module):
max_s: int, max_s: int,
prefill_cache_indices: Optional[torch.Tensor] = None, prefill_cache_indices: Optional[torch.Tensor] = None,
lm_head_indices: Optional[torch.Tensor] = None, lm_head_indices: Optional[torch.Tensor] = None,
adapter_data: Optional[torch.Tensor] = None,
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
inputs_embeds = self.embed_tokens(input_ids) inputs_embeds = self.embed_tokens(input_ids)
hidden_states = self.model( hidden_states = self.model(
...@@ -436,6 +506,7 @@ class FlashLlamaForCausalLM(torch.nn.Module): ...@@ -436,6 +506,7 @@ class FlashLlamaForCausalLM(torch.nn.Module):
max_s, max_s,
true_max_s=max_s, true_max_s=max_s,
prefill_cache_indices=prefill_cache_indices, prefill_cache_indices=prefill_cache_indices,
adapter_data=adapter_data,
) )
if lm_head_indices is not None: if lm_head_indices is not None:
hidden_states = hidden_states[lm_head_indices] hidden_states = hidden_states[lm_head_indices]
......
...@@ -38,6 +38,8 @@ from text_generation_server.layers import ( ...@@ -38,6 +38,8 @@ from text_generation_server.layers import (
TensorParallelEmbedding, TensorParallelEmbedding,
SpeculativeHead, SpeculativeHead,
get_linear, get_linear,
TensorParallelMultiAdapterLinear,
TensorParallelAdapterRowLinear,
) )
from text_generation_server.layers.rotary import PositionRotaryEmbedding from text_generation_server.layers.rotary import PositionRotaryEmbedding
from text_generation_server.layers.layernorm import ( from text_generation_server.layers.layernorm import (
...@@ -107,12 +109,7 @@ class MistralConfig(PretrainedConfig): ...@@ -107,12 +109,7 @@ class MistralConfig(PretrainedConfig):
class MistralAttention(torch.nn.Module): class MistralAttention(torch.nn.Module):
def __init__( def __init__(self, prefix: str, config, weights, layer_id):
self,
prefix: str,
config,
weights,
):
super().__init__() super().__init__()
self.max_past = ( self.max_past = (
config.sliding_window if config.sliding_window is not None else -1 config.sliding_window if config.sliding_window is not None else -1
...@@ -140,7 +137,7 @@ class MistralAttention(torch.nn.Module): ...@@ -140,7 +137,7 @@ class MistralAttention(torch.nn.Module):
config.num_key_value_heads // weights.process_group.size() config.num_key_value_heads // weights.process_group.size()
) )
self.query_key_value = TensorParallelColumnLinear.load_multi( query_key_value = TensorParallelColumnLinear.load_multi(
config, config,
prefixes=[f"{prefix}.q_proj", f"{prefix}.k_proj", f"{prefix}.v_proj"], prefixes=[f"{prefix}.q_proj", f"{prefix}.k_proj", f"{prefix}.v_proj"],
dim=0, dim=0,
...@@ -148,12 +145,31 @@ class MistralAttention(torch.nn.Module): ...@@ -148,12 +145,31 @@ class MistralAttention(torch.nn.Module):
bias=False, bias=False,
) )
self.o_proj = TensorParallelRowLinear.load( head_size = config.hidden_size // config.num_attention_heads
self.query_key_value = TensorParallelMultiAdapterLinear.load(
query_key_value,
layer_id,
["q_proj", "k_proj", "v_proj"],
sizes=[
head_size * config.num_attention_heads,
head_size * config.num_key_value_heads,
head_size * config.num_key_value_heads,
],
process_group=weights.process_group,
)
o_proj = TensorParallelRowLinear.load(
config, config,
prefix=f"{prefix}.o_proj", prefix=f"{prefix}.o_proj",
weights=weights, weights=weights,
bias=False, bias=False,
) )
self.o_proj = TensorParallelAdapterRowLinear.load(
o_proj,
layer_id,
"o_proj",
process_group=weights.process_group,
)
self.num_groups = self.num_heads // self.num_key_value_heads self.num_groups = self.num_heads // self.num_key_value_heads
self.kv_head_mapping = torch.arange( self.kv_head_mapping = torch.arange(
0, self.num_key_value_heads, dtype=torch.int32, device=weights.device 0, self.num_key_value_heads, dtype=torch.int32, device=weights.device
...@@ -171,8 +187,9 @@ class MistralAttention(torch.nn.Module): ...@@ -171,8 +187,9 @@ class MistralAttention(torch.nn.Module):
input_lengths, input_lengths,
max_s, max_s,
prefill_cache_indices, prefill_cache_indices,
adapter_data,
): ):
qkv = self.query_key_value(hidden_states) qkv = self.query_key_value(hidden_states, adapter_data)
query, kv = qkv.split( query, kv = qkv.split(
[ [
self.head_size * self.num_heads, self.head_size * self.num_heads,
...@@ -224,11 +241,13 @@ class MistralAttention(torch.nn.Module): ...@@ -224,11 +241,13 @@ class MistralAttention(torch.nn.Module):
max_s, max_s,
) )
return self.o_proj(attn_output.view(-1, self.num_heads * self.head_size)) return self.o_proj(
attn_output.view(-1, self.num_heads * self.head_size), adapter_data
)
class MistralMLP(nn.Module): class MistralMLP(nn.Module):
def __init__(self, prefix, config, weights): def __init__(self, prefix, config, weights, layer_id):
super().__init__() super().__init__()
self.hidden_act = config.hidden_act self.hidden_act = config.hidden_act
self.act = ( self.act = (
...@@ -244,19 +263,37 @@ class MistralMLP(nn.Module): ...@@ -244,19 +263,37 @@ class MistralMLP(nn.Module):
) )
) )
# Fuse gate and up proj # Fuse gate and up proj
self.gate_up_proj = TensorParallelColumnLinear.load_multi( gate_up_proj = TensorParallelColumnLinear.load_multi(
config, config,
prefixes=[f"{prefix}.gate_proj", f"{prefix}.up_proj"], prefixes=[f"{prefix}.gate_proj", f"{prefix}.up_proj"],
weights=weights, weights=weights,
dim=0, dim=0,
bias=False, bias=False,
) )
self.down_proj = TensorParallelRowLinear.load( self.gate_up_proj = TensorParallelMultiAdapterLinear.load(
gate_up_proj,
layer_id,
["gate_proj", "up_proj"],
sizes=[
config.intermediate_size,
config.intermediate_size,
],
process_group=weights.process_group,
)
down_proj = TensorParallelRowLinear.load(
config, config,
prefix=f"{prefix}.down_proj", prefix=f"{prefix}.down_proj",
weights=weights, weights=weights,
bias=False, bias=False,
) )
self.down_proj = TensorParallelAdapterRowLinear.load(
down_proj,
layer_id,
"down_proj",
process_group=weights.process_group,
)
self.intermediate_size = ( self.intermediate_size = (
config.intermediate_size // weights.process_group.size() config.intermediate_size // weights.process_group.size()
) )
...@@ -264,7 +301,7 @@ class MistralMLP(nn.Module): ...@@ -264,7 +301,7 @@ class MistralMLP(nn.Module):
# TODO: This is a hotfix to be removed & properly refactored. # TODO: This is a hotfix to be removed & properly refactored.
self.quantize = config.quantize self.quantize = config.quantize
def forward(self, hidden_states): def forward(self, hidden_states, adapter_data):
if ( if (
SYSTEM == "rocm" SYSTEM == "rocm"
and self.hidden_act == "silu" and self.hidden_act == "silu"
...@@ -278,20 +315,27 @@ class MistralMLP(nn.Module): ...@@ -278,20 +315,27 @@ class MistralMLP(nn.Module):
device="cuda", device="cuda",
) )
_custom_C.LLMM_Silu(self.gate_up_proj.linear.weight, hidden_states, out, 8) _custom_C.LLMM_Silu(self.gate_up_proj.linear.weight, hidden_states, out, 8)
return self.down_proj(out) return self.down_proj(out, adapter_data)
else: else:
gate_up_states = self.gate_up_proj(hidden_states) gate_up_states = self.gate_up_proj(hidden_states, adapter_data)
gate_up_states = gate_up_states.view(-1, 2, self.intermediate_size) gate_up_states = gate_up_states.view(-1, 2, self.intermediate_size)
return self.down_proj(self.act(gate_up_states[:, 0]) * gate_up_states[:, 1]) return self.down_proj(
self.act(gate_up_states[:, 0]) * gate_up_states[:, 1], adapter_data
)
class MistralLayer(nn.Module): class MistralLayer(nn.Module):
def __init__(self, prefix, config, weights): def __init__(self, prefix, config, weights, layer_id):
super().__init__() super().__init__()
self.self_attn = MistralAttention( self.self_attn = MistralAttention(
prefix=f"{prefix}.self_attn", config=config, weights=weights prefix=f"{prefix}.self_attn",
config=config,
weights=weights,
layer_id=layer_id,
)
self.mlp = MistralMLP(
prefix=f"{prefix}.mlp", config=config, weights=weights, layer_id=layer_id
) )
self.mlp = MistralMLP(prefix=f"{prefix}.mlp", config=config, weights=weights)
self.input_layernorm = FastRMSNorm.load( self.input_layernorm = FastRMSNorm.load(
prefix=f"{prefix}.input_layernorm", weights=weights, eps=config.rms_norm_eps prefix=f"{prefix}.input_layernorm", weights=weights, eps=config.rms_norm_eps
...@@ -315,6 +359,7 @@ class MistralLayer(nn.Module): ...@@ -315,6 +359,7 @@ class MistralLayer(nn.Module):
input_lengths, input_lengths,
max_s, max_s,
prefill_cache_indices, prefill_cache_indices,
adapter_data,
): ):
normed_hidden_states, res = self.input_layernorm(hidden_states, residual) normed_hidden_states, res = self.input_layernorm(hidden_states, residual)
...@@ -330,6 +375,7 @@ class MistralLayer(nn.Module): ...@@ -330,6 +375,7 @@ class MistralLayer(nn.Module):
input_lengths, input_lengths,
max_s, max_s,
prefill_cache_indices, prefill_cache_indices,
adapter_data,
) )
# faster post attention rms norm # faster post attention rms norm
...@@ -337,7 +383,7 @@ class MistralLayer(nn.Module): ...@@ -337,7 +383,7 @@ class MistralLayer(nn.Module):
attn_output, res attn_output, res
) )
mlp_output = self.mlp(normed_attn_res_output) mlp_output = self.mlp(normed_attn_res_output, adapter_data)
return mlp_output, attn_res return mlp_output, attn_res
...@@ -355,6 +401,7 @@ class MistralModel(torch.nn.Module): ...@@ -355,6 +401,7 @@ class MistralModel(torch.nn.Module):
prefix=f"{prefix}.layers.{layer_id}", prefix=f"{prefix}.layers.{layer_id}",
config=config, config=config,
weights=weights, weights=weights,
layer_id=layer_id,
) )
for layer_id in range(config.num_hidden_layers) for layer_id in range(config.num_hidden_layers)
] ]
...@@ -381,6 +428,7 @@ class MistralModel(torch.nn.Module): ...@@ -381,6 +428,7 @@ class MistralModel(torch.nn.Module):
max_s: int, max_s: int,
true_max_s: int, true_max_s: int,
prefill_cache_indices: Optional[torch.Tensor], prefill_cache_indices: Optional[torch.Tensor],
adapter_data: Optional[torch.Tensor] = None,
): ):
hidden_states = inputs_embeds hidden_states = inputs_embeds
# Get rotary cos and sin for this forward # Get rotary cos and sin for this forward
...@@ -403,6 +451,7 @@ class MistralModel(torch.nn.Module): ...@@ -403,6 +451,7 @@ class MistralModel(torch.nn.Module):
input_lengths, input_lengths,
max_s, max_s,
prefill_cache_indices, prefill_cache_indices,
adapter_data,
) )
hidden_states, _ = self.norm(hidden_states, residual) hidden_states, _ = self.norm(hidden_states, residual)
...@@ -454,6 +503,7 @@ class FlashMistralForCausalLM(torch.nn.Module): ...@@ -454,6 +503,7 @@ class FlashMistralForCausalLM(torch.nn.Module):
max_s: int, max_s: int,
prefill_cache_indices: Optional[torch.Tensor], prefill_cache_indices: Optional[torch.Tensor],
lm_head_indices: Optional[torch.Tensor] = None, lm_head_indices: Optional[torch.Tensor] = None,
adapter_data: Optional[torch.Tensor] = None,
) -> torch.Tensor: ) -> torch.Tensor:
true_max_s = max_s true_max_s = max_s
if prefill_cache_indices is not None: if prefill_cache_indices is not None:
...@@ -476,6 +526,7 @@ class FlashMistralForCausalLM(torch.nn.Module): ...@@ -476,6 +526,7 @@ class FlashMistralForCausalLM(torch.nn.Module):
max_s, max_s,
true_max_s, true_max_s,
prefill_cache_indices, prefill_cache_indices,
adapter_data,
) )
if lm_head_indices is not None: if lm_head_indices is not None:
hidden_states = hidden_states[lm_head_indices] hidden_states = hidden_states[lm_head_indices]
......
...@@ -638,6 +638,7 @@ class FlashMixtralForCausalLM(torch.nn.Module): ...@@ -638,6 +638,7 @@ class FlashMixtralForCausalLM(torch.nn.Module):
max_s: int, max_s: int,
prefill_cache_indices: Optional[torch.Tensor], prefill_cache_indices: Optional[torch.Tensor],
lm_head_indices: Optional[torch.Tensor] = None, lm_head_indices: Optional[torch.Tensor] = None,
adapter_data: Optional[torch.Tensor] = None,
) -> torch.Tensor: ) -> torch.Tensor:
true_max_s = max_s true_max_s = max_s
if prefill_cache_indices is not None: if prefill_cache_indices is not None:
......
...@@ -390,6 +390,7 @@ class FlashGPTNeoXForCausalLM(FlashGPTNeoXPreTrainedModel): ...@@ -390,6 +390,7 @@ class FlashGPTNeoXForCausalLM(FlashGPTNeoXPreTrainedModel):
max_s: int, max_s: int,
prefill_cache_indices: Optional[torch.Tensor], prefill_cache_indices: Optional[torch.Tensor],
lm_head_indices: Optional[torch.Tensor] = None, lm_head_indices: Optional[torch.Tensor] = None,
adapter_data: Optional[torch.Tensor] = None,
) -> torch.Tensor: ) -> torch.Tensor:
hidden_states = self.gpt_neox( hidden_states = self.gpt_neox(
input_ids, input_ids,
......
...@@ -74,6 +74,7 @@ class PaliGemmaForConditionalGeneration(nn.Module): ...@@ -74,6 +74,7 @@ class PaliGemmaForConditionalGeneration(nn.Module):
# Unused here # Unused here
pixel_attention_mask: Optional[torch.BoolTensor] = None, pixel_attention_mask: Optional[torch.BoolTensor] = None,
image_sizes: Optional[torch.Tensor] = None, image_sizes: Optional[torch.Tensor] = None,
adapter_data: Optional[torch.Tensor] = None,
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
inputs_embeds = self.text_model.embed_tokens(input_ids) inputs_embeds = self.text_model.embed_tokens(input_ids)
# TODO This is odd but apparently pali gemma position ids start at 1. # TODO This is odd but apparently pali gemma position ids start at 1.
......
...@@ -400,6 +400,7 @@ class FlashPhiForCausalLM(torch.nn.Module): ...@@ -400,6 +400,7 @@ class FlashPhiForCausalLM(torch.nn.Module):
max_s: int, max_s: int,
prefill_cache_indices: Optional[torch.Tensor], prefill_cache_indices: Optional[torch.Tensor],
lm_head_indices: Optional[torch.Tensor] = None, lm_head_indices: Optional[torch.Tensor] = None,
adapter_data: Optional[torch.Tensor] = None,
) -> torch.Tensor: ) -> torch.Tensor:
hidden_states = self.model( hidden_states = self.model(
input_ids, input_ids,
......
...@@ -359,6 +359,7 @@ class Qwen2ForCausalLM(torch.nn.Module): ...@@ -359,6 +359,7 @@ class Qwen2ForCausalLM(torch.nn.Module):
max_s: int, max_s: int,
prefill_cache_indices: Optional[torch.Tensor] = None, prefill_cache_indices: Optional[torch.Tensor] = None,
lm_head_indices: Optional[torch.Tensor] = None, lm_head_indices: Optional[torch.Tensor] = None,
adapter_data: Optional[torch.Tensor] = None,
) -> torch.Tensor: ) -> torch.Tensor:
true_max_s = max_s true_max_s = max_s
if prefill_cache_indices is not None: if prefill_cache_indices is not None:
......
...@@ -672,6 +672,7 @@ class FlashRWForCausalLM(FlashRWPreTrainedModel): ...@@ -672,6 +672,7 @@ class FlashRWForCausalLM(FlashRWPreTrainedModel):
max_s: int, max_s: int,
prefill_cache_indices: Optional[torch.Tensor], prefill_cache_indices: Optional[torch.Tensor],
lm_head_indices: Optional[torch.Tensor] = None, lm_head_indices: Optional[torch.Tensor] = None,
adapter_data: Optional[torch.Tensor] = None,
) -> torch.Tensor: ) -> torch.Tensor:
hidden_states = self.transformer( hidden_states = self.transformer(
input_ids, input_ids,
......
...@@ -483,6 +483,7 @@ class FlashSantacoderForCausalLM(nn.Module): ...@@ -483,6 +483,7 @@ class FlashSantacoderForCausalLM(nn.Module):
max_s: int, max_s: int,
prefill_cache_indices: Optional[torch.Tensor], prefill_cache_indices: Optional[torch.Tensor],
lm_head_indices: Optional[torch.Tensor] = None, lm_head_indices: Optional[torch.Tensor] = None,
adapter_data: Optional[torch.Tensor] = None,
) -> torch.Tensor: ) -> torch.Tensor:
hidden_states = self.transformer( hidden_states = self.transformer(
input_ids, input_ids,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment