Unverified Commit abd58ff8 authored by Nicolas Patry's avatar Nicolas Patry Committed by GitHub
Browse files

feat(server): Rework model loading (#344)

# What does this PR do?

Reworked the loading logic. Idea is to use cleaner loading code:

- Remove need for `no_init_weights`
- Remove all weird `bnb_linear` and `load_weights` and
`post_load_weights`.

New code layout:

- New class `Weights` in charge of handling loading the weights from
multiple files into appropiate tensors (potentially sharded)
- TP layers now are "shells", they contain the code to know what kind of
sharding we need + eventual `all_reduce`. They do not inherit from
linear, but they contain some kind of Linear instead
- the contained linear can be either FastLinear, BnbLinear or GPTq
Linear next.
- All modeling code is explictly made for sharding, process group is
just no-ops for non sharded code (removes a lot of test cases)

![Screenshot from 2023-05-19
23-19-59](https://github.com/huggingface/text-generation-inference/assets/204321/9a802654-74a3-488c-87a8-073743a6143f

)

---------
Co-authored-by: default avatarUbuntu <ubuntu@ip-172-31-41-161.taildb5d.ts.net>
Co-authored-by: default avatarUbuntu <ubuntu@ip-172-31-41-161.ec2.internal>
Co-authored-by: default avatarOlivierDehaene <olivier@huggingface.co>
Co-authored-by: default avatarOlivierDehaene <23298448+OlivierDehaene@users.noreply.github.com>
parent 19c41824
......@@ -6,12 +6,17 @@ from transformers import AutoTokenizer
from text_generation_server.pb import generate_pb2
from text_generation_server.models.causal_lm import CausalLMBatch
from text_generation_server.models.bloom import BloomCausalLMBatch, BLOOM
from text_generation_server.utils import weight_hub_files, download_weights
from text_generation_server.models.bloom import BloomCausalLMBatch, BLOOMSharded
@pytest.fixture(scope="session")
def default_bloom():
return BLOOM("bigscience/bloom-560m")
model_id = "bigscience/bloom-560m"
revision = "main"
filenames = weight_hub_files(model_id, revision, ".safetensors")
download_weights(filenames, model_id, revision)
return BLOOMSharded(model_id)
@pytest.fixture(scope="session")
......
import os
import torch
from loguru import logger
......@@ -8,17 +9,20 @@ from typing import Optional
from text_generation_server.models.model import Model
from text_generation_server.models.causal_lm import CausalLM
from text_generation_server.models.flash_causal_lm import FlashCausalLM
from text_generation_server.models.bloom import BLOOM, BLOOMSharded
from text_generation_server.models.bloom import BLOOMSharded
from text_generation_server.models.seq2seq_lm import Seq2SeqLM
from text_generation_server.models.rw import RW
from text_generation_server.models.opt import OPT, OPTSharded
from text_generation_server.models.galactica import Galactica, GalacticaSharded
from text_generation_server.models.opt import OPTSharded
from text_generation_server.models.galactica import GalacticaSharded
from text_generation_server.models.santacoder import SantaCoder
from text_generation_server.models.gpt_neox import GPTNeoxSharded
from text_generation_server.models.t5 import T5Sharded
from text_generation_server.models.gpt_neox import GPTNeoxSharded
try:
if torch.cuda.is_available():
if (
torch.cuda.is_available()
and not os.getenv("USE_FLASH_ATTENTION", "").lower() == "false"
):
major, minor = torch.cuda.get_device_capability()
is_sm75 = major == 7 and minor == 5
is_sm8x = major == 8 and minor >= 0
......@@ -30,14 +34,12 @@ try:
f"GPU with CUDA capability {major} {minor} is not supported"
)
from text_generation_server.models.flash_neox import FlashNeoX, FlashNeoXSharded
from text_generation_server.models.flash_rw import FlashRW, FlashRWSharded
from text_generation_server.models.flash_rw import FlashRWSharded
from text_generation_server.models.flash_neox import FlashNeoXSharded
from text_generation_server.models.flash_llama import (
FlashLlama,
FlashLlamaSharded,
)
from text_generation_server.models.flash_santacoder import (
FlashSantacoder,
FlashSantacoderSharded,
)
......@@ -52,30 +54,22 @@ except ImportError:
__all__ = [
"Model",
"BLOOM",
"BLOOMSharded",
"CausalLM",
"FlashCausalLM",
"Galactica",
"GalacticaSharded",
"GPTNeoxSharded",
"Seq2SeqLM",
"SantaCoder",
"OPT",
"OPTSharded",
"T5Sharded",
"get_model",
]
if FLASH_ATTENTION:
__all__.append(FlashNeoX)
__all__.append(FlashNeoXSharded)
__all__.append(FlashRW)
__all__.append(FlashRWSharded)
__all__.append(FlashSantacoder)
__all__.append(FlashSantacoderSharded)
__all__.append(FlashLlama)
__all__.append(FlashLlamaSharded)
FLASH_ATT_ERROR_MESSAGE = (
"{} requires Flash Attention CUDA kernels to be installed.\n"
......@@ -102,36 +96,24 @@ def get_model(
trust_remote_code: bool,
) -> Model:
if "facebook/galactica" in model_id:
if sharded:
return GalacticaSharded(
model_id,
revision,
quantize=quantize,
trust_remote_code=trust_remote_code,
)
else:
return Galactica(
model_id,
revision,
quantize=quantize,
trust_remote_code=trust_remote_code,
)
return GalacticaSharded(
model_id, revision, quantize=quantize, trust_remote_code=trust_remote_code
)
if model_id.startswith("bigcode/"):
if sharded:
if not FLASH_ATTENTION:
raise NotImplementedError(
FLASH_ATT_ERROR_MESSAGE.format(f"Sharded Santacoder")
)
if FLASH_ATTENTION:
return FlashSantacoderSharded(
model_id,
revision,
quantize=quantize,
trust_remote_code=trust_remote_code,
)
elif sharded:
raise NotImplementedError(
FLASH_ATT_ERROR_MESSAGE.format("Sharded Santacoder")
)
else:
santacoder_cls = FlashSantacoder if FLASH_ATTENTION else SantaCoder
return santacoder_cls(
return SantaCoder(
model_id,
revision,
quantize=quantize,
......@@ -144,20 +126,19 @@ def get_model(
model_type = config_dict["model_type"]
if model_type == "gpt_bigcode":
if sharded:
if not FLASH_ATTENTION:
raise NotImplementedError(
FLASH_ATT_ERROR_MESSAGE.format(f"Sharded Santacoder")
)
if FLASH_ATTENTION:
return FlashSantacoderSharded(
model_id,
revision,
quantize=quantize,
trust_remote_code=trust_remote_code,
)
elif sharded:
raise NotImplementedError(
FLASH_ATT_ERROR_MESSAGE.format("Sharded Santacoder")
)
else:
santacoder_cls = FlashSantacoder if FLASH_ATTENTION else SantaCoder
return santacoder_cls(
return SantaCoder(
model_id,
revision,
quantize=quantize,
......@@ -165,33 +146,45 @@ def get_model(
)
if model_type == "bloom":
if sharded:
return BLOOMSharded(
return BLOOMSharded(
model_id, revision, quantize=quantize, trust_remote_code=trust_remote_code
)
elif model_type == "gpt_neox":
if FLASH_ATTENTION:
return FlashNeoXSharded(
model_id,
revision,
quantize=quantize,
trust_remote_code=trust_remote_code,
)
elif sharded:
return GPTNeoxSharded(
model_id,
revision,
quantize=quantize,
trust_remote_code=trust_remote_code,
)
else:
return BLOOM(
return CausalLM(
model_id,
revision,
quantize=quantize,
trust_remote_code=trust_remote_code,
)
if model_type == "gpt_neox":
if sharded:
neox_cls = FlashNeoXSharded if FLASH_ATTENTION else GPTNeoxSharded
return neox_cls(
elif model_type == "llama":
if FLASH_ATTENTION:
return FlashLlama(
model_id,
revision,
quantize=quantize,
trust_remote_code=trust_remote_code,
)
elif sharded:
raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Sharded Llama"))
else:
neox_cls = FlashNeoX if FLASH_ATTENTION else CausalLM
return neox_cls(
return CausalLM(
model_id,
revision,
quantize=quantize,
......@@ -217,7 +210,7 @@ def get_model(
)
else:
if FLASH_ATTENTION and not config_dict.get("alibi", False):
return FlashRW(
return FlashRWSharded(
model_id,
revision,
quantize=quantize,
......@@ -231,42 +224,12 @@ def get_model(
trust_remote_code=trust_remote_code,
)
if model_type == "llama":
if sharded:
if FLASH_ATTENTION:
return FlashLlamaSharded(
model_id,
revision,
quantize=quantize,
trust_remote_code=trust_remote_code,
)
raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format(f"Sharded Llama"))
else:
llama_cls = FlashLlama if FLASH_ATTENTION else CausalLM
return llama_cls(
model_id,
revision,
quantize=quantize,
trust_remote_code=trust_remote_code,
)
if model_type == "opt":
if sharded:
return OPTSharded(
model_id,
revision,
quantize=quantize,
trust_remote_code=trust_remote_code,
)
else:
return OPT(
model_id,
revision,
quantize=quantize,
trust_remote_code=trust_remote_code,
)
elif model_type == "opt":
return OPTSharded(
model_id, revision, quantize=quantize, trust_remote_code=trust_remote_code
)
if model_type == "t5":
elif model_type == "t5":
if sharded:
return T5Sharded(
model_id,
......
import torch
import torch.distributed
from typing import List, Optional, Type
from typing import Optional, Type
from accelerate import init_empty_weights
from safetensors import safe_open
from transformers import (
AutoTokenizer,
AutoModelForCausalLM,
AutoConfig,
PreTrainedTokenizerBase,
)
from transformers.models.bloom.parallel_layers import (
TensorParallelColumnLinear,
TensorParallelEmbedding,
TensorParallelRowLinear,
)
from text_generation_server.models.custom_modeling.bloom_modeling import (
BloomForCausalLM,
)
from text_generation_server.models import CausalLM
from text_generation_server.models.causal_lm import CausalLMBatch
from text_generation_server.pb import generate_pb2
from text_generation_server.utils import (
initialize_torch_distributed,
weight_files,
Weights,
)
HAS_BITS_AND_BYTES = True
try:
import bitsandbytes as bnb
from bitsandbytes.nn import Int8Params
except Exception as e:
HAS_BITS_AND_BYTES = False
class BloomCausalLMBatch(CausalLMBatch):
@classmethod
......@@ -42,34 +31,12 @@ class BloomCausalLMBatch(CausalLMBatch):
dtype: torch.dtype,
device: torch.device,
) -> "CausalLMBatch":
batch = super(BloomCausalLMBatch, cls).from_pb(
pb=pb, tokenizer=tokenizer, dtype=dtype, device=device
)
batch = super().from_pb(pb=pb, tokenizer=tokenizer, dtype=dtype, device=device)
batch.keys_head_dim_last = False
return batch
class BLOOM(CausalLM):
def __init__(
self,
model_id: str,
revision: Optional[str] = None,
quantize: Optional[str] = None,
trust_remote_code: bool = False,
):
super(BLOOM, self).__init__(
model_id=model_id,
revision=revision,
quantize=quantize,
trust_remote_code=trust_remote_code,
)
@property
def batch_type(self) -> Type[CausalLMBatch]:
return BloomCausalLMBatch
class BLOOMSharded(BLOOM):
class BLOOMSharded(CausalLM):
def __init__(
self,
model_id: str,
......@@ -101,25 +68,16 @@ class BLOOMSharded(BLOOM):
trust_remote_code=trust_remote_code,
)
config.pad_token_id = 3
config.quantize = quantize
torch.distributed.barrier(group=self.process_group)
filenames = weight_files(model_id, revision=revision, extension=".safetensors")
weights = Weights(
filenames, device=device, dtype=dtype, process_group=self.process_group
)
with init_empty_weights():
model = AutoModelForCausalLM.from_config(
config, trust_remote_code=trust_remote_code
)
model = BloomForCausalLM(config, weights)
torch.distributed.barrier(group=self.process_group)
self.load_weights(
model,
filenames,
quantize=quantize,
device=device,
dtype=dtype,
rank=rank,
world_size=world_size,
)
torch.distributed.barrier(group=self.process_group)
super(CausalLM, self).__init__(
model=model,
......@@ -131,132 +89,9 @@ class BLOOMSharded(BLOOM):
world_size=world_size,
)
@staticmethod
def load_weights(
model,
filenames: List[str],
quantize: Optional[str],
device: torch.device,
dtype: torch.dtype,
rank: int,
world_size: int,
):
parameters = dict(model.named_parameters())
for file in filenames:
with safe_open(
file, framework="pt", device=str(device) if quantize is None else "cpu"
) as f:
for name in f.keys():
if name.startswith("transformer.") or name.startswith("lm_head."):
full_name = name
else:
full_name = f"transformer.{name}"
module_name, param_name = full_name.rsplit(".", 1)
module = model.get_submodule(module_name)
current_tensor = parameters[full_name]
slice_ = f.get_slice(name)
if isinstance(module, TensorParallelColumnLinear):
size = slice_.get_shape()[0]
block_size = size // world_size
start = rank * block_size
stop = (rank + 1) * block_size
tensor = slice_[start:stop]
elif isinstance(module, TensorParallelRowLinear):
if param_name == "weight":
size = slice_.get_shape()[1]
block_size = size // world_size
start = rank * block_size
stop = (rank + 1) * block_size
tensor = slice_[:, start:stop]
else:
tensor = slice_[:]
# XXX: Hack for Rowlinear to add the bias only once.
if rank != 0:
tensor = torch.zeros_like(tensor)
elif (
isinstance(module, TensorParallelEmbedding)
or name == "lm_head.weight"
):
size = slice_.get_shape()[0]
block_size = size // world_size
start = rank * block_size
stop = (rank + 1) * block_size
tensor = slice_[start:stop]
else:
tensor = slice_[:]
if current_tensor.shape != tensor.shape:
raise ValueError(
f"Name {name} -- Current {current_tensor.shape} and got {tensor.shape}"
)
tensor = tensor.contiguous().to(dtype)
if quantize == "bitsandbytes":
if not HAS_BITS_AND_BYTES:
raise ImportError(
"bitsandbytes is not available on your machine either because it is not installed "
"or you don't have a GPU.\n"
"You can install it with `pip install bitsandbytes`."
)
if (
type(module)
in [TensorParallelRowLinear, TensorParallelColumnLinear]
and param_name == "weight"
):
tensor = Int8Params(
tensor,
has_fp16_weights=False,
requires_grad=False,
).to(device)
state = bnb.MatmulLtState()
state.threshold = 6.0
state.has_fp16_weights = False
state.memory_efficient_backward = False
state.use_pool = True
state.CB = tensor.CB
state.SCB = tensor.SCB
tensor.CB = None
tensor.SCB = None
def replace_linear(state):
def linear(input, weight, bias):
out = bnb.matmul(
input,
weight,
state=state,
threshold=state.threshold,
bias=bias,
)
if state.CB is not None:
# we converted 8-bit row major to turing/ampere format
# in the first inference pass
# we no longer need the row-major weight
del state.CB
weight.data = state.CxB
return out
return linear
module.linear = replace_linear(state)
else:
tensor = tensor.to(device)
elif quantize == "gptq":
raise NotImplementedError("`gptq` is not implemented for now")
elif quantize is None:
tensor = tensor.to(device)
else:
raise ValueError(f"Unexpected quantize `{quantize}`")
module._parameters[param_name] = tensor
if name == "word_embeddings.weight":
model.lm_head._parameters["weight"] = tensor
@property
def batch_type(self) -> Type[CausalLMBatch]:
return BloomCausalLMBatch
def forward(
self, input_ids, attention_mask, position_ids, past_key_values: Optional = None
......@@ -269,9 +104,5 @@ class BLOOMSharded(BLOOM):
use_cache=True,
)
# Logits are sharded, so we need to gather them
logits = [torch.empty_like(outputs.logits) for _ in range(self.world_size)]
torch.distributed.all_gather(logits, outputs.logits, group=self.process_group)
logits = torch.cat(logits, dim=2)
logits = outputs.logits
return logits, outputs.past_key_values
# coding=utf-8
# Copyright 2022 HuggingFace Inc. team and BigScience workshop.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""PyTorch BLOOM model."""
import math
import os
import warnings
from typing import Optional, Tuple, Union
import torch
import torch.distributed
import torch.utils.checkpoint
from torch import nn
from torch.nn import LayerNorm
from torch.nn import functional as F
from transformers.modeling_outputs import (
BaseModelOutputWithPastAndCrossAttentions,
CausalLMOutputWithCrossAttentions,
)
from transformers import BloomConfig, PreTrainedModel
from text_generation_server.utils.layers import (
TensorParallelColumnLinear,
TensorParallelEmbedding,
TensorParallelRowLinear,
TensorParallelHead,
)
CUSTOM_KERNELS_ENABLED = False
if not os.environ.get("DISABLE_CUSTOM_KERNELS", "False") == "True":
try:
from custom_kernels import fused_bloom_attention_cuda
CUSTOM_KERNELS_ENABLED = True
except ImportError:
pass
_CHECKPOINT_FOR_DOC = "bigscience/bloom-560m"
_CONFIG_FOR_DOC = "BloomConfig"
BLOOM_PRETRAINED_MODEL_ARCHIVE_LIST = [
"bigscience/bigscience-small-testing",
"bigscience/bloom-560m",
"bigscience/bloom-1b1",
"bigscience/bloom-1b7",
"bigscience/bloom-3b",
"bigscience/bloom-7b1",
"bigscience/bloom",
]
def _make_causal_mask(
input_ids_shape: torch.Size, device: torch.device, past_key_values_length: int
) -> torch.BoolTensor:
"""
Make causal mask used for self-attention.
"""
batch_size, target_length = input_ids_shape
mask = torch.ones(
(target_length, target_length + past_key_values_length),
dtype=torch.bool,
device=device,
)
mask = mask.triu(1 + past_key_values_length)
expanded_mask = mask.unsqueeze(0).expand(
batch_size, target_length, target_length + past_key_values_length
)
return expanded_mask
def _expand_mask(mask: torch.Tensor, tgt_length: int) -> torch.BoolTensor:
"""
Expands attention_mask from `[batch_size, src_length]` to `[batch_size, 1, tgt_length, src_length]`.
"""
batch_size, src_length = mask.shape
tgt_length = tgt_length if tgt_length is not None else src_length
expanded_mask = ~(mask[:, None, :].to(torch.bool))
return expanded_mask.expand(batch_size, tgt_length, src_length)
def build_alibi_tensor(attention_mask: torch.Tensor, num_heads: int) -> torch.Tensor:
"""
Link to paper: https://arxiv.org/abs/2108.12409 Alibi tensor is not causal as the original paper mentions, it
relies on a translation invariance of softmax for quick implementation: with l being a tensor, and a fixed value
`softmax(l+a) = softmax(l)`. Based on
https://github.com/ofirpress/attention_with_linear_biases/blob/a35aaca144e0eb6b789dfcb46784c4b8e31b7983/fairseq/models/transformer.py#L742
TODO @thomasw21 this doesn't work as nicely due to the masking strategy, and so masking varies slightly.
Args:
Returns tensor shaped (batch_size * num_heads, 1, max_seq_len)
attention_mask (`torch.Tensor`):
Token-wise attention mask, this should be of shape (batch_size, max_seq_len).
num_heads (`int`, *required*):
number of heads
dtype (`torch.dtype`, *optional*, default=`torch.bfloat16`):
dtype of the output tensor
"""
batch_size, seq_length = attention_mask.shape
closest_power_of_2 = 2 ** math.floor(math.log2(num_heads))
base = torch.tensor(
2 ** (-(2 ** -(math.log2(closest_power_of_2) - 3))),
device=attention_mask.device,
dtype=torch.float32,
)
powers = torch.arange(
1, 1 + closest_power_of_2, device=attention_mask.device, dtype=torch.int32
)
slopes = torch.pow(base, powers)
if closest_power_of_2 != num_heads:
extra_base = torch.tensor(
2 ** (-(2 ** -(math.log2(2 * closest_power_of_2) - 3))),
device=attention_mask.device,
dtype=torch.float32,
)
num_remaining_heads = min(closest_power_of_2, num_heads - closest_power_of_2)
extra_powers = torch.arange(
1,
1 + 2 * num_remaining_heads,
2,
device=attention_mask.device,
dtype=torch.int32,
)
slopes = torch.cat([slopes, torch.pow(extra_base, extra_powers)], dim=0)
# Note: alibi will added to the attention bias that will be applied to the query, key product of attention
# => therefore alibi will have to be of shape (batch_size, num_heads, query_length, key_length)
# => here we set (batch_size=1, num_heads=num_heads, query_length=1, key_length=max_length)
# => the query_length dimension will then be broadcasted correctly
# This is more or less identical to T5's relative position bias:
# https://github.com/huggingface/transformers/blob/f681437203baa7671de3174b0fa583c349d9d5e1/src/transformers/models/t5/modeling_t5.py#L527
arange_tensor = ((attention_mask.cumsum(dim=-1) - 1) * attention_mask)[:, None, :]
alibi = slopes[..., None] * arange_tensor
return alibi
# @torch.jit.script
def dropout_add(
x: torch.Tensor, residual: torch.Tensor, prob: float, training: bool
) -> torch.Tensor:
"""
Dropout add function
Args:
x (`torch.tensor`, *required*):
input tensor
residual (`torch.tensor`, *required*):
esidual tensor
prob (`float`, *required*):
dropout probability
training (`bool`, *required*):
training mode
"""
out = F.dropout(x, p=prob, training=training)
out = residual + out
return out
# @torch.jit.script # this is shit for unknow reasons.
def _split_heads(
fused_qkv: torch.Tensor, num_heads: int, head_dim: int
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
"""
Split the last dimension into (num_heads, head_dim) without making any copies, results share same memory
storage as `fused_qkv`
Args:
fused_qkv (`torch.tensor`, *required*): [batch_size, seq_length, num_heads * 3 * head_dim]
Returns:
query: [batch_size, seq_length, num_heads, head_dim] key: [batch_size, seq_length, num_heads, head_dim]
value: [batch_size, seq_length, num_heads, head_dim]
"""
batch_size, seq_length, three_times_hidden_size = fused_qkv.shape
fused_qkv = fused_qkv.view(batch_size, seq_length, num_heads, 3 * head_dim)
query_layer, key_layer, value_layer = fused_qkv.split(head_dim, dim=-1)
query_layer = query_layer.transpose(1, 2).reshape(
batch_size * num_heads, seq_length, head_dim
)
key_layer = key_layer.permute(0, 2, 3, 1).reshape(
batch_size * num_heads, head_dim, seq_length
)
value_layer = value_layer.transpose(1, 2).reshape(
batch_size * num_heads, seq_length, head_dim
)
return query_layer, key_layer, value_layer
# @torch.jit.script
def _merge_heads(x: torch.Tensor, num_heads: int, head_dim: int) -> torch.Tensor:
"""
Merge heads together over the last dimenstion
Args:
x: (`torch.tensor`, *required*): [batch_size * num_heads, seq_length, head_dim]
Returns:
torch.tensor: [batch_size, seq_length, num_heads * head_dim]
"""
# What we want to achieve is:
# batch_size * num_heads, seq_length, head_dim -> batch_size, seq_length, num_heads * head_dim
batch_size_and_num_heads, seq_length, _ = x.shape
batch_size = batch_size_and_num_heads // num_heads
# First view to decompose the batch size
# batch_size * num_heads, seq_length, head_dim -> batch_size, num_heads, seq_length, head_dim
x = x.view(batch_size, num_heads, seq_length, head_dim)
# batch_size, num_heads, seq_length, head_dim -> batch_size, seq_length, num_heads, head_dim
x = x.permute(0, 2, 1, 3)
# batch_size, seq_length, num_heads, head_dim -> batch_size, seq_length, num_heads * head_dim
return x.reshape(batch_size, seq_length, num_heads * head_dim)
class BloomAttention(nn.Module):
def __init__(self, prefix, config: BloomConfig, weights):
super().__init__()
self.pretraining_tp = config.pretraining_tp
self.slow_but_exact = config.slow_but_exact
self.process_group = weights.process_group
self.hidden_size = config.hidden_size
self.num_heads = config.n_head
self.head_dim = self.hidden_size // self.num_heads
self.split_size = self.hidden_size
self.hidden_dropout = config.hidden_dropout
if self.head_dim * self.num_heads != self.hidden_size:
raise ValueError(
f"`hidden_size` must be divisible by num_heads (got `hidden_size`: {self.hidden_size} and `num_heads`:"
f" {self.num_heads})."
)
# Layer-wise attention scaling
self.inv_norm_factor = 1.0 / math.sqrt(self.head_dim)
self.beta = 1.0
process_group = weights.process_group
self.num_heads = self.num_heads // process_group.size()
self.query_key_value = TensorParallelColumnLinear.load(
config=config,
prefix=f"{prefix}.query_key_value",
weights=weights,
bias=True,
)
self.dense = TensorParallelRowLinear.load(
config=config, prefix=f"{prefix}.dense", weights=weights, bias=True
)
self.attention_dropout = nn.Dropout(config.attention_dropout)
@staticmethod
def compute_attention(
fused_qkv: torch.Tensor,
layer_past: Optional[Tuple[torch.Tensor, torch.Tensor]],
alibi: torch.Tensor,
attention_mask: torch.Tensor,
head_mask: Optional[torch.Tensor],
beta: float,
inv_norm_factor: float,
num_heads: int,
use_cache: bool,
):
batch_size, q_length, three_times_hidden_size = fused_qkv.shape
head_dim = three_times_hidden_size // (3 * num_heads)
batch_size * num_heads
### TODO @thomasw21: this takes quite a bit of time, how do I accelerate that?
# 3 x [batch_size, seq_length, num_heads, head_dim]
(query_layer, key_layer, value_layer) = _split_heads(
fused_qkv, num_heads=num_heads, head_dim=head_dim
)
if layer_past is not None:
past_key, past_value = layer_past
# concatenate along seq_length dimension:
# - key: [batch_size * self.num_heads, head_dim, kv_length]
# - value: [batch_size * self.num_heads, kv_length, head_dim]
past_key = past_key.view(-1, *past_key.shape[-2:])
key_layer = torch.cat((past_key, key_layer), dim=2)
past_value = past_value.view(-1, *past_value.shape[-2:])
value_layer = torch.cat((past_value, value_layer), dim=1)
_, _, kv_length = key_layer.shape
if use_cache is True:
present = (key_layer, value_layer)
else:
present = None
###
# [batch_size * num_heads, q_length, kv_length]
# we use `torch.Tensor.baddbmm` instead of `torch.baddbmm` as the latter isn't supported by TorchScript v1.11
attention_scores = alibi.baddbmm(
batch1=query_layer,
batch2=key_layer,
beta=beta,
alpha=inv_norm_factor,
)
# cast attention scores to fp32, compute scaled softmax and cast back to initial dtype - [batch_size, num_heads, q_length, kv_length]
input_dtype = attention_scores.dtype
# `float16` has a minimum value of -65504.0, whereas `bfloat16` and `float32` have a minimum value of `-3.4e+38`
if input_dtype == torch.float16:
attention_scores = attention_scores.to(torch.float)
# torch.finfo not supported by torch.jit, we temporarily remplace with `-1e34`
attn_weights = attention_scores.masked_fill_(
attention_mask, torch.finfo(attention_scores.dtype).min
)
attention_probs = F.softmax(attn_weights, dim=-1, dtype=torch.float32).to(
input_dtype
)
# # [batch_size, num_heads, q_length, kv_length]
# attention_probs = self.attention_dropout(attention_probs)
if head_mask is not None:
attention_probs = attention_probs * head_mask
# matmul: [batch_size * num_heads, q_length, head_dim]
context_layer = torch.bmm(attention_probs, value_layer, out=query_layer)
# change view [batch_size, num_heads, q_length, head_dim]
context_layer = _merge_heads(
context_layer, num_heads=num_heads, head_dim=head_dim
)
return context_layer, present, attention_probs
def forward(
self,
hidden_states: torch.Tensor,
residual: torch.Tensor,
alibi: torch.Tensor,
attention_mask: torch.Tensor,
layer_past: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
head_mask: Optional[torch.Tensor] = None,
use_cache: bool = False,
output_attentions: bool = False,
):
fused_qkv = self.query_key_value(
hidden_states
) # [batch_size, seq_length, 3 x hidden_size]
batch_size, q_length, _ = fused_qkv.shape
if layer_past is not None:
past_key, past_value = layer_past
layer_past = (
past_key.view(-1, *past_key.shape[-2:]),
past_value.view(-1, *past_value.shape[-2:]),
)
if CUSTOM_KERNELS_ENABLED:
assert self.training is False, "Only foward pass was implemented"
assert (
attention_mask.shape[-1] < 4096
), "Custom kernel support only up to 4096 tokens"
(
context_layer,
present,
attention_probs,
) = fused_bloom_attention_cuda.forward(
fused_qkv,
layer_past,
alibi,
attention_mask,
head_mask,
self.beta,
self.inv_norm_factor,
self.num_heads,
use_cache,
)
else:
context_layer, present, attention_probs = self.compute_attention(
fused_qkv=fused_qkv,
layer_past=layer_past,
alibi=alibi,
attention_mask=attention_mask,
head_mask=head_mask,
beta=self.beta,
inv_norm_factor=self.inv_norm_factor,
num_heads=self.num_heads,
use_cache=use_cache,
)
# aggregate results across tp ranks. See here: https://github.com/pytorch/pytorch/issues/76232
if self.pretraining_tp > 1 and self.slow_but_exact:
slices = self.hidden_size / self.pretraining_tp
output_tensor = torch.zeros_like(context_layer)
for i in range(self.pretraining_tp):
output_tensor = output_tensor + F.linear(
context_layer[:, :, int(i * slices) : int((i + 1) * slices)],
self.dense.weight[:, int(i * slices) : int((i + 1) * slices)],
)
else:
output_tensor = self.dense(context_layer)
# output_tensor = dropout_add(output_tensor, residual, self.hidden_dropout, self.training)
output_tensor += residual
outputs = (output_tensor, present)
if output_attentions:
outputs += (attention_probs,)
return outputs
class BloomMLP(nn.Module):
def __init__(self, prefix, config: BloomConfig, weights):
super().__init__()
self.pretraining_tp = config.pretraining_tp
self.slow_but_exact = config.slow_but_exact
self.dense_h_to_4h = TensorParallelColumnLinear.load(
config=config, prefix=f"{prefix}.dense_h_to_4h", weights=weights, bias=True
)
self.dense_4h_to_h = TensorParallelRowLinear.load(
config=config, prefix=f"{prefix}.dense_4h_to_h", weights=weights, bias=True
)
self.gelu_impl = torch.nn.GELU(approximate="tanh")
self.hidden_dropout = config.hidden_dropout
def forward(
self, hidden_states: torch.Tensor, residual: torch.Tensor
) -> torch.Tensor:
hidden_states = self.gelu_impl(self.dense_h_to_4h(hidden_states))
if self.pretraining_tp > 1 and self.slow_but_exact:
intermediate_output = torch.zeros_like(residual)
slices = self.dense_4h_to_h.weight.shape[-1] / self.pretraining_tp
for i in range(self.pretraining_tp):
intermediate_output = intermediate_output + F.linear(
hidden_states[:, :, int(i * slices) : int((i + 1) * slices)],
self.dense_4h_to_h.weight[
:, int(i * slices) : int((i + 1) * slices)
],
)
else:
intermediate_output = self.dense_4h_to_h(hidden_states)
# output = dropout_add(intermediate_output, residual, self.hidden_dropout, self.training)
intermediate_output += residual
return intermediate_output
class BloomBlock(nn.Module):
def __init__(self, layer_id: int, config: BloomConfig, weights):
super().__init__()
prefix = f"h.{layer_id}"
self.input_layernorm = LayerNorm.load(
prefix=f"{prefix}.input_layernorm",
weights=weights,
eps=config.layer_norm_epsilon,
)
self.num_heads = config.n_head
self.self_attention = BloomAttention(
prefix=f"{prefix}.self_attention", config=config, weights=weights
)
self.post_attention_layernorm = LayerNorm.load(
prefix=f"{prefix}.post_attention_layernorm",
weights=weights,
eps=config.layer_norm_epsilon,
)
self.mlp = BloomMLP(prefix=f"{prefix}.mlp", config=config, weights=weights)
self.apply_residual_connection_post_layernorm = (
config.apply_residual_connection_post_layernorm
)
self.hidden_dropout = config.hidden_dropout
def forward(
self,
hidden_states: torch.Tensor,
alibi: torch.Tensor,
attention_mask: torch.Tensor,
layer_past: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
head_mask: Optional[torch.Tensor] = None,
use_cache: bool = False,
output_attentions: bool = False,
):
# hidden_states: [batch_size, seq_length, hidden_size]
# Layer norm at the beginning of the transformer layer.
layernorm_output = self.input_layernorm(hidden_states)
# Layer norm post the self attention.
if self.apply_residual_connection_post_layernorm:
residual = layernorm_output
else:
residual = hidden_states
# Self attention.
attn_outputs = self.self_attention(
layernorm_output,
residual,
layer_past=layer_past,
attention_mask=attention_mask,
alibi=alibi,
head_mask=head_mask,
use_cache=use_cache,
output_attentions=output_attentions,
)
attention_output = attn_outputs[0]
outputs = attn_outputs[1:]
layernorm_output = self.post_attention_layernorm(attention_output)
# Get residual
if self.apply_residual_connection_post_layernorm:
residual = layernorm_output
else:
residual = attention_output
# MLP.
output = self.mlp(layernorm_output, residual)
if use_cache:
outputs = (output,) + outputs
else:
outputs = (output,) + outputs[1:]
return outputs # hidden_states, present, attentions
class BloomPreTrainedModel(PreTrainedModel):
config_class = BloomConfig
base_model_prefix = "transformer"
_no_split_modules = ["BloomBlock"]
@staticmethod
def _convert_to_standard_cache(
past_key_value: Tuple[Tuple[torch.Tensor, torch.Tensor]], batch_size: int
) -> Tuple[Tuple[torch.Tensor, torch.Tensor]]:
"""
Standardizes the format of the cache so as to match most implementations, i.e. to tuple(tuple([batch_size,
num_heads, ...]))
"""
batch_size_times_num_heads, head_dim, seq_length = past_key_value[0][0].shape
num_heads = batch_size_times_num_heads // batch_size
# key: [batch_size * num_heads, head_dim, seq_length] -> [batch_size, num_heads, head_dim, seq_length]
# value: [batch_size * num_heads, seq_length, head_dim] -> [batch_size, num_heads, seq_length, head_dim]
return tuple(
(
layer_past[0].view(batch_size, num_heads, head_dim, seq_length),
layer_past[1].view(batch_size, num_heads, seq_length, head_dim),
)
for layer_past in past_key_value
)
@staticmethod
def _convert_to_bloom_cache(
past_key_value: Tuple[Tuple[torch.Tensor, torch.Tensor]]
) -> Tuple[Tuple[torch.Tensor, torch.Tensor]]:
"""
Converts the cache to the format expected by Bloom, i.e. to tuple(tuple([batch_size * num_heads, ...]))
"""
batch_size, num_heads, head_dim, seq_length = past_key_value[0][0].shape
batch_size_times_num_heads = batch_size * num_heads
# key: [batch_size, num_heads, head_dim, seq_length] -> [batch_size * num_heads, head_dim, seq_length]
# value: [batch_size, num_heads, seq_length, head_dim] -> [batch_size * num_heads, seq_length, head_dim]
return tuple(
(
layer_past[0].view(batch_size_times_num_heads, head_dim, seq_length),
layer_past[1].view(batch_size_times_num_heads, seq_length, head_dim),
)
for layer_past in past_key_value
)
class BloomModel(BloomPreTrainedModel):
def __init__(self, config: BloomConfig, weights):
super().__init__(config)
self.embed_dim = config.hidden_size
self.num_heads = config.n_head
process_group = weights.process_group
self.tp_rank = process_group.rank()
self.tp_world_size = process_group.size()
self.word_embeddings = TensorParallelEmbedding(
prefix="word_embeddings", weights=weights
)
self.word_embeddings_layernorm = LayerNorm.load(
prefix="word_embeddings_layernorm",
weights=weights,
eps=config.layer_norm_epsilon,
)
# Transformer blocks
self.h = nn.ModuleList(
[
BloomBlock(layer_id=layer_id, config=config, weights=weights)
for layer_id in range(config.num_hidden_layers)
]
)
# Final Layer Norm
self.ln_f = LayerNorm.load(
prefix="ln_f", weights=weights, eps=config.layer_norm_epsilon
)
def _prepare_attn_mask(
self,
attention_mask: torch.Tensor,
input_shape: Tuple[int, int],
past_key_values_length: int,
) -> torch.BoolTensor:
# create causal mask
# [batch_size, seq_length] -> [batch_size, tgt_length, src_length]
combined_attention_mask = None
device = attention_mask.device
_, src_length = input_shape
if src_length > 1:
combined_attention_mask = _make_causal_mask(
input_shape,
device=device,
past_key_values_length=past_key_values_length,
)
# [batch_size, seq_length] -> [batch_size, tgt_length, src_length]
expanded_attn_mask = _expand_mask(attention_mask, tgt_length=src_length)
combined_attention_mask = (
expanded_attn_mask
if combined_attention_mask is None
else expanded_attn_mask | combined_attention_mask
)
return combined_attention_mask
def set_input_embeddings(self, new_embeddings: torch.Tensor):
self.word_embeddings = new_embeddings
def forward(
self,
input_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[Tuple[Tuple[torch.Tensor, torch.Tensor], ...]] = None,
attention_mask: Optional[torch.Tensor] = None,
head_mask: Optional[torch.LongTensor] = None,
inputs_embeds: Optional[torch.LongTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
**deprecated_arguments,
) -> Union[Tuple[torch.Tensor, ...], BaseModelOutputWithPastAndCrossAttentions]:
if deprecated_arguments.pop("position_ids", False) is not False:
# `position_ids` could have been `torch.Tensor` or `None` so defaulting pop to `False` allows to detect if users were passing explicitly `None`
warnings.warn(
"`position_ids` have no functionality in BLOOM and will be removed in v5.0.0. You can safely ignore"
" passing `position_ids`.",
FutureWarning,
)
if len(deprecated_arguments) > 0:
raise ValueError(f"Got unexpected arguments: {deprecated_arguments}")
output_attentions = (
output_attentions
if output_attentions is not None
else self.config.output_attentions
)
output_hidden_states = (
output_hidden_states
if output_hidden_states is not None
else self.config.output_hidden_states
)
use_cache = use_cache if use_cache is not None else self.config.use_cache
return_dict = (
return_dict if return_dict is not None else self.config.use_return_dict
)
if input_ids is not None and inputs_embeds is not None:
raise ValueError(
"You cannot specify both input_ids and inputs_embeds at the same time"
)
elif input_ids is not None:
batch_size, seq_length = input_ids.shape
elif inputs_embeds is not None:
batch_size, seq_length, _ = inputs_embeds.shape
else:
raise ValueError("You have to specify either input_ids or inputs_embeds")
if past_key_values is None:
past_key_values = tuple([None] * len(self.h))
# Prepare head mask if needed
# 1.0 in head_mask indicate we keep the head
# attention_probs has shape batch_size x num_heads x N x N
# head_mask has shape n_layer x batch x num_heads x N x N
head_mask = self.get_head_mask(head_mask, self.config.n_layer)
if inputs_embeds is None:
inputs_embeds = self.word_embeddings(input_ids)
hidden_states = self.word_embeddings_layernorm(inputs_embeds)
presents = () if use_cache else None
all_self_attentions = () if output_attentions else None
all_hidden_states = () if output_hidden_states else None
# Compute alibi tensor: check build_alibi_tensor documentation
seq_length_with_past = seq_length
past_key_values_length = 0
if past_key_values[0] is not None:
past_key_values_length = past_key_values[0][0].shape[-1]
seq_length_with_past = seq_length_with_past + past_key_values_length
if attention_mask is None:
attention_mask = torch.ones(
(batch_size, seq_length_with_past), device=hidden_states.device
)
else:
attention_mask = attention_mask.to(hidden_states.device)
alibi = build_alibi_tensor(attention_mask, self.num_heads)
causal_mask = self._prepare_attn_mask(
attention_mask,
input_shape=(batch_size, seq_length),
past_key_values_length=past_key_values_length,
)
if hasattr(self, "tp_rank"):
assert self.num_heads % self.tp_world_size == 0
block_size = self.num_heads // self.tp_world_size
alibi = alibi[
:, self.tp_rank * block_size : (self.tp_rank + 1) * block_size
]
alibi = alibi.reshape(batch_size * block_size, 1, seq_length_with_past)
causal_mask = torch.repeat_interleave(causal_mask, block_size, dim=0)
else:
alibi = alibi.reshape(batch_size * self.num_heads, 1, seq_length_with_past)
causal_mask = torch.repeat_interleave(causal_mask, self.num_heads, dim=0)
alibi = alibi.to(hidden_states.dtype)
for i, (block, layer_past) in enumerate(zip(self.h, past_key_values)):
if output_hidden_states:
all_hidden_states = all_hidden_states + (hidden_states,)
outputs = block(
hidden_states,
layer_past=layer_past,
attention_mask=causal_mask,
head_mask=head_mask[i],
use_cache=use_cache,
output_attentions=output_attentions,
alibi=alibi,
)
hidden_states = outputs[0]
if use_cache is True:
presents = presents + (outputs[1],)
if output_attentions:
all_self_attentions = all_self_attentions + (
outputs[2 if use_cache else 1],
)
# Add last hidden state
hidden_states = self.ln_f(hidden_states)
if output_hidden_states:
all_hidden_states = all_hidden_states + (hidden_states,)
if not return_dict:
return tuple(
v
for v in [
hidden_states,
presents,
all_hidden_states,
all_self_attentions,
]
if v is not None
)
return BaseModelOutputWithPastAndCrossAttentions(
last_hidden_state=hidden_states,
past_key_values=presents,
hidden_states=all_hidden_states,
attentions=all_self_attentions,
)
class BloomForCausalLM(BloomPreTrainedModel):
def __init__(self, config, weights):
super().__init__(config)
self.transformer = BloomModel(config, weights)
self.lm_head = TensorParallelHead.load(
config,
prefix="word_embeddings",
weights=weights,
)
def prepare_inputs_for_generation(
self,
input_ids: torch.LongTensor,
past_key_values: Optional[torch.Tensor] = None,
attention_mask: Optional[torch.Tensor] = None,
inputs_embeds: Optional[torch.Tensor] = None,
**kwargs,
) -> dict:
# only last token for input_ids if past is not None
if past_key_values:
input_ids = input_ids[:, -1].unsqueeze(-1)
# the cache may be in the stardard format (e.g. in contrastive search), convert to bloom's format if needed
if past_key_values[0][0].shape[0] == input_ids.shape[0]:
past_key_values = self._convert_to_bloom_cache(past_key_values)
# if `inputs_embeds` are passed, we only want to use them in the 1st generation step
if inputs_embeds is not None and past_key_values is None:
model_inputs = {"inputs_embeds": inputs_embeds}
else:
model_inputs = {"input_ids": input_ids}
model_inputs.update(
{
"past_key_values": past_key_values,
"use_cache": kwargs.get("use_cache"),
"attention_mask": attention_mask,
}
)
return model_inputs
def forward(
self,
input_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[Tuple[Tuple[torch.Tensor, torch.Tensor], ...]] = None,
attention_mask: Optional[torch.Tensor] = None,
head_mask: Optional[torch.Tensor] = None,
inputs_embeds: Optional[torch.Tensor] = None,
labels: Optional[torch.Tensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
**deprecated_arguments,
) -> Union[Tuple[torch.Tensor], CausalLMOutputWithCrossAttentions]:
r"""
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
Labels for language modeling. Note that the labels **are shifted** inside the model, i.e. you can set
`labels = input_ids` Indices are selected in `[-100, 0, ..., config.vocab_size]` All labels set to `-100`
are ignored (masked), the loss is only computed for labels in `[0, ..., config.vocab_size]`
"""
if deprecated_arguments.pop("position_ids", False) is not False:
# `position_ids` could have been `torch.Tensor` or `None` so defaulting pop to `False` allows to detect if users were passing explicitly `None`
warnings.warn(
"`position_ids` have no functionality in BLOOM and will be removed in v5.0.0. You can safely ignore"
" passing `position_ids`.",
FutureWarning,
)
if len(deprecated_arguments) > 0:
raise ValueError(f"Got unexpected arguments: {deprecated_arguments}")
return_dict = (
return_dict if return_dict is not None else self.config.use_return_dict
)
transformer_outputs = self.transformer(
input_ids,
past_key_values=past_key_values,
attention_mask=attention_mask,
head_mask=head_mask,
inputs_embeds=inputs_embeds,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)
hidden_states = transformer_outputs[0]
lm_logits = self.lm_head(hidden_states)
loss = None
if not return_dict:
output = (lm_logits,) + transformer_outputs[1:]
return ((loss,) + output) if loss is not None else output
return CausalLMOutputWithCrossAttentions(
loss=loss,
logits=lm_logits,
past_key_values=transformer_outputs.past_key_values,
hidden_states=transformer_outputs.hidden_states,
attentions=transformer_outputs.attentions,
)
......@@ -30,21 +30,23 @@ import flash_attn_cuda
import dropout_layer_norm
from text_generation_server.utils.layers import (
FastLinear,
TensorParallelRowLinear,
TensorParallelColumnLinear,
TensorParallelEmbedding,
PositionRotaryEmbedding,
TensorParallelHead,
)
class LlamaRMSNorm(nn.Module):
def __init__(self, hidden_size, eps=1e-6):
def __init__(self, prefix, weights, eps=1e-6):
"""
LlamaRMSNorm is equivalent to T5LayerNorm
"""
super().__init__()
self.weight = nn.Parameter(torch.ones(hidden_size))
weight = weights.get_tensor(f"{prefix}.weight")
self.weight = nn.Parameter(weight)
self.variance_epsilon = eps
def forward(self, hidden_states, residual=None):
......@@ -91,35 +93,35 @@ class LlamaRMSNorm(nn.Module):
class FlashLlamaAttention(torch.nn.Module):
def __init__(
self,
num_heads,
hidden_size,
process_group=None,
prefix: str,
config,
weights,
):
super().__init__()
self.num_heads = num_heads
self.hidden_size = hidden_size
self.head_size = hidden_size // num_heads
self.num_heads = config.num_attention_heads
self.hidden_size = config.hidden_size
self.head_size = self.hidden_size // self.num_heads
self.rotary_emb = PositionRotaryEmbedding.load(
prefix=f"{prefix}.rotary_emb", weights=weights
)
self.rotary_emb = PositionRotaryEmbedding(self.head_size, base=10000)
self.softmax_scale = self.head_size ** (-0.5)
if process_group is None:
self.query_key_value = FastLinear(hidden_size, 3 * hidden_size, bias=False)
self.o_proj = FastLinear(hidden_size, hidden_size, bias=False)
else:
self.num_heads = self.num_heads // process_group.size()
self.query_key_value = TensorParallelColumnLinear(
hidden_size,
3 * hidden_size,
bias=False,
process_group=process_group,
)
self.o_proj = TensorParallelRowLinear(
hidden_size,
hidden_size,
bias=False,
process_group=process_group,
)
self.num_heads = self.num_heads // weights.process_group.size()
self.query_key_value = TensorParallelColumnLinear.load_multi(
config,
prefixes=[f"{prefix}.q_proj", f"{prefix}.k_proj", f"{prefix}.v_proj"],
dim=0,
weights=weights,
bias=False,
)
self.o_proj = TensorParallelRowLinear.load(
config,
prefix=f"{prefix}.o_proj",
weights=weights,
bias=False,
)
def forward(
self,
......@@ -195,8 +197,9 @@ class FlashLlamaAttention(torch.nn.Module):
class LlamaMLP(nn.Module):
def __init__(self, act, hidden_size, intermediate_size, process_group=None):
def __init__(self, prefix, config, weights):
super().__init__()
act = config.hidden_act
self.act = (
ACT2FN[act]
if "gelu" not in act
......@@ -207,32 +210,23 @@ class LlamaMLP(nn.Module):
else "none",
)
)
if process_group is None:
# Fuse gate and up proj
self.gate_up_proj = FastLinear(
hidden_size, 2 * intermediate_size, bias=False
)
self.down_proj = FastLinear(intermediate_size, hidden_size, bias=False)
self.intermediate_size = intermediate_size
else:
# Fuse gate and up proj
self.gate_up_proj = TensorParallelColumnLinear(
hidden_size,
2 * intermediate_size,
bias=False,
process_group=process_group,
)
self.down_proj = TensorParallelRowLinear(
intermediate_size,
hidden_size,
bias=False,
process_group=process_group,
reduce=True,
)
self.intermediate_size = self.down_proj.in_features
self.process_group = process_group
# Fuse gate and up proj
self.gate_up_proj = TensorParallelColumnLinear.load_multi(
config,
prefixes=[f"{prefix}.gate_proj", f"{prefix}.up_proj"],
weights=weights,
dim=0,
bias=False,
)
self.down_proj = TensorParallelRowLinear.load(
config,
prefix=f"{prefix}.down_proj",
weights=weights,
bias=False,
)
self.intermediate_size = (
config.intermediate_size // weights.process_group.size()
)
def forward(self, hidden_states):
gate_up_states = self.gate_up_proj(hidden_states)
......@@ -241,22 +235,22 @@ class LlamaMLP(nn.Module):
class FlashLlamaLayer(nn.Module):
def __init__(
self,
num_heads,
act,
hidden_size,
intermediate_size,
rms_norm_eps,
process_group=None,
):
def __init__(self, layer_id, config, weights):
super().__init__()
prefix = f"model.layers.{layer_id}"
self.self_attn = FlashLlamaAttention(
prefix=f"{prefix}.self_attn", config=config, weights=weights
)
self.mlp = LlamaMLP(prefix=f"{prefix}.mlp", config=config, weights=weights)
self.self_attn = FlashLlamaAttention(num_heads, hidden_size, process_group)
self.mlp = LlamaMLP(act, hidden_size, intermediate_size, process_group)
self.input_layernorm = LlamaRMSNorm(hidden_size, eps=rms_norm_eps)
self.post_attention_layernorm = LlamaRMSNorm(hidden_size, eps=rms_norm_eps)
self.input_layernorm = LlamaRMSNorm(
prefix=f"{prefix}.input_layernorm", weights=weights, eps=config.rms_norm_eps
)
self.post_attention_layernorm = LlamaRMSNorm(
prefix=f"{prefix}.post_attention_layernorm",
weights=weights,
eps=config.rms_norm_eps,
)
def forward(
self,
......@@ -295,54 +289,35 @@ class FlashLlamaLayer(nn.Module):
class FlashLlamaModel(torch.nn.Module):
def __init__(self, config, process_group=None):
super(FlashLlamaModel, self).__init__()
def __init__(self, config, weights):
super().__init__()
self.config = config
self.tp_embeddings = False
if process_group is not None:
self.tp_rank = process_group.rank()
self.tp_world_size = process_group.size()
if config.vocab_size % self.tp_world_size == 0:
self.tp_embeddings = True
if self.tp_embeddings:
self.embed_tokens = TensorParallelEmbedding(
config.vocab_size, config.hidden_size, process_group=process_group
)
else:
self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size)
process_group = weights.process_group
self.tp_rank = process_group.rank()
self.tp_world_size = process_group.size()
self.embed_tokens = TensorParallelEmbedding(
prefix="model.embed_tokens", weights=weights
)
self.layers = nn.ModuleList(
[
FlashLlamaLayer(
config.num_attention_heads,
config.hidden_act,
config.hidden_size,
config.intermediate_size,
config.rms_norm_eps,
process_group,
layer_id,
config,
weights,
)
for _ in range(config.num_hidden_layers)
for layer_id in range(config.num_hidden_layers)
]
)
self.norm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
self.norm = LlamaRMSNorm(
prefix="model.norm", weights=weights, eps=config.rms_norm_eps
)
self.gradient_checkpointing = False
self.head_size = self.layers[0].self_attn.head_size
self.num_heads = self.layers[0].self_attn.num_heads
def post_load_weights(self, quantize: Optional[str] = None):
if isinstance(self.embed_tokens, TensorParallelEmbedding):
self.embed_tokens.add_null_idx()
for layer in self.layers:
layer: FlashLlamaLayer
layer.self_attn.query_key_value.prepare_weights(quantize)
layer.self_attn.o_proj.prepare_weights(quantize)
layer.mlp.gate_up_proj.prepare_weights(quantize)
layer.mlp.down_proj.prepare_weights(quantize)
def forward(
self,
input_ids,
......@@ -410,29 +385,15 @@ class FlashLlamaModel(torch.nn.Module):
class FlashLlamaForCausalLM(torch.nn.Module):
def __init__(self, config, process_group=None):
def __init__(self, config, weights):
super().__init__()
self.process_group = process_group
if self.process_group is not None:
self.world_size = self.process_group.size()
else:
self.world_size = 1
self.model = FlashLlamaModel(config, process_group)
if self.model.tp_embeddings:
self.lm_head = FastLinear(
config.hidden_size,
config.vocab_size // process_group.size(),
bias=False,
)
else:
self.lm_head = FastLinear(config.hidden_size, config.vocab_size, bias=False)
def post_load_weights(self, quantize: Optional[str] = None):
self.model.post_load_weights(quantize)
self.lm_head.prepare_weights()
self.model = FlashLlamaModel(config, weights)
self.lm_head = TensorParallelHead.load(
config,
prefix="lm_head",
weights=weights,
)
def forward(
self,
......@@ -457,12 +418,4 @@ class FlashLlamaForCausalLM(torch.nn.Module):
if lm_head_indices is not None:
hidden_states = hidden_states[lm_head_indices]
logits = self.lm_head(hidden_states)
if self.model.tp_embeddings:
# Logits are sharded, so we need to gather them
world_logits = [torch.empty_like(logits) for _ in range(self.world_size)]
torch.distributed.all_gather(world_logits, logits, group=self.process_group)
world_logits = torch.cat(world_logits, dim=1)
return world_logits, present
return logits, present
......@@ -31,61 +31,81 @@ from typing import Optional
import flash_attn_cuda
from text_generation_server.utils.layers import (
FastLinear,
TensorParallelRowLinear,
TensorParallelColumnLinear,
TensorParallelEmbedding,
TensorParallelHead,
FastLayerNorm,
PositionRotaryEmbedding,
get_linear,
)
def load_row(config, prefix: str, weights, bias: bool):
weight = weights.get_sharded(f"{prefix}.weight", dim=1)
if bias and weights.process_group.rank() == 0:
# Rank is only on the first rank process
bias = weights.get_tensor(f"{prefix}.bias")
else:
bias = None
linear = get_linear(weight, bias, config.quantize)
if config.use_parallel_residual:
return linear
else:
return TensorParallelRowLinear(linear, process_group=weights.process_group)
def load_qkv(config, prefix: str, weights, num_heads, head_size, hidden_size):
weight = weights.get_sharded(f"{prefix}.weight", dim=0)
bias = weights.get_sharded(f"{prefix}.bias", dim=0)
weight = (
weight.view(
num_heads,
3,
head_size,
hidden_size,
)
.permute(1, 0, 2, 3)
.reshape(-1, hidden_size)
)
bias = bias.view(num_heads, 3, head_size).permute(1, 0, 2).reshape(-1)
linear = get_linear(weight, bias, config.quantize)
if config.use_parallel_residual:
return linear
else:
return TensorParallelColumnLinear(linear)
class FlashNeoxAttention(torch.nn.Module):
def __init__(
self,
num_heads,
hidden_size,
rotary_pct,
rotary_emb_base,
process_group=None,
reduce=True,
):
def __init__(self, config, prefix, weights):
super().__init__()
num_heads = config.num_attention_heads
hidden_size = config.hidden_size
self.num_heads = num_heads
self.hidden_size = hidden_size
self.head_size = hidden_size // num_heads
self.num_heads = self.num_heads // weights.process_group.size()
rotary_ndims = int(self.head_size * rotary_pct)
self.rotary_emb = PositionRotaryEmbedding(rotary_ndims, base=rotary_emb_base)
self.softmax_scale = self.head_size ** (-0.5)
self.rotary_emb = PositionRotaryEmbedding.load(
prefix=f"{prefix}.rotary_emb", weights=weights
)
if process_group is None:
self.query_key_value = FastLinear(hidden_size, 3 * hidden_size)
self.dense = FastLinear(hidden_size, hidden_size)
else:
self.num_heads = self.num_heads // process_group.size()
self.query_key_value = TensorParallelColumnLinear(
hidden_size,
3 * hidden_size,
process_group=process_group,
)
self.dense = TensorParallelRowLinear(
hidden_size, hidden_size, process_group=process_group, reduce=reduce
)
self.softmax_scale = self.head_size ** (-0.5)
def shuffle_qkv_dims(self):
"""Swap dims to avoid an additional permute"""
self.query_key_value.weight = torch.nn.Parameter(
self.query_key_value.weight.view(
self.num_heads, 3, self.head_size, self.hidden_size
)
.permute(1, 0, 2, 3)
.reshape(-1, self.hidden_size)
self.query_key_value = load_qkv(
config,
prefix=f"{prefix}.query_key_value",
weights=weights,
num_heads=self.num_heads,
head_size=self.head_size,
hidden_size=self.hidden_size,
)
self.query_key_value.bias = torch.nn.Parameter(
self.query_key_value.bias.view(self.num_heads, 3, self.head_size)
.permute(1, 0, 2)
.reshape(-1)
self.dense = load_row(
config, prefix=f"{prefix}.dense", weights=weights, bias=True
)
def forward(
......@@ -162,10 +182,9 @@ class FlashNeoxAttention(torch.nn.Module):
class FlashMLP(nn.Module):
def __init__(
self, act, hidden_size, intermediate_size, process_group=None, reduce=True
):
def __init__(self, config, prefix, weights):
super().__init__()
act = config.hidden_act
self.act = (
ACT2FN[act]
if "gelu" not in act
......@@ -177,22 +196,12 @@ class FlashMLP(nn.Module):
)
)
if process_group is None:
self.dense_h_to_4h = FastLinear(hidden_size, intermediate_size)
self.dense_4h_to_h = FastLinear(intermediate_size, hidden_size)
else:
self.dense_h_to_4h = TensorParallelColumnLinear(
hidden_size,
intermediate_size,
process_group=process_group,
)
self.dense_4h_to_h = TensorParallelRowLinear(
intermediate_size,
hidden_size,
process_group=process_group,
reduce=reduce,
)
self.process_group = process_group
self.dense_h_to_4h = TensorParallelColumnLinear.load(
config, prefix=f"{prefix}.dense_h_to_4h", weights=weights, bias=True
)
self.dense_4h_to_h = load_row(
config, prefix=f"{prefix}.dense_4h_to_h", weights=weights, bias=True
)
def forward(self, hidden_states):
hidden_states = self.dense_h_to_4h(hidden_states)
......@@ -202,38 +211,28 @@ class FlashMLP(nn.Module):
class FlashNeoXLayer(nn.Module):
def __init__(
self,
num_heads,
act,
hidden_size,
intermediate_size,
rotary_pct,
rotary_emb_base,
layer_norm_eps,
use_parallel_residual,
process_group=None,
):
def __init__(self, layer_id, config, weights):
super().__init__()
self.use_parallel_residual = use_parallel_residual
self.input_layernorm = FastLayerNorm(hidden_size, eps=layer_norm_eps)
self.post_attention_layernorm = FastLayerNorm(hidden_size, eps=layer_norm_eps)
self.attention = FlashNeoxAttention(
num_heads,
hidden_size,
rotary_pct,
rotary_emb_base,
process_group,
reduce=not use_parallel_residual,
layer_norm_eps = config.layer_norm_eps
prefix = f"gpt_neox.layers.{layer_id}"
self.use_parallel_residual = config.use_parallel_residual
self.input_layernorm = FastLayerNorm.load(
prefix=f"{prefix}.input_layernorm", weights=weights, eps=layer_norm_eps
)
self.mlp = FlashMLP(
act,
hidden_size,
intermediate_size,
process_group,
reduce=not use_parallel_residual,
self.post_attention_layernorm = FastLayerNorm.load(
prefix=f"{prefix}.post_attention_layernorm",
weights=weights,
eps=layer_norm_eps,
)
self.attention = FlashNeoxAttention(
config, prefix=f"{prefix}.attention", weights=weights
)
self.process_group = process_group
self.mlp = FlashMLP(config, prefix=f"{prefix}.mlp", weights=weights)
self.process_group = weights.process_group
def forward(
self,
......@@ -266,9 +265,7 @@ class FlashNeoXLayer(nn.Module):
mlp_output = self.mlp(ln2_hidden_states)
intermediate = mlp_output + attn_output
# Only reduce once and after the addition instead of once per layer
if self.process_group is not None:
torch.distributed.all_reduce(intermediate, group=self.process_group)
torch.distributed.all_reduce(intermediate, group=self.process_group)
return intermediate + hidden_states, None
else:
......@@ -302,42 +299,24 @@ class FlashGPTNeoXPreTrainedModel(PreTrainedModel):
class FlashGPTNeoXModel(FlashGPTNeoXPreTrainedModel):
def __init__(self, config, process_group=None):
def __init__(self, config, weights):
super().__init__(config)
self.config = config
self.tp_embeddings = False
if process_group is not None:
self.tp_rank = process_group.rank()
self.tp_world_size = process_group.size()
if config.vocab_size % self.tp_world_size == 0:
self.tp_embeddings = True
if self.tp_embeddings:
self.embed_in = TensorParallelEmbedding(
config.vocab_size, config.hidden_size, process_group=process_group
)
else:
self.embed_in = nn.Embedding(config.vocab_size, config.hidden_size)
self.embed_in = TensorParallelEmbedding(
prefix="gpt_neox.embed_in", weights=weights
)
self.layers = nn.ModuleList(
[
FlashNeoXLayer(
config.num_attention_heads,
config.hidden_act,
config.hidden_size,
config.intermediate_size,
config.rotary_pct,
config.rotary_emb_base,
config.layer_norm_eps,
config.use_parallel_residual,
process_group,
)
for _ in range(config.num_hidden_layers)
FlashNeoXLayer(layer_id, config, weights)
for layer_id in range(config.num_hidden_layers)
]
)
self.final_layer_norm = FastLayerNorm(
config.hidden_size, eps=config.layer_norm_eps
self.final_layer_norm = FastLayerNorm.load(
prefix="gpt_neox.final_layer_norm",
weights=weights,
eps=config.layer_norm_eps,
)
self.gradient_checkpointing = False
......@@ -345,29 +324,6 @@ class FlashGPTNeoXModel(FlashGPTNeoXPreTrainedModel):
self.head_size = self.layers[0].attention.head_size
self.num_heads = self.layers[0].attention.num_heads
def post_load_weights(self, quantize: Optional[str] = None):
if isinstance(self.embed_in, TensorParallelEmbedding):
self.embed_in.add_null_idx()
for layer in self.layers:
layer: FlashNeoXLayer
layer.attention.shuffle_qkv_dims()
layer.attention.query_key_value.prepare_weights(quantize)
layer.attention.dense.prepare_weights(quantize)
layer.mlp.dense_h_to_4h.prepare_weights(quantize)
layer.mlp.dense_4h_to_h.prepare_weights(quantize)
@classmethod
def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs):
# Pop here as we will replace the layer in our own logic and don't want from_pretrained
# to do it for us
load_in_8bit = kwargs.pop("load_in_8bit", False)
model = super(FlashGPTNeoXModel, cls).from_pretrained(
pretrained_model_name_or_path, load_in_8bit=False, *model_args, **kwargs
)
model.post_load_weights("bitsandbytes" if load_in_8bit else None)
return model
def forward(
self,
input_ids,
......@@ -435,42 +391,13 @@ class FlashGPTNeoXModel(FlashGPTNeoXPreTrainedModel):
class FlashGPTNeoXForCausalLM(FlashGPTNeoXPreTrainedModel):
def __init__(self, config, process_group=None):
def __init__(self, config, weights):
super().__init__(config)
self.gpt_neox = FlashGPTNeoXModel(config, weights)
self.process_group = process_group
if self.process_group is not None:
self.world_size = self.process_group.size()
else:
self.world_size = 1
self.gpt_neox = FlashGPTNeoXModel(config, process_group)
if self.gpt_neox.tp_embeddings:
self.embed_out = FastLinear(
config.hidden_size,
config.vocab_size // process_group.size(),
bias=False,
)
else:
self.embed_out = FastLinear(
config.hidden_size, config.vocab_size, bias=False
)
def post_load_weights(self, quantize: Optional[str] = None):
self.gpt_neox.post_load_weights(quantize)
self.embed_out.prepare_weights()
@classmethod
def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs):
# Pop here as we will replace the layer in our own logic and don't want from_pretrained
# to do it for us
load_in_8bit = kwargs.pop("load_in_8bit", False)
model = super(FlashGPTNeoXForCausalLM, cls).from_pretrained(
pretrained_model_name_or_path, load_in_8bit=False, *model_args, **kwargs
self.embed_out = TensorParallelHead.load(
config, prefix="embed_out", weights=weights
)
model.post_load_weights("bitsandbytes" if load_in_8bit else None)
return model
def forward(
self,
......@@ -495,12 +422,4 @@ class FlashGPTNeoXForCausalLM(FlashGPTNeoXPreTrainedModel):
if lm_head_indices is not None:
hidden_states = hidden_states[lm_head_indices]
logits = self.embed_out(hidden_states)
if self.gpt_neox.tp_embeddings:
# Logits are sharded, so we need to gather them
world_logits = [torch.empty_like(logits) for _ in range(self.world_size)]
torch.distributed.all_gather(world_logits, logits, group=self.process_group)
world_logits = torch.cat(world_logits, dim=1)
return world_logits, present
return logits, present
import os
import torch
import torch.distributed
......@@ -12,15 +10,31 @@ from typing import Optional
import flash_attn_cuda
from text_generation_server.utils.layers import (
FastLinear,
TensorParallelRowLinear,
TensorParallelColumnLinear,
TensorParallelEmbedding,
TensorParallelHead,
FastLayerNorm,
PositionRotaryEmbedding,
get_linear,
)
def load_row(config, prefix: str, weights, bias: bool):
weight = weights.get_sharded(f"{prefix}.weight", dim=1)
if bias and weights.process_group.rank() == 0:
# Rank is only on the first rank process
bias = weights.get_tensor(f"{prefix}.bias")
else:
bias = None
linear = get_linear(weight, bias, config.quantize)
if config.parallel_attn:
return linear
else:
return TensorParallelRowLinear(linear, process_group=weights.process_group)
class RWConfig(PretrainedConfig):
attribute_map = {
"num_hidden_layers": "n_layer",
......@@ -85,44 +99,31 @@ class RWConfig(PretrainedConfig):
class FlashRWAttention(torch.nn.Module):
def __init__(
self,
num_heads,
num_heads_kv,
hidden_size,
bias,
process_group=None,
reduce=True,
config,
prefix,
weights,
):
super().__init__()
self.num_heads = num_heads
self.num_heads_kv = num_heads_kv
self.hidden_size = hidden_size
self.head_size = hidden_size // num_heads
self.num_heads = config.n_head
self.num_heads_kv = config.n_head_kv
self.hidden_size = config.hidden_size
self.head_size = self.hidden_size // self.num_heads
self.rotary_emb = PositionRotaryEmbedding(self.head_size, base=10000)
self.rotary_emb = PositionRotaryEmbedding.static(
dim=self.head_size, base=10000.0, device=weights.device
)
self.softmax_scale = self.head_size ** (-0.5)
self.num_heads = self.num_heads // weights.process_group.size()
if process_group is None:
self.query_key_value = FastLinear(
hidden_size,
self.head_size * (self.num_heads + 2 * self.num_heads_kv),
bias=bias,
)
self.dense = FastLinear(hidden_size, hidden_size, bias=bias)
else:
self.query_key_value = TensorParallelColumnLinear(
hidden_size,
self.head_size * (self.num_heads + 2 * self.num_heads_kv),
bias=bias,
process_group=process_group,
)
self.dense = TensorParallelRowLinear(
hidden_size,
hidden_size,
bias=bias,
process_group=process_group,
reduce=reduce,
)
self.num_heads = self.num_heads // process_group.size()
self.query_key_value = TensorParallelColumnLinear.load(
config,
prefix=f"{prefix}.query_key_value",
weights=weights,
bias=config.bias,
)
self.dense = load_row(
config, prefix=f"{prefix}.dense", weights=weights, bias=config.bias
)
def forward(
self,
......@@ -212,57 +213,48 @@ class FlashRWAttention(torch.nn.Module):
class FlashRWLargeAttention(torch.nn.Module):
def __init__(
self,
num_heads,
num_heads_kv,
hidden_size,
bias,
process_group=None,
reduce=True,
config,
prefix,
weights,
):
super().__init__()
hidden_size = config.hidden_size
num_heads = config.n_head
num_heads_kv = config.n_head_kv
self.hidden_size = hidden_size
self.head_size = hidden_size // num_heads
self.rotary_emb = PositionRotaryEmbedding(self.head_size, base=10000)
self.rotary_emb = PositionRotaryEmbedding.static(
self.head_size, base=10000.0, device=weights.device
)
self.softmax_scale = self.head_size ** (-0.5)
self.num_groups = num_heads // (num_heads_kv * 2)
self.num_heads = num_heads // self.num_groups
self.num_heads_kv = num_heads_kv // self.num_groups
process_group = weights.process_group
if process_group is None:
self.query_key_value = FastLinear(
hidden_size,
self.num_groups
* self.head_size
* (self.num_heads + 2 * self.num_heads_kv),
bias=bias,
)
self.dense = FastLinear(hidden_size, hidden_size, bias=bias)
else:
if process_group.size() > self.num_groups:
raise NotImplementedError(
f"Tensor Parallelism is not implemented for world_size > n groups"
)
self.query_key_value = TensorParallelColumnLinear(
hidden_size,
self.num_groups
* self.head_size
* (self.num_heads + 2 * self.num_heads_kv),
bias=bias,
process_group=process_group,
if process_group.size() > self.num_groups:
raise NotImplementedError(
f"Tensor Parallelism is not implemented for world_size > n groups"
)
self.dense = TensorParallelRowLinear(
hidden_size,
hidden_size,
bias=bias,
process_group=process_group,
reduce=reduce,
if self.num_groups % process_group.size() != 0:
raise NotImplementedError(
f"Tensor Parallelism is not implemented for {self.num_groups} not divisible by {process_group.size()}"
)
self.num_groups = self.num_groups // process_group.size()
self.num_groups = self.num_groups // process_group.size()
self.query_key_value = TensorParallelColumnLinear.load(
config,
prefix=f"{prefix}.query_key_value",
weights=weights,
bias=config.bias,
)
self.dense = load_row(
config, prefix=f"{prefix}.dense", weights=weights, bias=config.bias
)
def forward(
self,
......@@ -359,28 +351,16 @@ class FlashRWLargeAttention(torch.nn.Module):
class FlashMLP(nn.Module):
def __init__(self, hidden_size, bias, process_group=None, reduce=True):
def __init__(self, config, prefix, weights):
super().__init__()
self.act = torch.nn.functional.gelu
if process_group is None:
self.dense_h_to_4h = FastLinear(hidden_size, 4 * hidden_size, bias=bias)
self.dense_4h_to_h = FastLinear(4 * hidden_size, hidden_size, bias=bias)
else:
self.dense_h_to_4h = TensorParallelColumnLinear(
hidden_size,
4 * hidden_size,
bias=bias,
process_group=process_group,
)
self.dense_4h_to_h = TensorParallelRowLinear(
4 * hidden_size,
hidden_size,
bias=bias,
process_group=process_group,
reduce=reduce,
)
self.process_group = process_group
self.dense_h_to_4h = TensorParallelColumnLinear.load(
config, prefix=f"{prefix}.dense_h_to_4h", weights=weights, bias=config.bias
)
self.dense_4h_to_h = load_row(
config, prefix=f"{prefix}.dense_4h_to_h", weights=weights, bias=config.bias
)
def forward(self, hidden_states):
hidden_states = self.dense_h_to_4h(hidden_states)
......@@ -392,38 +372,44 @@ class FlashMLP(nn.Module):
class FlashRWLayer(nn.Module):
def __init__(
self,
num_heads,
num_heads_kv,
hidden_size,
bias,
layer_norm_eps,
parallel_attn,
process_group=None,
layer_id,
config,
weights,
):
super().__init__()
parallel_attn = config.parallel_attn
self.parallel_attn = parallel_attn
self.input_layernorm = FastLayerNorm(hidden_size, eps=layer_norm_eps)
prefix = f"transformer.h.{layer_id}"
self.input_layernorm = FastLayerNorm.load(
prefix=f"{prefix}.input_layernorm",
weights=weights,
eps=config.layer_norm_epsilon,
)
self.self_attention = FlashRWAttention(
num_heads,
num_heads_kv,
hidden_size,
bias,
process_group=process_group,
reduce=False,
config,
prefix=f"{prefix}.self_attention",
weights=weights,
)
self.post_attention_layernorm = (
FastLayerNorm(hidden_size, eps=layer_norm_eps)
FastLayerNorm.load(
prefix=f"{prefix}.post_attention_layernorm",
weights=weights,
eps=config.layer_norm_epsilon,
)
if not parallel_attn
else None
)
self.mlp = FlashMLP(
hidden_size, bias, process_group=process_group, reduce=False
config,
prefix=f"{prefix}.mlp",
weights=weights,
)
self.process_group = process_group
self.process_group = weights.process_group
def forward(
self,
......@@ -454,9 +440,7 @@ class FlashRWLayer(nn.Module):
mlp_output = self.mlp(ln_hidden_states)
intermediate = mlp_output + attn_output
# Only reduce once and after the addition instead of once per layer
if self.process_group is not None:
torch.distributed.all_reduce(intermediate, group=self.process_group)
torch.distributed.all_reduce(intermediate, group=self.process_group)
return intermediate, residual
else:
......@@ -483,33 +467,30 @@ class FlashRWLayer(nn.Module):
class FlashRWLargeLayer(nn.Module):
def __init__(
self,
num_heads,
num_heads_kv,
hidden_size,
bias,
layer_norm_eps,
process_group=None,
):
def __init__(self, layer_id, config, weights):
super().__init__()
self.ln_attn = FastLayerNorm(hidden_size, eps=layer_norm_eps)
self.ln_mlp = FastLayerNorm(hidden_size, eps=layer_norm_eps)
prefix = f"transformer.h.{layer_id}"
self.ln_attn = FastLayerNorm.load(
prefix=f"{prefix}.ln_attn",
weights=weights,
eps=config.layer_norm_epsilon,
)
self.ln_mlp = FastLayerNorm.load(
prefix=f"{prefix}.ln_mlp",
weights=weights,
eps=config.layer_norm_epsilon,
)
self.self_attention = FlashRWLargeAttention(
num_heads,
num_heads_kv,
hidden_size,
bias,
process_group=process_group,
reduce=False,
config,
prefix=f"{prefix}.self_attention",
weights=weights,
)
assert config.parallel_attn, "This version doesn't support non parallel_attn"
self.mlp = FlashMLP(
hidden_size, bias, process_group=process_group, reduce=False
)
self.mlp = FlashMLP(config, prefix=f"{prefix}.mlp", weights=weights)
self.process_group = process_group
self.process_group = weights.process_group
def forward(
self,
......@@ -543,9 +524,7 @@ class FlashRWLargeLayer(nn.Module):
intermediate = attn_output + mlp_output
# Only reduce once and after the addition instead of once per layer
if self.process_group is not None:
torch.distributed.all_reduce(intermediate, group=self.process_group)
torch.distributed.all_reduce(intermediate, group=self.process_group)
return intermediate, residual
......@@ -555,37 +534,18 @@ class FlashRWPreTrainedModel(PreTrainedModel):
class FlashRWModel(FlashRWPreTrainedModel):
def __init__(self, config, process_group=None):
def __init__(self, config, weights):
super().__init__(config)
self.config = config
self.tp_embeddings = False
if process_group is not None:
self.tp_rank = process_group.rank()
self.tp_world_size = process_group.size()
if config.vocab_size % self.tp_world_size == 0:
self.tp_embeddings = True
if self.tp_embeddings:
self.word_embeddings = TensorParallelEmbedding(
config.vocab_size, config.hidden_size, process_group=process_group
)
else:
self.word_embeddings = nn.Embedding(config.vocab_size, config.hidden_size)
self.word_embeddings = TensorParallelEmbedding(
prefix="transformer.word_embeddings", weights=weights
)
if config.model_type == "RefinedWebModel":
self.h = nn.ModuleList(
[
FlashRWLayer(
config.n_head,
config.n_head_kv,
config.hidden_size,
config.bias,
config.layer_norm_epsilon,
config.parallel_attn,
process_group,
)
for _ in range(config.num_hidden_layers)
FlashRWLayer(layer_id, config, weights)
for layer_id in range(config.num_hidden_layers)
]
)
self.cache_size = (
......@@ -596,15 +556,8 @@ class FlashRWModel(FlashRWPreTrainedModel):
elif config.model_type == "RefinedWeb":
self.h = nn.ModuleList(
[
FlashRWLargeLayer(
config.n_head,
config.n_head_kv,
config.hidden_size,
config.bias,
config.layer_norm_epsilon,
process_group,
)
for _ in range(config.num_hidden_layers)
FlashRWLargeLayer(layer_id, config, weights)
for layer_id in range(config.num_hidden_layers)
]
)
self.cache_size = (
......@@ -617,31 +570,13 @@ class FlashRWModel(FlashRWPreTrainedModel):
f"model_type {config.model_type} is not supported."
)
self.ln_f = FastLayerNorm(config.hidden_size, eps=config.layer_norm_epsilon)
self.head_size = self.h[0].self_attention.head_size
def post_load_weights(self, quantize: Optional[str] = None):
if isinstance(self.word_embeddings, TensorParallelEmbedding):
self.word_embeddings.add_null_idx()
for layer in self.h:
layer: FlashRWLayer
layer.self_attention.query_key_value.prepare_weights(quantize)
layer.self_attention.dense.prepare_weights(quantize)
layer.mlp.dense_h_to_4h.prepare_weights(quantize)
layer.mlp.dense_4h_to_h.prepare_weights(quantize)
@classmethod
def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs):
# Pop here as we will replace the layer in our own logic and don't want from_pretrained
# to do it for us
load_in_8bit = kwargs.pop("load_in_8bit", False)
model = super(FlashRWModel, cls).from_pretrained(
pretrained_model_name_or_path, load_in_8bit=False, *model_args, **kwargs
self.ln_f = FastLayerNorm.load(
prefix="transformer.ln_f",
weights=weights,
eps=config.layer_norm_epsilon,
)
model.post_load_weights("bitsandbytes" if load_in_8bit else None)
return model
self.head_size = self.h[0].self_attention.head_size
def forward(
self,
......@@ -708,40 +643,14 @@ class FlashRWModel(FlashRWPreTrainedModel):
class FlashRWForCausalLM(FlashRWPreTrainedModel):
def __init__(self, config, process_group=None):
def __init__(self, config, weights):
super().__init__(config)
self.process_group = process_group
if self.process_group is not None:
self.world_size = self.process_group.size()
else:
self.world_size = 1
self.transformer = FlashRWModel(config, process_group)
self.transformer = FlashRWModel(config, weights)
if self.transformer.tp_embeddings:
self.lm_head = FastLinear(
config.hidden_size,
config.vocab_size // process_group.size(),
bias=False,
)
else:
self.lm_head = FastLinear(config.hidden_size, config.vocab_size, bias=False)
def post_load_weights(self, quantize: Optional[str] = None):
self.transformer.post_load_weights(quantize)
self.lm_head.prepare_weights()
@classmethod
def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs):
# Pop here as we will replace the layer in our own logic and don't want from_pretrained
# to do it for us
load_in_8bit = kwargs.pop("load_in_8bit", False)
model = super(FlashRWForCausalLM, cls).from_pretrained(
pretrained_model_name_or_path, load_in_8bit=False, *model_args, **kwargs
self.lm_head = TensorParallelHead.load(
config, prefix="lm_head", weights=weights
)
model.post_load_weights("bitsandbytes" if load_in_8bit else None)
return model
def forward(
self,
......@@ -766,12 +675,4 @@ class FlashRWForCausalLM(FlashRWPreTrainedModel):
if lm_head_indices is not None:
hidden_states = hidden_states[lm_head_indices]
logits = self.lm_head(hidden_states)
if self.transformer.tp_embeddings:
# Logits are sharded, so we need to gather them
world_logits = [torch.empty_like(logits) for _ in range(self.world_size)]
torch.distributed.all_gather(world_logits, logits, group=self.process_group)
world_logits = torch.cat(world_logits, dim=1)
return world_logits, present
return logits, present
......@@ -8,39 +8,142 @@ from typing import Optional
# Flash attention imports
import flash_attn_cuda
from text_generation_server.utils.layers import (
FastLinear,
TensorParallelRowLinear,
TensorParallelColumnLinear,
TensorParallelHead,
TensorParallelEmbedding,
FastLayerNorm,
get_linear,
)
class FlashMQAttention(torch.nn.Module):
def __init__(
self,
num_heads,
def load_multi_mqa(
config, prefix: str, weights, bias: bool, head_size, num_heads, hidden_size
):
if any("c_attn" in k for k in weights.routing.keys()):
slice_ = weights._get_slice(f"{prefix}.c_attn.weight")
shape = slice_.get_shape()
world_size = weights.process_group.size()
rank = weights.process_group.rank()
if config.transpose:
block_size = (shape[1] - 2 * head_size) // world_size
start = rank * block_size
stop = (rank + 1) * block_size
assert (shape[1] - 2 * head_size) % world_size == 0
q_tensor = slice_[:, start:stop]
kv_tensor = slice_[:, -2 * head_size :]
weight = torch.cat([q_tensor, kv_tensor], dim=1).T
else:
block_size = (shape[0] - 2 * head_size) // world_size
start = rank * block_size
stop = (rank + 1) * block_size
assert (shape[0] - 2 * head_size) % world_size == 0
q_tensor = slice_[start:stop]
kv_tensor = slice_[-2 * head_size :]
weight = torch.cat([q_tensor, kv_tensor], dim=0)
if bias:
slice_ = weights._get_slice(f"{prefix}.c_attn.bias")
shape = slice_.get_shape()
block_size = (shape[0] - 2 * head_size) // world_size
assert (shape[0] - 2 * head_size) % world_size == 0
q_tensor = slice_[start:stop]
start = rank * block_size
stop = (rank + 1) * block_size
q_tensor = slice_[start:stop]
kv_tensor = slice_[-2 * head_size :]
bias = torch.cat([q_tensor, kv_tensor], dim=0)
else:
if config.transpose:
w = [
weights.get_sharded(f"{prefix}.q_attn.weight", dim=1).T,
weights.get_tensor(f"{prefix}.kv_attn.weight").T,
]
weight = torch.cat(w, dim=0)
else:
w = [
weights.get_sharded(f"{prefix}.q_attn.weight", dim=0),
weights.get_tensor(f"{prefix}.kv_attn.weight"),
]
weight = torch.cat(w, dim=1)
if bias:
b = [
weights.get_sharded(f"{prefix}.q_attn.bias", dim=0),
weights.get_tensor(f"{prefix}.kv_attn.bias"),
]
bias = torch.cat(b, dim=0)
else:
bias = None
weight = weight.to(dtype=weights.dtype).to(device=weights.device)
assert list(weight.shape) == [
(num_heads + 2) * head_size,
hidden_size,
process_group=None,
):
], f"{weight.shape} != {[(num_heads + 2) * head_size, hidden_size]}"
if bias is not None:
bias = bias.to(dtype=weights.dtype).to(device=weights.device)
assert list(bias.shape) == [
(num_heads + 2) * head_size
], f"{weight.shape} != {[(num_heads + 2) * head_size]}"
return TensorParallelColumnLinear(get_linear(weight, bias, config.quantize))
def load_col(config, prefix: str, weights, bias: bool):
if config.transpose:
weight = weights.get_sharded(f"{prefix}.weight", dim=1).T
else:
weight = weights.get_sharded(f"{prefix}.weight", dim=0)
if bias:
bias = weights.get_sharded(f"{prefix}.bias", dim=0)
else:
bias = None
return TensorParallelColumnLinear(get_linear(weight, bias, config.quantize))
def load_row(config, prefix: str, weights, bias: bool):
if config.transpose:
weight = weights.get_sharded(f"{prefix}.weight", dim=0).T
else:
weight = weights.get_sharded(f"{prefix}.weight", dim=1)
if bias and weights.process_group.rank() == 0:
# Rank is only on the first rank process
bias = weights.get_tensor(f"{prefix}.bias")
else:
bias = None
return TensorParallelRowLinear(
get_linear(weight, bias, config.quantize), process_group=weights.process_group
)
class FlashMQAttention(torch.nn.Module):
def __init__(self, prefix, config, weights):
super().__init__()
num_heads = config.num_attention_heads
hidden_size = config.hidden_size
self.num_heads = num_heads
self.hidden_size = hidden_size
self.head_size = hidden_size // num_heads
assert self.num_heads % weights.process_group.size() == 0
self.num_heads = self.num_heads // weights.process_group.size()
self.softmax_scale = self.head_size ** (-0.5)
if process_group is None:
self.c_attn = FastLinear(hidden_size, hidden_size + 2 * self.head_size)
self.c_proj = FastLinear(hidden_size, hidden_size)
else:
self.num_heads = self.num_heads // process_group.size()
self.c_attn = FastLinear(hidden_size, self.head_size * (self.num_heads + 2))
self.c_proj = TensorParallelRowLinear(
hidden_size,
hidden_size,
process_group=process_group,
)
self.c_attn = load_multi_mqa(
config,
prefix=prefix,
weights=weights,
bias=True,
head_size=self.head_size,
hidden_size=hidden_size,
num_heads=self.num_heads,
)
self.c_proj = load_row(
config, prefix=f"{prefix}.c_proj", weights=weights, bias=True
)
def forward(
self,
......@@ -121,8 +224,9 @@ class FlashMQAttention(torch.nn.Module):
class MLP(nn.Module):
def __init__(self, act, hidden_size, intermediate_size, process_group=None):
def __init__(self, prefix, config, weights):
super().__init__()
act = config.activation_function
self.act = (
ACT2FN[act]
if "gelu" not in act
......@@ -134,20 +238,12 @@ class MLP(nn.Module):
)
)
if process_group is None:
self.c_fc = FastLinear(hidden_size, intermediate_size)
self.c_proj = FastLinear(intermediate_size, hidden_size)
else:
self.c_fc = TensorParallelColumnLinear(
hidden_size,
intermediate_size,
process_group=process_group,
)
self.c_proj = TensorParallelRowLinear(
intermediate_size,
hidden_size,
process_group=process_group,
)
self.c_fc = load_col(
config, prefix=f"{prefix}.c_fc", weights=weights, bias=True
)
self.c_proj = load_row(
config, prefix=f"{prefix}.c_proj", weights=weights, bias=True
)
def forward(self, hidden_states):
hidden_states = self.c_fc(hidden_states)
......@@ -157,28 +253,24 @@ class MLP(nn.Module):
class Block(nn.Module):
def __init__(
self,
num_heads,
act,
hidden_size,
intermediate_size,
layer_norm_eps,
process_group=None,
):
def __init__(self, layer_id, config, weights):
super().__init__()
self.ln_1 = FastLayerNorm(hidden_size, eps=layer_norm_eps)
self.ln_2 = FastLayerNorm(hidden_size, eps=layer_norm_eps)
prefix = f"transformer.h.{layer_id}"
self.ln_1 = FastLayerNorm.load(
prefix=f"{prefix}.ln_1", weights=weights, eps=config.layer_norm_epsilon
)
self.ln_2 = FastLayerNorm.load(
prefix=f"{prefix}.ln_2", weights=weights, eps=config.layer_norm_epsilon
)
self.attn = FlashMQAttention(
num_heads,
hidden_size,
process_group,
prefix=f"{prefix}.attn",
config=config,
weights=weights,
)
self.mlp = MLP(
act,
hidden_size,
intermediate_size,
process_group,
prefix=f"{prefix}.mlp",
config=config,
weights=weights,
)
def forward(
......@@ -210,66 +302,39 @@ class Block(nn.Module):
class FlashSantacoderModel(nn.Module):
def __init__(self, config, process_group=None):
def __init__(self, config, weights):
super().__init__()
self.config = config
self.process_group = process_group
self.tp_embeddings = False
if process_group is not None:
self.tp_rank = process_group.rank()
self.tp_world_size = process_group.size()
if config.vocab_size % self.tp_world_size == 0:
self.tp_embeddings = True
if self.tp_embeddings:
self.wte = TensorParallelEmbedding(
config.vocab_size,
config.hidden_size,
reduce=False,
process_group=process_group,
)
self.wpe = TensorParallelEmbedding(
config.max_position_embeddings,
config.hidden_size,
reduce=False,
process_group=process_group,
)
else:
self.wte = nn.Embedding(config.vocab_size, config.hidden_size)
self.wpe = nn.Embedding(config.max_position_embeddings, config.hidden_size)
self.process_group = weights.process_group
self.wte = TensorParallelEmbedding(
prefix="transformer.wte",
weights=weights,
reduce=False,
)
self.wpe = TensorParallelEmbedding(
prefix="transformer.wpe",
weights=weights,
reduce=False,
)
self.h = nn.ModuleList(
[
Block(
config.num_attention_heads,
config.activation_function,
config.hidden_size,
config.n_inner
if config.n_inner is not None
else 4 * config.hidden_size,
config.layer_norm_epsilon,
process_group,
layer_id,
config,
weights,
)
for _ in range(config.num_hidden_layers)
for layer_id in range(config.num_hidden_layers)
]
)
self.ln_f = FastLayerNorm(config.hidden_size, eps=config.layer_norm_epsilon)
self.ln_f = FastLayerNorm.load(
prefix="transformer.ln_f", weights=weights, eps=config.layer_norm_epsilon
)
self.head_size = self.h[0].attn.head_size
self.num_heads = self.h[0].attn.num_heads
def post_load_weights(self, quantize: Optional[str] = None):
if self.tp_embeddings:
self.wte.add_null_idx()
self.wpe.add_null_idx()
for layer in self.h:
layer: Block
layer.attn.c_attn.prepare_weights(quantize)
layer.attn.c_proj.prepare_weights(quantize)
layer.mlp.c_fc.prepare_weights(quantize)
layer.mlp.c_proj.prepare_weights(quantize)
def forward(
self,
input_ids,
......@@ -281,8 +346,7 @@ class FlashSantacoderModel(nn.Module):
pre_allocate_past_size: Optional[int] = None,
):
hidden_states = self.wte(input_ids) + self.wpe(position_ids)
if self.tp_embeddings:
torch.distributed.all_reduce(hidden_states, group=self.process_group)
torch.distributed.all_reduce(hidden_states, group=self.process_group)
# Prefill
if past_key_values is None:
......@@ -331,23 +395,12 @@ class FlashSantacoderModel(nn.Module):
class FlashSantacoderForCausalLM(nn.Module):
def __init__(self, config, process_group=None):
def __init__(self, config, weights):
super().__init__()
self.transformer = FlashSantacoderModel(config, process_group)
if self.transformer.tp_embeddings:
self.lm_head = FastLinear(
config.hidden_size,
config.vocab_size // process_group.size(),
bias=False,
)
else:
self.lm_head = FastLinear(config.hidden_size, config.vocab_size, bias=False)
def post_load_weights(self, quantize: Optional[str] = None):
self.transformer.post_load_weights(quantize)
self.lm_head.prepare_weights()
self.transformer = FlashSantacoderModel(config, weights)
self.lm_head = TensorParallelHead.load(
config, prefix="transformer.wte", weights=weights
)
def forward(
self,
......@@ -372,29 +425,4 @@ class FlashSantacoderForCausalLM(nn.Module):
if lm_head_indices is not None:
hidden_states = hidden_states[lm_head_indices]
logits = self.lm_head(hidden_states)
if self.transformer.tp_embeddings:
# Logits are sharded, so we need to gather them
if logits.shape[0] == 1:
# Fast path when batch size is 1
world_logits = logits.new_empty(
(logits.shape[1] * self.transformer.tp_world_size)
)
torch.distributed.all_gather_into_tensor(
world_logits, logits.view(-1), group=self.transformer.process_group
)
world_logits = world_logits.view(1, -1)
else:
# We cannot use all_gather_into_tensor as it only support concatenating on the first dim
world_logits = [
torch.empty_like(logits)
for _ in range(self.transformer.tp_world_size)
]
torch.distributed.all_gather(
world_logits, logits, group=self.transformer.process_group
)
world_logits = torch.cat(world_logits, dim=1)
return world_logits, present
return logits, present
# coding=utf-8
# Copyright 2022 EleutherAI The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
""" PyTorch GPTNeoX model."""
from typing import Optional, Tuple, Union
import os
import torch
import torch.distributed
import torch.utils.checkpoint
from torch import nn
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
from transformers.activations import ACT2FN
from transformers.file_utils import (
add_code_sample_docstrings,
add_start_docstrings,
add_start_docstrings_to_model_forward,
replace_return_docstrings,
)
from transformers.modeling_outputs import (
BaseModelOutputWithPast,
CausalLMOutputWithPast,
QuestionAnsweringModelOutput,
SequenceClassifierOutputWithPast,
TokenClassifierOutput,
)
from transformers.modeling_utils import PreTrainedModel
from transformers import GPTNeoXConfig
from loguru import logger
from text_generation_server.utils.layers import (
TensorParallelColumnLinear,
TensorParallelEmbedding,
TensorParallelRowLinear,
TensorParallelHead,
)
CUSTOM_KERNELS_ENABLED = False
if not os.environ.get("DISABLE_CUSTOM_KERNELS", "False") == "True":
try:
from custom_kernels import fused_attention_cuda
CUSTOM_KERNELS_ENABLED = True
except ImportError:
pass
if not CUSTOM_KERNELS_ENABLED:
logger.warning("We're not using custom kernels.")
def make_causal_mask(
input_ids_shape: torch.Size, device: torch.device, past_key_values_length: int
) -> torch.BoolTensor:
"""
Make causal mask used for self-attention.
"""
batch_size, target_length = input_ids_shape
mask = torch.ones(
(target_length, target_length + past_key_values_length),
dtype=torch.bool,
device=device,
)
mask = mask.triu(1 + past_key_values_length)
expanded_mask = mask.unsqueeze(0).expand(
batch_size, target_length, target_length + past_key_values_length
)
return expanded_mask
def expand_mask(mask: torch.Tensor, tgt_length: int) -> torch.BoolTensor:
"""
Expands attention_mask from `[batch_size, src_length]` to `[batch_size, 1, tgt_length, src_length]`.
"""
batch_size, src_length = mask.shape
tgt_length = tgt_length if tgt_length is not None else src_length
expanded_mask = ~(mask[:, None, :].to(torch.bool))
return expanded_mask.expand(batch_size, tgt_length, src_length)
def prepare_attn_mask(
attention_mask: torch.Tensor,
input_shape: Tuple[int, int],
past_key_values_length: int,
) -> torch.BoolTensor:
# create causal mask
# [batch_size, seq_length] -> [batch_size, tgt_length, src_length]
combined_attention_mask = None
device = attention_mask.device
_, src_length = input_shape
if src_length > 1:
combined_attention_mask = make_causal_mask(
input_shape, device=device, past_key_values_length=past_key_values_length
)
# [batch_size, seq_length] -> [batch_size, tgt_length, src_length]
expanded_attn_mask = expand_mask(attention_mask, tgt_length=src_length)
combined_attention_mask = (
expanded_attn_mask
if combined_attention_mask is None
else expanded_attn_mask | combined_attention_mask
)
return combined_attention_mask
class GPTNeoXPreTrainedModel(PreTrainedModel):
"""
An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
models.
"""
class GPTNeoXAttention(nn.Module):
def __init__(self, config, prefix, weights):
super().__init__()
self.num_attention_heads = config.num_attention_heads
self.hidden_size = config.hidden_size
self.head_size = self.hidden_size // self.num_attention_heads
self.rotary_ndims = int(self.head_size * config.rotary_pct)
max_positions = config.max_position_embeddings
# ??? TODO
# self.register_buffer(
# "bias",
# torch.tril(torch.ones((max_positions, max_positions), dtype=torch.bool)).view(
# 1, 1, max_positions, max_positions
# ),
# )
# self.register_buffer("masked_bias", torch.tensor(-1e9))
self.rotary_emb = RotaryEmbedding(
self.rotary_ndims,
config.max_position_embeddings,
base=config.rotary_emb_base,
)
self.rotary_emb.inv_freq = nn.Parameter(
weights.get_tensor(f"{prefix}.rotary_emb.inv_freq")
)
self.inv_norm_factor = 1.0 / torch.sqrt(
torch.tensor(self.head_size, dtype=torch.float32)
).to(torch.get_default_dtype())
assert self.num_attention_heads % weights.process_group.size() == 0
self.num_attention_heads = (
self.num_attention_heads // weights.process_group.size()
)
self.query_key_value = TensorParallelColumnLinear.load(
config, prefix=f"{prefix}.query_key_value", weights=weights, bias=True
)
self.dense = TensorParallelRowLinear.load(
config, prefix=f"{prefix}.dense", weights=weights, bias=True
)
def forward(
self,
hidden_states,
position_ids,
attention_mask,
head_mask=None,
layer_past=None,
use_cache=False,
output_attentions=False,
):
has_layer_past = layer_past is not None
# Compute QKV
# Attention heads [batch, seq_len, hidden_size]
# --> [batch, seq_len, (np * 3 * head_size)]
qkv = self.query_key_value(hidden_states)
# [batch, seq_len, (num_heads * 3 * head_size)]
# --> [batch, seq_len, num_heads, 3 * head_size]
new_qkv_shape = qkv.size()[:-1] + (self.num_attention_heads, 3 * self.head_size)
qkv = qkv.view(*new_qkv_shape).permute(0, 2, 1, 3)
# [batch, seq_len, num_attention_heads, 3 * head_size] --> 3 [batch, num_attention_heads, seq_len, head_size]
query, key, value = qkv.split(self.head_size, -1)
# Compute token offset for rotary embeddings (when decoding)
seq_len = key.shape[-2]
if has_layer_past:
seq_len += layer_past[0].shape[-2]
# Compute rotary embeddings on rotary_ndims
query_rot = query[..., : self.rotary_ndims]
key_rot = key[..., : self.rotary_ndims]
query_rot, key_rot = self.rotary_emb(query_rot, key_rot, position_ids, seq_len)
query[..., : self.rotary_ndims] = query_rot
key[..., : self.rotary_ndims] = key_rot
if CUSTOM_KERNELS_ENABLED:
attn_output, present, attn_weights = fused_attention_cuda.forward(
query,
key,
value,
layer_past,
attention_mask,
head_mask,
self.inv_norm_factor,
self.num_attention_heads,
use_cache,
)
else:
# Cache QKV values
if has_layer_past:
past_key = layer_past[0]
past_value = layer_past[1]
key = torch.cat((past_key, key), dim=-2)
value = torch.cat((past_value, value), dim=-2)
present = (key, value) if use_cache else None
# Compute attention
attn_output, attn_weights = self._attn(
query, key, value, attention_mask, head_mask
)
# Reshape outputs
attn_output = self._merge_heads(
attn_output, self.num_attention_heads, self.head_size
)
attn_output = self.dense(attn_output)
outputs = (attn_output, present)
if output_attentions:
outputs += (attn_weights,)
return outputs
@classmethod
def _split_heads(cls, tensor, num_attention_heads, attn_head_size):
"""
Splits hidden dim into attn_head_size and num_attention_heads
"""
# tensor: [bs, seq_len, hidden_size]
new_shape = tensor.size()[:-1] + (num_attention_heads, attn_head_size)
# -> [bs, seq_len, num_attention_heads, attn_head_size]
tensor = tensor.view(new_shape)
# -> [bs, num_attention_heads, seq_len, attn_head_size]
tensor = tensor.permute(0, 2, 1, 3)
return tensor
@classmethod
def _merge_heads(cls, tensor, num_attention_heads, attn_head_size):
"""
Merges attn_head_size dim and num_attn_heads dim into hidden dim
"""
# tensor [bs, num_attention_heads, seq_len, attn_head_size]
tensor = tensor.permute(0, 2, 1, 3).contiguous()
# -> [bs, seq_len, num_attention_heads, attn_head_size]
tensor = tensor.view(
tensor.size(0), tensor.size(1), num_attention_heads * attn_head_size
)
# -> [bs, seq_len, hidden_size]
return tensor
def _attn(self, query, key, value, attention_mask=None, head_mask=None):
# q, k, v: [bs, num_attention_heads, seq_len, attn_head_size]
# compute causal mask from causal mask buffer
batch_size, num_attention_heads, query_length, attn_head_size = query.size()
key_length = key.size(-2)
query = query.view(
batch_size * num_attention_heads, query_length, attn_head_size
)
key = key.view(batch_size * num_attention_heads, key_length, attn_head_size)
attn_scores = torch.zeros(
1,
dtype=query.dtype,
device=key.device,
).expand(batch_size * num_attention_heads, query_length, key_length)
attn_scores = torch.baddbmm(
attn_scores,
query,
key.transpose(1, 2),
beta=1.0,
alpha=self.inv_norm_factor,
)
# cast attention scores to fp32, compute scaled softmax and cast back to initial dtype - [batch_size, num_heads, q_length, kv_length]
input_dtype = attn_scores.dtype
if input_dtype in [torch.float16, torch.bfloat16]:
attn_scores = attn_scores.to(torch.float)
attn_scores = torch.where(
attention_mask, torch.finfo(attn_scores.dtype).min, attn_scores
)
attn_scores = attn_scores.view(
batch_size, num_attention_heads, query_length, key_length
)
attn_weights = nn.functional.softmax(attn_scores, dim=-1)
attn_weights = attn_weights.to(value.dtype)
# Mask heads if we want to
if head_mask is not None:
attn_weights = attn_weights * head_mask
attn_output = torch.matmul(attn_weights, value)
return attn_output, attn_weights
class RotaryEmbedding(torch.nn.Module):
def __init__(self, dim, max_position_embeddings, base=10000, device=None):
super().__init__()
self.true_inv_freq = 1.0 / (
base ** (torch.arange(0, dim, 2).float().to(device) / dim)
)
self.register_buffer("inv_freq", self.true_inv_freq)
# Build here to make `torch.jit.trace` work.
self.max_seq_len_cached = max_position_embeddings
self.cos_cached = None
self.sin_cached = None
@staticmethod
def rotate_half(x):
"""Rotates half the hidden dims of the input."""
x1 = x[..., : x.shape[-1] // 2]
x2 = x[..., x.shape[-1] // 2 :]
return torch.cat((-x2, x1), dim=-1)
@staticmethod
def _create_cos_sin(inv_freq, max_position_embeddings, dtype, device):
t = torch.arange(
max_position_embeddings, device=inv_freq.device, dtype=inv_freq.dtype
)
freqs = torch.einsum("i,j->ij", t, inv_freq)
# Different from paper, but it uses a different permutation in order to obtain the same calculation
emb = torch.cat((freqs, freqs), dim=-1)
return emb.cos().to(device).to(dtype), emb.sin().to(device).to(dtype)
def forward(self, q, k, position_ids, seq_len=None):
# x: [bs, num_attention_heads, seq_len, head_size]
if (
seq_len > self.max_seq_len_cached
or self.cos_cached is None
or self.sin_cached is None
):
if seq_len > self.max_seq_len_cached:
self.max_seq_len_cached = seq_len
self.cos_cached, self.sin_cached = self._create_cos_sin(
self.true_inv_freq, self.max_seq_len_cached, q.dtype, q.device
)
return rotary_forward(q, k, self.cos_cached, self.sin_cached, position_ids)
@torch.jit.script
def rotary_forward(q, k, cos, sin, position_ids):
cos = cos[position_ids].unsqueeze(1)
sin = sin[position_ids].unsqueeze(1)
chunk_size = q.shape[-1] // 2
q1, q2 = q.split(chunk_size, -1)
q_rotated = torch.cat((-q2, q1), dim=-1)
k1, k2 = k.split(chunk_size, -1)
k_rotated = torch.cat((-k2, k1), dim=-1)
q_embed = (q * cos) + (q_rotated * sin)
k_embed = (k * cos) + (k_rotated * sin)
return q_embed, k_embed
class GPTNeoXMLP(nn.Module):
def __init__(self, config, prefix, weights):
super().__init__()
self.act = (
ACT2FN[config.hidden_act]
if "gelu_fast" not in config.hidden_act
else lambda x: torch.nn.functional.gelu(x, approximate="tanh")
)
self.dense_h_to_4h = TensorParallelColumnLinear.load(
config, prefix=f"{prefix}.dense_h_to_4h", weights=weights, bias=True
)
self.dense_4h_to_h = TensorParallelRowLinear.load(
config, prefix=f"{prefix}.dense_4h_to_h", weights=weights, bias=True
)
def forward(self, hidden_states):
hidden_states = self.dense_h_to_4h(hidden_states)
hidden_states = self.act(hidden_states)
hidden_states = self.dense_4h_to_h(hidden_states)
return hidden_states
class GPTNeoXLayer(nn.Module):
def __init__(self, layer_id, config, weights):
super().__init__()
self.use_parallel_residual = config.use_parallel_residual
self.input_layernorm = nn.LayerNorm.load(
prefix=f"gpt_neox.layers.{layer_id}.input_layernorm",
weights=weights,
eps=config.layer_norm_eps,
)
self.post_attention_layernorm = nn.LayerNorm.load(
prefix=f"gpt_neox.layers.{layer_id}.post_attention_layernorm",
weights=weights,
eps=config.layer_norm_eps,
)
self.attention = GPTNeoXAttention(
config, prefix=f"gpt_neox.layers.{layer_id}.attention", weights=weights
)
self.mlp = GPTNeoXMLP(
config, prefix=f"gpt_neox.layers.{layer_id}.mlp", weights=weights
)
def forward(
self,
hidden_states,
position_ids,
attention_mask=None,
head_mask=None,
use_cache=False,
layer_past=None,
output_attentions=False,
):
attention_layer_outputs = self.attention(
self.input_layernorm(hidden_states),
attention_mask=attention_mask,
position_ids=position_ids,
layer_past=layer_past,
head_mask=head_mask,
use_cache=use_cache,
output_attentions=output_attentions,
)
attn_output = attention_layer_outputs[
0
] # output_attn: attn_output, present, (attn_weights)
outputs = attention_layer_outputs[1:]
if self.use_parallel_residual:
# pseudocode:
# x = x + attn(ln1(x)) + mlp(ln2(x))
mlp_output = self.mlp(self.post_attention_layernorm(hidden_states))
hidden_states = mlp_output + attn_output + hidden_states
else:
# pseudocode:
# x = x + attn(ln1(x))
# x = x + mlp(ln2(x))
attn_output = attn_output + hidden_states
mlp_output = self.mlp(self.post_attention_layernorm(attn_output))
hidden_states = mlp_output + attn_output
if use_cache:
outputs = (
hidden_states,
) + outputs # hidden_states, present, (attn_weights)
else:
outputs = (hidden_states,) + outputs[1:] # hidden_states, (attn_weights)
return outputs
class GPTNeoXModel(GPTNeoXPreTrainedModel):
def __init__(self, config, weights):
super().__init__(config)
self.config = config
self.num_attention_heads = config.num_attention_heads
self.embed_in = TensorParallelEmbedding(
prefix="gpt_neox.embed_in", weights=weights
)
self.layers = nn.ModuleList(
[
GPTNeoXLayer(layer_id, config, weights)
for layer_id in range(config.num_hidden_layers)
]
)
self.final_layer_norm = nn.LayerNorm.load(
prefix="gpt_neox.final_layer_norm",
weights=weights,
eps=config.layer_norm_eps,
)
self.tp_world_size = weights.process_group.size()
def forward(
self,
input_ids: Optional[torch.LongTensor] = None,
position_ids=None,
attention_mask: Optional[torch.FloatTensor] = None,
head_mask: Optional[torch.FloatTensor] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
) -> Union[Tuple, BaseModelOutputWithPast]:
r"""
past_key_values (`tuple(tuple(torch.FloatTensor))` of length `config.n_layers` with each tuple having 4 tensors of shape `(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`):
Contains precomputed key and value hidden states of the attention blocks. Can be used to speed up decoding.
If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those that
don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of all
`decoder_input_ids` of shape `(batch_size, sequence_length)`.
use_cache (`bool`, *optional*):
If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see
`past_key_values`).
"""
output_attentions = (
output_attentions
if output_attentions is not None
else self.config.output_attentions
)
output_hidden_states = (
output_hidden_states
if output_hidden_states is not None
else self.config.output_hidden_states
)
return_dict = (
return_dict if return_dict is not None else self.config.use_return_dict
)
use_cache = use_cache if use_cache is not None else self.config.use_cache
if input_ids is not None and inputs_embeds is not None:
raise ValueError(
"You cannot specify both input_ids and inputs_embeds at the same time"
)
elif input_ids is not None:
input_shape = input_ids.size()
elif inputs_embeds is not None:
input_shape = inputs_embeds.size()[:-1]
else:
raise ValueError("You have to specify either input_ids or inputs_embeds")
batch_size, seq_length = input_shape
if past_key_values is None:
past_length = 0
past_key_values = tuple([None] * self.config.num_hidden_layers)
else:
past_length = past_key_values[0][0].size(-2)
if position_ids is None:
device = input_ids.device if input_ids is not None else inputs_embeds.device
position_ids = torch.arange(
past_length, seq_length + past_length, dtype=torch.long, device=device
)
position_ids = position_ids.unsqueeze(0).view(-1, seq_length)
else:
position_ids = position_ids.view(-1, seq_length).long()
if inputs_embeds is None:
inputs_embeds = self.embed_in(input_ids)
hidden_states = inputs_embeds
# Attention mask.
seq_length_with_past = seq_length
past_key_values_length = 0
if past_key_values[0] is not None:
past_key_values_length = past_key_values[0][0].shape[-1]
seq_length_with_past = seq_length_with_past + past_key_values_length
if attention_mask is None:
attention_mask = torch.ones(
(batch_size, seq_length_with_past), device=hidden_states.device
)
else:
attention_mask = attention_mask.to(hidden_states.device)
causal_mask = prepare_attn_mask(
attention_mask,
input_shape=(batch_size, seq_length),
past_key_values_length=past_key_values_length,
)
assert self.num_attention_heads % self.tp_world_size == 0
block_size = self.num_attention_heads // self.tp_world_size
causal_mask = torch.repeat_interleave(causal_mask, block_size, dim=0)
# Prepare head mask if needed
# 1.0 in head_mask indicate we keep the head
# attention_probs has shape bsz x n_heads x N x N
# input head_mask has shape [num_heads] or [num_hidden_layers x num_heads]
# and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length]
head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers)
presents = () if use_cache else None
all_attentions = () if output_attentions else None
all_hidden_states = () if output_hidden_states else None
for i, (layer, layer_past) in enumerate(zip(self.layers, past_key_values)):
if output_hidden_states:
all_hidden_states = all_hidden_states + (hidden_states,)
outputs = layer(
hidden_states,
position_ids=position_ids,
attention_mask=causal_mask,
head_mask=head_mask[i],
layer_past=layer_past,
use_cache=use_cache,
output_attentions=output_attentions,
)
hidden_states = outputs[0]
if use_cache is True:
presents = presents + (outputs[1],)
if output_attentions:
all_attentions = all_attentions + (outputs[2 if use_cache else 1],)
hidden_states = self.final_layer_norm(hidden_states)
# Add last hidden state
if output_hidden_states:
all_hidden_states = all_hidden_states + (hidden_states,)
if not return_dict:
return tuple(
v
for v in [hidden_states, presents, all_hidden_states, all_attentions]
if v is not None
)
return BaseModelOutputWithPast(
last_hidden_state=hidden_states,
past_key_values=presents,
hidden_states=all_hidden_states,
attentions=all_attentions,
)
class GPTNeoxForCausalLM(GPTNeoXPreTrainedModel):
_keys_to_ignore_on_load_missing = [r"position_ids", r"predictions.decoder.bias"]
def __init__(self, config, weights):
super().__init__(config)
self.gpt_neox = GPTNeoXModel(config, weights)
self.embed_out = TensorParallelHead.load(
config, prefix="embed_out", weights=weights
)
def forward(
self,
input_ids: Optional[torch.LongTensor] = None,
attention_mask: Optional[torch.FloatTensor] = None,
position_ids: Optional[torch.LongTensor] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
head_mask: Optional[torch.FloatTensor] = None,
past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
labels: Optional[torch.LongTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
) -> Union[Tuple, CausalLMOutputWithPast]:
r"""
past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape
`(batch_size, num_heads, sequence_length, embed_size_per_head)`) and 2 additional tensors of shape
`(batch_size, num_heads, encoder_sequence_length, embed_size_per_head)`. The two additional tensors are
only required when the model is used as a decoder in a Sequence to Sequence model.
Contains pre-computed hidden-states (key and values in the self-attention blocks that can be used (see
`past_key_values` input) to speed up sequential decoding.
If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those that
don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of all
`decoder_input_ids` of shape `(batch_size, sequence_length)`.
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
Labels for computing the left-to-right language modeling loss (next word prediction). Indices should be in
`[-100, 0, ..., config.vocab_size]` (see `input_ids` docstring) Tokens with indices set to `-100` are
ignored (masked), the loss is only computed for the tokens with labels n `[0, ..., config.vocab_size]`.
use_cache (`bool`, *optional*):
If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see
`past_key_values`).
Returns:
Example:
```python
>>> from transformers import AutoTokenizer, GPTNeoXForCausalLM, GPTNeoXConfig
>>> import torch
>>> tokenizer = AutoTokenizer.from_pretrained("EleutherAI/gpt-neox-20b")
>>> config = GPTNeoXConfig.from_pretrained("EleutherAI/gpt-neox-20b")
>>> config.is_decoder = True
>>> model = GPTNeoXForCausalLM.from_pretrained("EleutherAI/gpt-neox-20b", config=config)
>>> inputs = tokenizer("Hello, my dog is cute", return_tensors="pt")
>>> outputs = model(**inputs)
>>> prediction_logits = outputs.logits
```"""
return_dict = (
return_dict if return_dict is not None else self.config.use_return_dict
)
outputs = self.gpt_neox(
input_ids,
attention_mask=attention_mask,
position_ids=position_ids,
head_mask=head_mask,
inputs_embeds=inputs_embeds,
past_key_values=past_key_values,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)
hidden_states = outputs[0]
lm_logits = self.embed_out(hidden_states)
lm_loss = None
if labels is not None:
# move labels to correct device to enable model parallelism
labels = labels.to(lm_logits.device)
# we are doing next-token prediction; shift prediction scores and input ids by one
shift_logits = lm_logits[:, :-1, :].contiguous()
labels = labels[:, 1:].contiguous()
loss_fct = CrossEntropyLoss()
lm_loss = loss_fct(
shift_logits.view(-1, shift_logits.size(-1)), labels.view(-1)
)
if not return_dict:
output = (lm_logits,) + outputs[1:]
return ((lm_loss,) + output) if lm_loss is not None else output
return CausalLMOutputWithPast(
loss=lm_loss,
logits=lm_logits,
past_key_values=outputs.past_key_values,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
)
def prepare_inputs_for_generation(
self,
input_ids,
past_key_values=None,
attention_mask=None,
inputs_embeds=None,
**kwargs,
):
input_shape = input_ids.shape
# cut decoder_input_ids if past is used
if past_key_values and past_key_values[0] is not None:
input_ids = input_ids[:, -1:]
position_ids = kwargs.get("position_ids", None)
if attention_mask is not None and position_ids is None:
# create position_ids on the fly for batch generation
position_ids = attention_mask.long().cumsum(-1) - 1
position_ids.masked_fill_(attention_mask == 0, 1)
if past_key_values:
position_ids = position_ids[:, -1].unsqueeze(-1)
# if model is used as a decoder in encoder-decoder model, the decoder attention mask is created on the fly
if attention_mask is None:
attention_mask = input_ids.new_ones(input_shape)
# if `inputs_embeds` are passed, we only want to use them in the 1st generation step
if inputs_embeds is not None and past_key_values is None:
model_inputs = {"inputs_embeds": inputs_embeds}
else:
model_inputs = {"input_ids": input_ids}
model_inputs.update(
{
"attention_mask": attention_mask,
"past_key_values": past_key_values,
"position_ids": position_ids,
}
)
return model_inputs
def _reorder_cache(self, past_key_values, beam_idx):
reordered_past = ()
for layer_past in past_key_values:
reordered_past += (
tuple(
past_state.index_select(0, beam_idx)
for past_state in layer_past[:2]
)
+ layer_past[2:],
)
return reordered_past
# coding=utf-8
# Copyright 2022 The Fairseq Authors and The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
""" PyTorch OPT model."""
import random
from typing import List, Optional, Tuple, Union
import torch
import torch.utils.checkpoint
from torch import nn
from transformers.activations import ACT2FN
from transformers.modeling_outputs import (
BaseModelOutputWithPast,
CausalLMOutputWithPast,
)
from transformers.modeling_utils import PreTrainedModel
from transformers import OPTConfig
from text_generation_server.utils.layers import (
TensorParallelColumnLinear,
TensorParallelEmbedding,
TensorParallelRowLinear,
TensorParallelHead,
)
EPS = 1e-5
# Copied from transformers.models.bart.modeling_bart._make_causal_mask
def _make_causal_mask(
input_ids_shape: torch.Size,
dtype: torch.dtype,
device: torch.device,
past_key_values_length: int = 0,
):
"""
Make causal mask used for bi-directional self-attention.
"""
bsz, tgt_len = input_ids_shape
mask = torch.full(
(tgt_len, tgt_len),
torch.tensor(torch.finfo(dtype).min, device=device),
device=device,
)
mask_cond = torch.arange(mask.size(-1), device=device)
mask.masked_fill_(mask_cond < (mask_cond + 1).view(mask.size(-1), 1), 0)
mask = mask.to(dtype)
if past_key_values_length > 0:
mask = torch.cat(
[
torch.zeros(
tgt_len, past_key_values_length, dtype=dtype, device=device
),
mask,
],
dim=-1,
)
return mask[None, None, :, :].expand(
bsz, 1, tgt_len, tgt_len + past_key_values_length
)
def _expand_mask(mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int] = None):
"""
Expands attention_mask from `[bsz, seq_len]` to `[bsz, 1, tgt_seq_len, src_seq_len]`.
"""
bsz, src_len = mask.size()
tgt_len = tgt_len if tgt_len is not None else src_len
expanded_mask = mask[:, None, None, :].expand(bsz, 1, tgt_len, src_len).to(dtype)
inverted_mask = 1.0 - expanded_mask
return inverted_mask.masked_fill(
inverted_mask.to(torch.bool), torch.finfo(dtype).min
)
class OPTLearnedPositionalEmbedding(nn.Module):
"""
This module learns positional embeddings up to a fixed maximum size.
"""
def __init__(self, weights):
super().__init__()
self.offset = 2
self.weight = nn.Parameter(
weights.get_tensor("model.decoder.embed_positions.weight")
)
def forward(
self, attention_mask: torch.LongTensor, past_key_values_length: int = 0
):
"""`input_ids_shape` is expected to be [bsz x seqlen]."""
attention_mask = attention_mask.long()
# create positions depending on attention_mask
positions = (
torch.cumsum(attention_mask, dim=1).type_as(attention_mask) * attention_mask
).long() - 1
# cut positions if `past_key_values_length` is > 0
positions = positions[:, past_key_values_length:]
return torch.nn.functional.embedding(positions + self.offset, self.weight)
class OPTAttention(nn.Module):
"""Multi-headed attention from 'Attention Is All You Need' paper"""
def __init__(
self,
config,
prefix,
weights,
is_decoder: bool = False,
bias: bool = True,
process_group=None,
):
super().__init__()
embed_dim = config.embed_dim
num_heads = config.num_attention_heads
self.embed_dim = embed_dim
self.num_heads = num_heads
self.dropout = config.dropout
self.head_dim = embed_dim // num_heads
if (self.head_dim * num_heads) != self.embed_dim:
raise ValueError(
f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim}"
f" and `num_heads`: {num_heads})."
)
self.scaling = self.head_dim**-0.5
self.is_decoder = is_decoder
process_group = weights.process_group
assert self.num_heads % process_group.size() == 0
self.num_heads = self.num_heads // process_group.size()
self.embed_dim = self.embed_dim // process_group.size()
self.q_proj = TensorParallelColumnLinear.load(
config, prefix=f"{prefix}.q_proj", weights=weights, bias=bias
)
self.k_proj = TensorParallelColumnLinear.load(
config, prefix=f"{prefix}.k_proj", weights=weights, bias=bias
)
self.v_proj = TensorParallelColumnLinear.load(
config, prefix=f"{prefix}.v_proj", weights=weights, bias=bias
)
self.out_proj = TensorParallelRowLinear.load(
config, prefix=f"{prefix}.out_proj", weights=weights, bias=bias
)
def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int):
return (
tensor.view(bsz, seq_len, self.num_heads, self.head_dim)
.transpose(1, 2)
.contiguous()
)
def forward(
self,
hidden_states: torch.Tensor,
key_value_states: Optional[torch.Tensor] = None,
past_key_value: Optional[Tuple[torch.Tensor]] = None,
attention_mask: Optional[torch.Tensor] = None,
layer_head_mask: Optional[torch.Tensor] = None,
output_attentions: bool = False,
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
"""Input shape: Batch x Time x Channel"""
# if key_value_states are provided this layer is used as a cross-attention layer
# for the decoder
is_cross_attention = key_value_states is not None
bsz, tgt_len, _ = hidden_states.size()
# get query proj
query_states = self.q_proj(hidden_states) * self.scaling
# get key, value proj
if is_cross_attention and past_key_value is not None:
# reuse k,v, cross_attentions
key_states = past_key_value[0]
value_states = past_key_value[1]
elif is_cross_attention:
# cross_attentions
key_states = self._shape(self.k_proj(key_value_states), -1, bsz)
value_states = self._shape(self.v_proj(key_value_states), -1, bsz)
elif past_key_value is not None:
# reuse k, v, self_attention
key_states = self._shape(self.k_proj(hidden_states), -1, bsz)
value_states = self._shape(self.v_proj(hidden_states), -1, bsz)
key_states = torch.cat([past_key_value[0], key_states], dim=2)
value_states = torch.cat([past_key_value[1], value_states], dim=2)
else:
# self_attention
key_states = self._shape(self.k_proj(hidden_states), -1, bsz)
value_states = self._shape(self.v_proj(hidden_states), -1, bsz)
if self.is_decoder:
# if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states.
# Further calls to cross_attention layer can then reuse all cross-attention
# key/value_states (first "if" case)
# if uni-directional self-attention (decoder) save Tuple(torch.Tensor, torch.Tensor) of
# all previous decoder key/value_states. Further calls to uni-directional self-attention
# can concat previous decoder key/value_states to current projected key/value_states (third "elif" case)
# if encoder bi-directional self-attention `past_key_value` is always `None`
past_key_value = (key_states, value_states)
proj_shape = (bsz * self.num_heads, -1, self.head_dim)
query_states = self._shape(query_states, tgt_len, bsz).view(*proj_shape)
key_states = key_states.view(*proj_shape)
value_states = value_states.view(*proj_shape)
src_len = key_states.size(1)
attn_weights = torch.bmm(query_states, key_states.transpose(1, 2))
if attn_weights.size() != (bsz * self.num_heads, tgt_len, src_len):
raise ValueError(
f"Attention weights should be of size {(bsz * self.num_heads, tgt_len, src_len)}, but is"
f" {attn_weights.size()}"
)
if attention_mask is not None:
if attention_mask.size() != (bsz, 1, tgt_len, src_len):
raise ValueError(
f"Attention mask should be of size {(bsz, 1, tgt_len, src_len)}, but is {attention_mask.size()}"
)
attn_weights = (
attn_weights.view(bsz, self.num_heads, tgt_len, src_len)
+ attention_mask
)
attn_weights = torch.max(
attn_weights, torch.tensor(torch.finfo(attn_weights.dtype).min)
)
attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len)
# upcast to fp32 if the weights are in fp16. Please see https://github.com/huggingface/transformers/pull/17437
if attn_weights.dtype == torch.float16:
attn_weights = nn.functional.softmax(
attn_weights, dim=-1, dtype=torch.float32
).to(torch.float16)
else:
attn_weights = nn.functional.softmax(attn_weights, dim=-1)
if layer_head_mask is not None:
if layer_head_mask.size() != (self.num_heads,):
raise ValueError(
f"Head mask for a single layer should be of size {(self.num_heads,)}, but is"
f" {layer_head_mask.size()}"
)
attn_weights = layer_head_mask.view(1, -1, 1, 1) * attn_weights.view(
bsz, self.num_heads, tgt_len, src_len
)
attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len)
if output_attentions:
# this operation is a bit awkward, but it's required to
# make sure that attn_weights keeps its gradient.
# In order to do so, attn_weights have to be reshaped
# twice and have to be reused in the following
attn_weights_reshaped = attn_weights.view(
bsz, self.num_heads, tgt_len, src_len
)
attn_weights = attn_weights_reshaped.view(
bsz * self.num_heads, tgt_len, src_len
)
else:
attn_weights_reshaped = None
attn_probs = nn.functional.dropout(
attn_weights, p=self.dropout, training=self.training
)
attn_output = torch.bmm(attn_probs, value_states)
if attn_output.size() != (bsz * self.num_heads, tgt_len, self.head_dim):
raise ValueError(
f"`attn_output` should be of size {(bsz, self.num_heads, tgt_len, self.head_dim)}, but is"
f" {attn_output.size()}"
)
attn_output = attn_output.view(bsz, self.num_heads, tgt_len, self.head_dim)
attn_output = attn_output.transpose(1, 2)
# Use the `embed_dim` from the config (stored in the class) rather than `hidden_state` because `attn_output` can be
# partitioned aross GPUs when using tensor-parallelism.
attn_output = attn_output.reshape(bsz, tgt_len, self.embed_dim)
attn_output = self.out_proj(attn_output)
return attn_output, attn_weights_reshaped, past_key_value
class OPTDecoderLayer(nn.Module):
def __init__(self, layer_id: int, config: OPTConfig, weights):
super().__init__()
self.process_group = weights.process_group
self.embed_dim = config.hidden_size
prefix = f"model.decoder.layers.{layer_id}"
self.self_attn = OPTAttention(
config,
prefix=f"{prefix}.self_attn",
weights=weights,
is_decoder=True,
bias=config.enable_bias,
)
self.do_layer_norm_before = config.do_layer_norm_before
self.dropout = config.dropout
self.activation_fn = ACT2FN[config.activation_function]
self.self_attn_layer_norm = nn.LayerNorm.load(
prefix=f"{prefix}.self_attn_layer_norm", weights=weights, eps=EPS
)
self.fc1 = TensorParallelColumnLinear.load(
config, prefix=f"{prefix}.fc1", weights=weights, bias=config.enable_bias
)
self.fc2 = TensorParallelRowLinear.load(
config, prefix=f"{prefix}.fc2", weights=weights, bias=config.enable_bias
)
self.final_layer_norm = nn.LayerNorm.load(
prefix=f"{prefix}.final_layer_norm", weights=weights, eps=EPS
)
def forward(
self,
hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
layer_head_mask: Optional[torch.Tensor] = None,
output_attentions: Optional[bool] = False,
use_cache: Optional[bool] = False,
past_key_value: Optional[Tuple[torch.Tensor]] = None,
) -> Tuple[
torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]
]:
"""
Args:
hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`
attention_mask (`torch.FloatTensor`, *optional*): attention mask of size
`(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values.
layer_head_mask (`torch.FloatTensor`, *optional*): mask for attention heads in a given layer of size
`(encoder_attention_heads,)`.
output_attentions (`bool`, *optional*):
Whether or not to return the attentions tensors of all attention layers. See `attentions` under
returned tensors for more detail.
use_cache (`bool`, *optional*):
If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding
(see `past_key_values`).
past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states
"""
residual = hidden_states
# 125m, 1.7B, ..., 175B applies layer norm BEFORE attention
if self.do_layer_norm_before:
hidden_states = self.self_attn_layer_norm(hidden_states)
# Self Attention
hidden_states, self_attn_weights, present_key_value = self.self_attn(
hidden_states=hidden_states,
past_key_value=past_key_value,
attention_mask=attention_mask,
layer_head_mask=layer_head_mask,
output_attentions=output_attentions,
)
hidden_states = nn.functional.dropout(
hidden_states, p=self.dropout, training=self.training
)
hidden_states = residual + hidden_states
# 350m applies layer norm AFTER attention
if not self.do_layer_norm_before:
hidden_states = self.self_attn_layer_norm(hidden_states)
# Fully Connected
hidden_states_shape = hidden_states.shape
hidden_states = hidden_states.reshape(-1, hidden_states.size(-1))
residual = hidden_states
# 125m, 1.7B, ..., 175B applies layer norm BEFORE attention
if self.do_layer_norm_before:
hidden_states = self.final_layer_norm(hidden_states)
hidden_states = self.fc1(hidden_states)
hidden_states = self.activation_fn(hidden_states)
hidden_states = self.fc2(hidden_states)
hidden_states = nn.functional.dropout(
hidden_states, p=self.dropout, training=self.training
)
hidden_states = (residual + hidden_states).view(hidden_states_shape)
# 350m applies layer norm AFTER attention
if not self.do_layer_norm_before:
hidden_states = self.final_layer_norm(hidden_states)
outputs = (hidden_states,)
if output_attentions:
outputs += (self_attn_weights,)
if use_cache:
outputs += (present_key_value,)
return outputs
class OPTPreTrainedModel(PreTrainedModel):
config_class = OPTConfig
class OPTDecoder(OPTPreTrainedModel):
def __init__(self, config: OPTConfig, weights):
super().__init__(config)
self.dropout = config.dropout
self.layerdrop = config.layerdrop
self.padding_idx = config.pad_token_id
self.max_target_positions = config.max_position_embeddings
self.vocab_size = config.vocab_size
self.embed_tokens = TensorParallelEmbedding(
prefix="model.decoder.embed_tokens", weights=weights
)
self.embed_positions = OPTLearnedPositionalEmbedding(weights)
if config.word_embed_proj_dim != config.hidden_size:
self.project_out = FastLinear.load(
config, prefix="model.decoder.project_out", bias=False
)
else:
self.project_out = None
if config.word_embed_proj_dim != config.hidden_size:
self.project_in = FastLinear.load(
config, prefix="model.decoder.project_in", bias=False
)
else:
self.project_in = None
# Note that the only purpose of `config._remove_final_layer_norm` is to keep backward compatibility
# with checkpoints that have been fine-tuned before transformers v4.20.1
# see https://github.com/facebookresearch/metaseq/pull/164
if config.do_layer_norm_before and not config._remove_final_layer_norm:
self.final_layer_norm = nn.LayerNorm.load(
prefix="model.decoder.final_layer_norm", weights=weights, eps=EPS
)
else:
self.final_layer_norm = None
self.layers = nn.ModuleList(
[
OPTDecoderLayer(layer_id, config, weights)
for layer_id in range(config.num_hidden_layers)
]
)
# Copied from transformers.models.bart.modeling_bart.BartDecoder._prepare_decoder_attention_mask
def _prepare_decoder_attention_mask(
self, attention_mask, input_shape, inputs_embeds, past_key_values_length
):
# create causal mask
# [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
combined_attention_mask = None
if input_shape[-1] > 1:
combined_attention_mask = _make_causal_mask(
input_shape,
inputs_embeds.dtype,
device=inputs_embeds.device,
past_key_values_length=past_key_values_length,
)
if attention_mask is not None:
# [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
expanded_attn_mask = _expand_mask(
attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1]
).to(inputs_embeds.device)
combined_attention_mask = (
expanded_attn_mask
if combined_attention_mask is None
else expanded_attn_mask + combined_attention_mask
)
return combined_attention_mask
def forward(
self,
input_ids: torch.LongTensor = None,
attention_mask: Optional[torch.Tensor] = None,
head_mask: Optional[torch.Tensor] = None,
past_key_values: Optional[List[torch.FloatTensor]] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
) -> Union[Tuple, BaseModelOutputWithPast]:
r"""
Args:
input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you
provide it.
Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
[`PreTrainedTokenizer.__call__`] for details.
[What are input IDs?](../glossary#input-ids)
attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
- 1 for tokens that are **not masked**,
- 0 for tokens that are **masked**.
[What are attention masks?](../glossary#attention-mask)
head_mask (`torch.Tensor` of shape `(num_hidden_layers, num_attention_heads)`, *optional*):
Mask to nullify selected heads of the attention modules. Mask values selected in `[0, 1]`:
- 1 indicates the head is **not masked**,
- 0 indicates the head is **masked**.
past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of
shape `(batch_size, num_heads, sequence_length, embed_size_per_head)`) and 2 additional tensors of
Contains pre-computed hidden-states (key and values in the self-attention blocks and in the
cross-attention blocks) that can be used (see `past_key_values` input) to speed up sequential decoding.
If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those
that don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of
all `decoder_input_ids` of shape `(batch_size, sequence_length)`.
inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation.
This is useful if you want more control over how to convert `input_ids` indices into associated vectors
than the model's internal embedding lookup matrix.
output_attentions (`bool`, *optional*):
Whether or not to return the attentions tensors of all attention layers. See `attentions` under
returned tensors for more detail.
output_hidden_states (`bool`, *optional*):
Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors
for more detail.
return_dict (`bool`, *optional*):
Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
"""
output_attentions = (
output_attentions
if output_attentions is not None
else self.config.output_attentions
)
output_hidden_states = (
output_hidden_states
if output_hidden_states is not None
else self.config.output_hidden_states
)
use_cache = use_cache if use_cache is not None else self.config.use_cache
return_dict = (
return_dict if return_dict is not None else self.config.use_return_dict
)
# retrieve input_ids and inputs_embeds
if input_ids is not None and inputs_embeds is not None:
raise ValueError(
"You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time"
)
elif input_ids is not None:
input_shape = input_ids.size()
input_ids = input_ids.view(-1, input_shape[-1])
elif inputs_embeds is not None:
input_shape = inputs_embeds.size()[:-1]
else:
raise ValueError(
"You have to specify either decoder_input_ids or decoder_inputs_embeds"
)
if inputs_embeds is None:
inputs_embeds = self.embed_tokens(input_ids)
batch_size, seq_length = input_shape
past_key_values_length = (
past_key_values[0][0].shape[2] if past_key_values is not None else 0
)
# required mask seq length can be calculated via length of past
mask_seq_length = past_key_values_length + seq_length
# embed positions
if attention_mask is None:
attention_mask = torch.ones(
batch_size, mask_seq_length, device=inputs_embeds.device
)
causal_attention_mask = self._prepare_decoder_attention_mask(
attention_mask, input_shape, inputs_embeds, past_key_values_length
)
pos_embeds = self.embed_positions(attention_mask, past_key_values_length)
if self.project_in is not None:
inputs_embeds = self.project_in(inputs_embeds)
hidden_states = inputs_embeds + pos_embeds
# decoder layers
all_hidden_states = () if output_hidden_states else None
all_self_attns = () if output_attentions else None
next_decoder_cache = () if use_cache else None
# check if head_mask has a correct number of layers specified if desired
for attn_mask, mask_name in zip([head_mask], ["head_mask"]):
if attn_mask is not None:
if attn_mask.size()[0] != (len(self.layers)):
raise ValueError(
f"The `{mask_name}` should be specified for {len(self.layers)} layers, but it is for"
f" {head_mask.size()[0]}."
)
for idx, decoder_layer in enumerate(self.layers):
# add LayerDrop (see https://arxiv.org/abs/1909.11556 for description)
if output_hidden_states:
all_hidden_states += (hidden_states,)
dropout_probability = random.uniform(0, 1)
if self.training and (dropout_probability < self.layerdrop):
continue
past_key_value = (
past_key_values[idx] if past_key_values is not None else None
)
layer_outputs = decoder_layer(
hidden_states,
attention_mask=causal_attention_mask,
layer_head_mask=(head_mask[idx] if head_mask is not None else None),
past_key_value=past_key_value,
output_attentions=output_attentions,
use_cache=use_cache,
)
hidden_states = layer_outputs[0]
if use_cache:
next_decoder_cache += (layer_outputs[2 if output_attentions else 1],)
if output_attentions:
all_self_attns += (layer_outputs[1],)
if self.final_layer_norm is not None:
hidden_states = self.final_layer_norm(hidden_states)
if self.project_out is not None:
hidden_states = self.project_out(hidden_states)
# add hidden states from the last decoder layer
if output_hidden_states:
all_hidden_states += (hidden_states,)
next_cache = next_decoder_cache if use_cache else None
if not return_dict:
return tuple(
v
for v in [hidden_states, next_cache, all_hidden_states, all_self_attns]
if v is not None
)
return BaseModelOutputWithPast(
last_hidden_state=hidden_states,
past_key_values=next_cache,
hidden_states=all_hidden_states,
attentions=all_self_attns,
)
class OPTModel(OPTPreTrainedModel):
def __init__(self, config: OPTConfig, weights):
super().__init__(config)
self.decoder = OPTDecoder(config, weights)
# Initialize weights and apply final processing
def forward(
self,
input_ids: torch.LongTensor = None,
attention_mask: Optional[torch.Tensor] = None,
head_mask: Optional[torch.Tensor] = None,
past_key_values: Optional[List[torch.FloatTensor]] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
) -> Union[Tuple, BaseModelOutputWithPast]:
output_attentions = (
output_attentions
if output_attentions is not None
else self.config.output_attentions
)
output_hidden_states = (
output_hidden_states
if output_hidden_states is not None
else self.config.output_hidden_states
)
use_cache = use_cache if use_cache is not None else self.config.use_cache
return_dict = (
return_dict if return_dict is not None else self.config.use_return_dict
)
# decoder outputs consists of (dec_features, past_key_value, dec_hidden, dec_attn)
decoder_outputs = self.decoder(
input_ids=input_ids,
attention_mask=attention_mask,
head_mask=head_mask,
past_key_values=past_key_values,
inputs_embeds=inputs_embeds,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)
if not return_dict:
return decoder_outputs
return BaseModelOutputWithPast(
last_hidden_state=decoder_outputs.last_hidden_state,
past_key_values=decoder_outputs.past_key_values,
hidden_states=decoder_outputs.hidden_states,
attentions=decoder_outputs.attentions,
)
class OPTForCausalLM(OPTPreTrainedModel):
def __init__(self, config, weights):
super().__init__(config)
self.model = OPTModel(config, weights)
self.lm_head = TensorParallelHead.load(
config, prefix="model.decoder.embed_tokens", weights=weights
)
def forward(
self,
input_ids: torch.LongTensor = None,
attention_mask: Optional[torch.Tensor] = None,
head_mask: Optional[torch.Tensor] = None,
past_key_values: Optional[List[torch.FloatTensor]] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
labels: Optional[torch.LongTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
) -> Union[Tuple, CausalLMOutputWithPast]:
output_attentions = (
output_attentions
if output_attentions is not None
else self.config.output_attentions
)
output_hidden_states = (
output_hidden_states
if output_hidden_states is not None
else self.config.output_hidden_states
)
return_dict = (
return_dict if return_dict is not None else self.config.use_return_dict
)
# decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
outputs = self.model.decoder(
input_ids=input_ids,
attention_mask=attention_mask,
head_mask=head_mask,
past_key_values=past_key_values,
inputs_embeds=inputs_embeds,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)
logits = self.lm_head(outputs[0]).contiguous()
loss = None
return CausalLMOutputWithPast(
loss=loss,
logits=logits,
past_key_values=outputs.past_key_values,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
)
def prepare_inputs_for_generation(
self,
input_ids,
past_key_values=None,
attention_mask=None,
inputs_embeds=None,
**kwargs,
):
if past_key_values:
input_ids = input_ids[:, -1:]
# if `inputs_embeds` are passed, we only want to use them in the 1st generation step
if inputs_embeds is not None and past_key_values is None:
model_inputs = {"inputs_embeds": inputs_embeds}
else:
model_inputs = {"input_ids": input_ids}
model_inputs.update(
{
"past_key_values": past_key_values,
"use_cache": kwargs.get("use_cache"),
"attention_mask": attention_mask,
}
)
return model_inputs
@staticmethod
def _reorder_cache(past_key_values, beam_idx):
reordered_past = ()
for layer_past in past_key_values:
reordered_past += (
tuple(
past_state.index_select(0, beam_idx) for past_state in layer_past
),
)
return reordered_past
# coding=utf-8
# Copyright 2018 Mesh TensorFlow authors, T5 Authors and HuggingFace Inc. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
""" PyTorch T5 model."""
import copy
import math
import warnings
from typing import Optional, Tuple, Union
import torch
import torch.distributed
from torch import nn
from torch.nn import CrossEntropyLoss
from transformers.activations import ACT2FN
from transformers.modeling_outputs import (
BaseModelOutput,
BaseModelOutputWithPastAndCrossAttentions,
Seq2SeqLMOutput,
)
from transformers.modeling_utils import PreTrainedModel
from transformers.pytorch_utils import ALL_LAYERNORM_LAYERS
from transformers.utils import (
is_torch_fx_proxy,
)
from transformers import T5Config
from text_generation_server.utils.layers import (
TensorParallelColumnLinear,
TensorParallelEmbedding,
TensorParallelRowLinear,
TensorParallelHead,
)
class PartialTPEmbedding(nn.Module):
def __init__(self, prefix: str, weights):
super().__init__()
weight = weights.get_sharded(f"{prefix}.weight", dim=1)
self.weight = nn.Parameter(weight)
def forward(self, input: torch.Tensor) -> torch.Tensor:
return torch.nn.functional.embedding(input, self.weight)
@torch.jit.script
def layer_norm(hidden_states, weight, epsilon):
# T5 uses a layer_norm which only scales and doesn't shift, which is also known as Root Mean
# Square Layer Normalization https://arxiv.org/abs/1910.07467 thus varience is calculated
# w/o mean and there is no bias. Additionally we want to make sure that the accumulation for
# half-precision inputs is done in fp32
variance = hidden_states.to(torch.float32).pow(2).mean(-1, keepdim=True)
hidden_states = hidden_states * torch.rsqrt(variance + epsilon)
# convert into half-precision if necessary
if weight.dtype in [torch.float16, torch.bfloat16]:
hidden_states = hidden_states.to(weight.dtype)
return weight * hidden_states
class T5LayerNorm(nn.Module):
def __init__(self, prefix, weights, eps=1e-6):
"""
Construct a layernorm module in the T5 style. No bias and no subtraction of mean.
"""
super().__init__()
weight = weights.get_tensor(f"{prefix}.weight")
self.weight = nn.Parameter(weight)
self.variance_epsilon = torch.tensor(eps)
def forward(self, hidden_states):
return layer_norm(hidden_states, self.weight, self.variance_epsilon)
try:
from apex.normalization import FusedRMSNorm
T5LayerNorm = FusedRMSNorm # noqa
logger.info(
"Discovered apex.normalization.FusedRMSNorm - will use it instead of T5LayerNorm"
)
except ImportError:
# using the normal T5LayerNorm
pass
except Exception:
logger.warning("discovered apex but it failed to load, falling back to T5LayerNorm")
pass
ALL_LAYERNORM_LAYERS.append(T5LayerNorm)
class T5DenseActDense(nn.Module):
def __init__(self, config: T5Config, prefix, weights):
super().__init__()
self.wi = TensorParallelColumnLinear.load(
config, prefix=f"{prefix}.wi", weights=weights, bias=False
)
### XXX: T5 models do not handle well both f16 and quantization.
### Overidding specifically this layer for that reason.
### https://github.com/huggingface/transformers/blob/main/src/transformers/models/t5/modeling_t5.py#L316
### https://github.com/huggingface/transformers/issues/20287
_q = config.quantize
_dtype = weights.dtype
weights.dtype = torch.float32
config.quantize = None
self.wo_cast = (torch.float32, _dtype)
self.wo = TensorParallelRowLinear.load(
config, prefix=f"{prefix}.wo", weights=weights, bias=False
)
weights.dtype = _dtype
config.quantize = _q
self.dropout = nn.Dropout(config.dropout_rate)
self.act = (
ACT2FN[config.dense_act_fn]
if "gelu" not in config.dense_act_fn
else lambda x: torch.nn.functional.gelu(x, approximate="tanh")
)
def forward(self, hidden_states):
hidden_states = self.wi(hidden_states)
hidden_states = self.act(hidden_states)
hidden_states = self.dropout(hidden_states)
hidden_states = hidden_states.to(dtype=self.wo_cast[0])
hidden_states = self.wo(hidden_states)
# XXX: Recasting is already done within the layer norm.
# Casting back to float16 here modifies results
# hidden_states = hidden_states.to(dtype=self.wo_cast[1])
return hidden_states
class T5DenseGatedActDense(nn.Module):
def __init__(self, config: T5Config, prefix, weights):
super().__init__()
self.wi_0 = TensorParallelColumnLinear.load(
config, prefix=f"{prefix}.wi_0", weights=weights, bias=False
)
self.wi_1 = TensorParallelColumnLinear.load(
config, prefix=f"{prefix}.wi_1", weights=weights, bias=False
)
### XXX: T5 models do not handle well both f16 and quantization.
### Overidding specifically this layer for that reason.
### https://github.com/huggingface/transformers/blob/main/src/transformers/models/t5/modeling_t5.py#L316
### https://github.com/huggingface/transformers/issues/20287
_q = config.quantize
_dtype = weights.dtype
weights.dtype = torch.float32
config.quantize = None
self.wo_cast = (torch.float32, _dtype)
self.wo = TensorParallelRowLinear.load(
config, prefix=f"{prefix}.wo", weights=weights, bias=False
)
weights.dtype = _dtype
config.quantize = _q
self.dropout = nn.Dropout(config.dropout_rate)
self.act = (
ACT2FN[config.dense_act_fn]
if "gelu" not in config.dense_act_fn
else lambda x: torch.nn.functional.gelu(x, approximate="tanh")
)
def forward(self, hidden_states):
hidden_gelu = self.act(self.wi_0(hidden_states))
hidden_linear = self.wi_1(hidden_states)
hidden_states = hidden_gelu * hidden_linear
hidden_states = self.dropout(hidden_states)
hidden_states = hidden_states.to(dtype=self.wo_cast[0])
hidden_states = self.wo(hidden_states)
# XXX: Recasting is already done within the layer norm.
# Casting back to float16 here modifies results
# hidden_states = hidden_states.to(dtype=self.wo_cast[1])
return hidden_states
class T5LayerFF(nn.Module):
def __init__(self, config: T5Config, prefix, weights):
super().__init__()
if config.is_gated_act:
self.DenseReluDense = T5DenseGatedActDense(
config, prefix=f"{prefix}.DenseReluDense", weights=weights
)
else:
self.DenseReluDense = T5DenseActDense(
config, prefix=f"{prefix}.DenseReluDense", weights=weights
)
self.layer_norm = T5LayerNorm(
prefix=f"{prefix}.layer_norm",
weights=weights,
eps=config.layer_norm_epsilon,
)
self.dropout = nn.Dropout(config.dropout_rate)
def forward(self, hidden_states):
forwarded_states = self.layer_norm(hidden_states)
forwarded_states = self.DenseReluDense(forwarded_states)
hidden_states = hidden_states + self.dropout(forwarded_states)
return hidden_states
class T5Attention(nn.Module):
def __init__(
self, config: T5Config, prefix, weights, has_relative_attention_bias=False
):
super().__init__()
self.is_decoder = config.is_decoder
self.has_relative_attention_bias = has_relative_attention_bias
self.relative_attention_num_buckets = config.relative_attention_num_buckets
self.relative_attention_max_distance = config.relative_attention_max_distance
self.d_model = config.d_model
self.key_value_proj_dim = config.d_kv
self.n_heads = config.num_heads
self.dropout = config.dropout_rate
self.inner_dim = self.n_heads * self.key_value_proj_dim
process_group = weights.process_group
# Mesh TensorFlow initialization to avoid scaling before softmax
assert self.n_heads % process_group.size() == 0
self.q = TensorParallelColumnLinear.load(
config, prefix=f"{prefix}.q", weights=weights, bias=False
)
self.k = TensorParallelColumnLinear.load(
config, prefix=f"{prefix}.k", weights=weights, bias=False
)
self.v = TensorParallelColumnLinear.load(
config, prefix=f"{prefix}.v", weights=weights, bias=False
)
self.o = TensorParallelRowLinear.load(
config, prefix=f"{prefix}.o", weights=weights, bias=False
)
self.n_heads = self.n_heads // process_group.size()
self.inner_dim = self.inner_dim // process_group.size()
if self.has_relative_attention_bias:
self.relative_attention_bias = PartialTPEmbedding(
prefix=f"{prefix}.relative_attention_bias", weights=weights
)
@staticmethod
def _relative_position_bucket(
relative_position, bidirectional=True, num_buckets=32, max_distance=128
):
"""
Adapted from Mesh Tensorflow:
https://github.com/tensorflow/mesh/blob/0cb87fe07da627bf0b7e60475d59f95ed6b5be3d/mesh_tensorflow/transformer/transformer_layers.py#L593
Translate relative position to a bucket number for relative attention. The relative position is defined as
memory_position - query_position, i.e. the distance in tokens from the attending position to the attended-to
position. If bidirectional=False, then positive relative positions are invalid. We use smaller buckets for
small absolute relative_position and larger buckets for larger absolute relative_positions. All relative
positions >=max_distance map to the same bucket. All relative positions <=-max_distance map to the same bucket.
This should allow for more graceful generalization to longer sequences than the model has been trained on
Args:
relative_position: an int32 Tensor
bidirectional: a boolean - whether the attention is bidirectional
num_buckets: an integer
max_distance: an integer
Returns:
a Tensor with the same shape as relative_position, containing int32 values in the range [0, num_buckets)
"""
relative_buckets = 0
if bidirectional:
num_buckets //= 2
relative_buckets += (relative_position > 0).to(torch.long) * num_buckets
relative_position = torch.abs(relative_position)
else:
relative_position = -torch.min(
relative_position, torch.zeros_like(relative_position)
)
# now relative_position is in the range [0, inf)
# half of the buckets are for exact increments in positions
max_exact = num_buckets // 2
is_small = relative_position < max_exact
# The other half of the buckets are for logarithmically bigger bins in positions up to max_distance
relative_position_if_large = max_exact + (
torch.log(relative_position.float() / max_exact)
/ math.log(max_distance / max_exact)
* (num_buckets - max_exact)
).to(torch.long)
relative_position_if_large = torch.min(
relative_position_if_large,
torch.full_like(relative_position_if_large, num_buckets - 1),
)
relative_buckets += torch.where(
is_small, relative_position, relative_position_if_large
)
return relative_buckets
def compute_bias(self, query_length, key_length, device=None):
"""Compute binned relative position bias"""
if device is None:
device = self.relative_attention_bias.weight.device
context_position = torch.arange(query_length, dtype=torch.long, device=device)[
:, None
]
memory_position = torch.arange(key_length, dtype=torch.long, device=device)[
None, :
]
relative_position = (
memory_position - context_position
) # shape (query_length, key_length)
relative_position_bucket = self._relative_position_bucket(
relative_position, # shape (query_length, key_length)
bidirectional=(not self.is_decoder),
num_buckets=self.relative_attention_num_buckets,
max_distance=self.relative_attention_max_distance,
)
values = self.relative_attention_bias(
relative_position_bucket
) # shape (query_length, key_length, num_heads)
values = values.permute([2, 0, 1]).unsqueeze(
0
) # shape (1, num_heads, query_length, key_length)
return values
def forward(
self,
hidden_states,
mask=None,
key_value_states=None,
position_bias=None,
past_key_value=None,
layer_head_mask=None,
query_length=None,
use_cache=False,
output_attentions=False,
):
"""
Self-attention (if key_value_states is None) or attention over source sentence (provided by key_value_states).
"""
# Input is (batch_size, seq_length, dim)
# Mask is (batch_size, key_length) (non-causal) or (batch_size, key_length, key_length)
# past_key_value[0] is (batch_size, n_heads, q_len - 1, dim_per_head)
batch_size, seq_length = hidden_states.shape[:2]
real_seq_length = seq_length
if past_key_value is not None:
assert (
len(past_key_value) == 2
), f"past_key_value should have 2 past states: keys and values. Got {len(past_key_value)} past states"
real_seq_length += (
past_key_value[0].shape[2] if query_length is None else query_length
)
key_length = (
real_seq_length if key_value_states is None else key_value_states.shape[1]
)
def shape(states):
"""projection"""
return states.view(
batch_size, -1, self.n_heads, self.key_value_proj_dim
).transpose(1, 2)
def unshape(states):
"""reshape"""
return (
states.transpose(1, 2).contiguous().view(batch_size, -1, self.inner_dim)
)
def project(hidden_states, proj_layer, key_value_states, past_key_value):
"""projects hidden states correctly to key/query states"""
if key_value_states is None:
# self-attn
# (batch_size, n_heads, seq_length, dim_per_head)
hidden_states = shape(proj_layer(hidden_states))
elif past_key_value is None:
# cross-attn
# (batch_size, n_heads, seq_length, dim_per_head)
hidden_states = shape(proj_layer(key_value_states))
if past_key_value is not None:
if key_value_states is None:
# self-attn
# (batch_size, n_heads, key_length, dim_per_head)
hidden_states = torch.cat([past_key_value, hidden_states], dim=2)
elif past_key_value.shape[2] != key_value_states.shape[1]:
# checking that the `sequence_length` of the `past_key_value` is the same as
# the provided `key_value_states` to support prefix tuning
# cross-attn
# (batch_size, n_heads, seq_length, dim_per_head)
hidden_states = shape(proj_layer(key_value_states))
else:
# cross-attn
hidden_states = past_key_value
return hidden_states
# get query states
query_states = shape(
self.q(hidden_states)
) # (batch_size, n_heads, seq_length, dim_per_head)
# get key/value states
key_states = project(
hidden_states,
self.k,
key_value_states,
past_key_value[0] if past_key_value is not None else None,
)
value_states = project(
hidden_states,
self.v,
key_value_states,
past_key_value[1] if past_key_value is not None else None,
)
# compute scores
scores = torch.matmul(
query_states, key_states.transpose(3, 2)
) # equivalent of torch.einsum("bnqd,bnkd->bnqk", query_states, key_states), compatible with onnx op>9
if position_bias is None:
if not self.has_relative_attention_bias:
position_bias = torch.zeros(
(1, self.n_heads, real_seq_length, key_length),
device=scores.device,
dtype=scores.dtype,
)
else:
position_bias = self.compute_bias(
real_seq_length, key_length, device=scores.device
)
# if key and values are already calculated
# we want only the last query position bias
if past_key_value is not None:
position_bias = position_bias[:, :, -hidden_states.size(1) :, :]
if mask is not None:
position_bias = (
position_bias + mask
) # (batch_size, n_heads, seq_length, key_length)
position_bias_masked = position_bias
scores += position_bias_masked
attn_weights = nn.functional.softmax(scores.float(), dim=-1).type_as(
scores
) # (batch_size, n_heads, seq_length, key_length)
attn_weights = nn.functional.dropout(
attn_weights, p=self.dropout, training=self.training
) # (batch_size, n_heads, seq_length, key_length)
# Mask heads if we want to
if layer_head_mask is not None:
attn_weights = attn_weights * layer_head_mask
attn_output = unshape(
torch.matmul(attn_weights, value_states)
) # (batch_size, seq_length, dim)
attn_output = self.o(attn_output)
present_key_value_state = (
(key_states, value_states) if (self.is_decoder and use_cache) else None
)
outputs = (attn_output,) + (present_key_value_state,) + (position_bias,)
if output_attentions:
outputs = outputs + (attn_weights,)
return outputs
class T5LayerSelfAttention(nn.Module):
def __init__(self, config, prefix, weights, has_relative_attention_bias=False):
super().__init__()
self.SelfAttention = T5Attention(
config,
prefix=f"{prefix}.SelfAttention",
weights=weights,
has_relative_attention_bias=has_relative_attention_bias,
)
self.layer_norm = T5LayerNorm(
prefix=f"{prefix}.layer_norm",
weights=weights,
eps=config.layer_norm_epsilon,
)
self.dropout = nn.Dropout(config.dropout_rate)
def forward(
self,
hidden_states,
attention_mask=None,
position_bias=None,
layer_head_mask=None,
past_key_value=None,
use_cache=False,
output_attentions=False,
):
normed_hidden_states = self.layer_norm(hidden_states)
attention_output = self.SelfAttention(
normed_hidden_states,
mask=attention_mask,
position_bias=position_bias,
layer_head_mask=layer_head_mask,
past_key_value=past_key_value,
use_cache=use_cache,
output_attentions=output_attentions,
)
hidden_states = hidden_states + self.dropout(attention_output[0])
outputs = (hidden_states,) + attention_output[
1:
] # add attentions if we output them
return outputs
class T5LayerCrossAttention(nn.Module):
def __init__(self, config, prefix, weights):
super().__init__()
self.EncDecAttention = T5Attention(
config,
prefix=f"{prefix}.EncDecAttention",
weights=weights,
has_relative_attention_bias=False,
)
self.layer_norm = T5LayerNorm(
prefix=f"{prefix}.layer_norm",
weights=weights,
eps=config.layer_norm_epsilon,
)
self.dropout = nn.Dropout(config.dropout_rate)
def forward(
self,
hidden_states,
key_value_states,
attention_mask=None,
position_bias=None,
layer_head_mask=None,
past_key_value=None,
use_cache=False,
query_length=None,
output_attentions=False,
):
normed_hidden_states = self.layer_norm(hidden_states)
attention_output = self.EncDecAttention(
normed_hidden_states,
mask=attention_mask,
key_value_states=key_value_states,
position_bias=position_bias,
layer_head_mask=layer_head_mask,
past_key_value=past_key_value,
use_cache=use_cache,
query_length=query_length,
output_attentions=output_attentions,
)
layer_output = hidden_states + self.dropout(attention_output[0])
outputs = (layer_output,) + attention_output[
1:
] # add attentions if we output them
return outputs
class T5Block(nn.Module):
def __init__(self, config, prefix, weights, has_relative_attention_bias: bool):
super().__init__()
self.is_decoder = config.is_decoder
self.layer = nn.ModuleList()
self.layer.append(
T5LayerSelfAttention(
config,
prefix=f"{prefix}.layer.0",
weights=weights,
has_relative_attention_bias=has_relative_attention_bias,
)
)
if self.is_decoder:
i = 2
self.layer.append(
T5LayerCrossAttention(
config, prefix=f"{prefix}.layer.1", weights=weights
)
)
else:
i = 1
self.layer.append(
T5LayerFF(config, prefix=f"{prefix}.layer.{i}", weights=weights)
)
def forward(
self,
hidden_states,
attention_mask=None,
position_bias=None,
encoder_hidden_states=None,
encoder_attention_mask=None,
encoder_decoder_position_bias=None,
layer_head_mask=None,
cross_attn_layer_head_mask=None,
past_key_value=None,
use_cache=False,
output_attentions=False,
return_dict=True,
):
if past_key_value is not None:
if not self.is_decoder:
logger.warning(
"`past_key_values` is passed to the encoder. Please make sure this is intended."
)
expected_num_past_key_values = 2 if encoder_hidden_states is None else 4
if len(past_key_value) != expected_num_past_key_values:
raise ValueError(
f"There should be {expected_num_past_key_values} past states. "
f"{'2 (past / key) for cross attention. ' if expected_num_past_key_values == 4 else ''}"
f"Got {len(past_key_value)} past key / value states"
)
self_attn_past_key_value = past_key_value[:2]
cross_attn_past_key_value = past_key_value[2:]
else:
self_attn_past_key_value, cross_attn_past_key_value = None, None
self_attention_outputs = self.layer[0](
hidden_states,
attention_mask=attention_mask,
position_bias=position_bias,
layer_head_mask=layer_head_mask,
past_key_value=self_attn_past_key_value,
use_cache=use_cache,
output_attentions=output_attentions,
)
hidden_states, present_key_value_state = self_attention_outputs[:2]
attention_outputs = self_attention_outputs[
2:
] # Keep self-attention outputs and relative position weights
# clamp inf values to enable fp16 training
if hidden_states.dtype == torch.float16:
clamp_value = torch.where(
torch.isinf(hidden_states).any(),
torch.finfo(hidden_states.dtype).max - 1000,
torch.finfo(hidden_states.dtype).max,
)
hidden_states = torch.clamp(
hidden_states, min=-clamp_value, max=clamp_value
)
do_cross_attention = self.is_decoder and encoder_hidden_states is not None
if do_cross_attention:
# the actual query length is unknown for cross attention
# if using past key value states. Need to inject it here
if present_key_value_state is not None:
query_length = present_key_value_state[0].shape[2]
else:
query_length = None
cross_attention_outputs = self.layer[1](
hidden_states,
key_value_states=encoder_hidden_states,
attention_mask=encoder_attention_mask,
position_bias=encoder_decoder_position_bias,
layer_head_mask=cross_attn_layer_head_mask,
past_key_value=cross_attn_past_key_value,
query_length=query_length,
use_cache=use_cache,
output_attentions=output_attentions,
)
hidden_states = cross_attention_outputs[0]
# clamp inf values to enable fp16 training
if hidden_states.dtype == torch.float16:
clamp_value = torch.where(
torch.isinf(hidden_states).any(),
torch.finfo(hidden_states.dtype).max - 1000,
torch.finfo(hidden_states.dtype).max,
)
hidden_states = torch.clamp(
hidden_states, min=-clamp_value, max=clamp_value
)
# Combine self attn and cross attn key value states
if present_key_value_state is not None:
present_key_value_state = (
present_key_value_state + cross_attention_outputs[1]
)
# Keep cross-attention outputs and relative position weights
attention_outputs = attention_outputs + cross_attention_outputs[2:]
# Apply Feed Forward layer
hidden_states = self.layer[-1](hidden_states)
# clamp inf values to enable fp16 training
if hidden_states.dtype == torch.float16:
clamp_value = torch.where(
torch.isinf(hidden_states).any(),
torch.finfo(hidden_states.dtype).max - 1000,
torch.finfo(hidden_states.dtype).max,
)
hidden_states = torch.clamp(
hidden_states, min=-clamp_value, max=clamp_value
)
outputs = (hidden_states,)
if use_cache:
outputs = outputs + (present_key_value_state,) + attention_outputs
else:
outputs = outputs + attention_outputs
return outputs # hidden-states, present_key_value_states, (self-attention position bias), (self-attention weights), (cross-attention position bias), (cross-attention weights)
class T5PreTrainedModel(PreTrainedModel):
"""
An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
models.
"""
config_class = T5Config
def _shift_right(self, input_ids):
decoder_start_token_id = self.config.decoder_start_token_id
pad_token_id = self.config.pad_token_id
assert decoder_start_token_id is not None, (
"self.model.config.decoder_start_token_id has to be defined. In T5 it is usually set to the pad_token_id."
" See T5 docs for more information"
)
# shift inputs to the right
if is_torch_fx_proxy(input_ids):
# Item assignment is not supported natively for proxies.
shifted_input_ids = torch.full(
input_ids.shape[:-1] + (1,), decoder_start_token_id
)
shifted_input_ids = torch.cat(
[shifted_input_ids, input_ids[..., :-1]], dim=-1
)
else:
shifted_input_ids = input_ids.new_zeros(input_ids.shape)
shifted_input_ids[..., 1:] = input_ids[..., :-1].clone()
shifted_input_ids[..., 0] = decoder_start_token_id
assert (
pad_token_id is not None
), "self.model.config.pad_token_id has to be defined."
# replace possible -100 values in labels by `pad_token_id`
shifted_input_ids.masked_fill_(shifted_input_ids == -100, pad_token_id)
return shifted_input_ids
class T5Stack(T5PreTrainedModel):
def __init__(self, config, prefix, weights, embed_tokens):
super().__init__(config)
self.is_decoder = config.is_decoder
self.embed_tokens = embed_tokens
self.block = nn.ModuleList(
[
T5Block(
config,
prefix=f"{prefix}.block.{layer_id}",
weights=weights,
has_relative_attention_bias=(layer_id == 0),
)
for layer_id in range(config.num_layers)
]
)
self.final_layer_norm = T5LayerNorm(
prefix=f"{prefix}.final_layer_norm",
weights=weights,
eps=config.layer_norm_epsilon,
)
self.dropout = nn.Dropout(config.dropout_rate)
def forward(
self,
input_ids=None,
attention_mask=None,
encoder_hidden_states=None,
encoder_attention_mask=None,
inputs_embeds=None,
head_mask=None,
cross_attn_head_mask=None,
past_key_values=None,
use_cache=None,
output_attentions=None,
output_hidden_states=None,
return_dict=None,
):
# Model parallel
use_cache = use_cache if use_cache is not None else self.config.use_cache
output_attentions = (
output_attentions
if output_attentions is not None
else self.config.output_attentions
)
output_hidden_states = (
output_hidden_states
if output_hidden_states is not None
else self.config.output_hidden_states
)
return_dict = (
return_dict if return_dict is not None else self.config.use_return_dict
)
if input_ids is not None and inputs_embeds is not None:
err_msg_prefix = "decoder_" if self.is_decoder else ""
raise ValueError(
f"You cannot specify both {err_msg_prefix}input_ids and {err_msg_prefix}inputs_embeds at the same time"
)
elif input_ids is not None:
input_shape = input_ids.size()
input_ids = input_ids.view(-1, input_shape[-1])
elif inputs_embeds is not None:
input_shape = inputs_embeds.size()[:-1]
else:
err_msg_prefix = "decoder_" if self.is_decoder else ""
raise ValueError(
f"You have to specify either {err_msg_prefix}input_ids or {err_msg_prefix}inputs_embeds"
)
if inputs_embeds is None:
assert (
self.embed_tokens is not None
), "You have to initialize the model with valid token embeddings"
inputs_embeds = self.embed_tokens(input_ids)
batch_size, seq_length = input_shape
# required mask seq length can be calculated via length of past
mask_seq_length = (
past_key_values[0][0].shape[2] + seq_length
if past_key_values is not None
else seq_length
)
if use_cache is True:
assert (
self.is_decoder
), f"`use_cache` can only be set to `True` if {self} is used as a decoder"
if attention_mask is None:
attention_mask = torch.ones(
batch_size, mask_seq_length, device=inputs_embeds.device
)
if (
self.is_decoder
and encoder_attention_mask is None
and encoder_hidden_states is not None
):
encoder_seq_length = encoder_hidden_states.shape[1]
encoder_attention_mask = torch.ones(
batch_size,
encoder_seq_length,
device=inputs_embeds.device,
dtype=torch.long,
)
# initialize past_key_values with `None` if past does not exist
if past_key_values is None:
past_key_values = [None] * len(self.block)
# We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length]
# ourselves in which case we just need to make it broadcastable to all heads.
extended_attention_mask = self.get_extended_attention_mask(
attention_mask, input_shape
)
# If a 2D or 3D attention mask is provided for the cross-attention
# we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length]
if self.is_decoder and encoder_hidden_states is not None:
(
encoder_batch_size,
encoder_sequence_length,
_,
) = encoder_hidden_states.size()
encoder_hidden_shape = (encoder_batch_size, encoder_sequence_length)
if encoder_attention_mask is None:
encoder_attention_mask = torch.ones(
encoder_hidden_shape, device=inputs_embeds.device
)
encoder_extended_attention_mask = self.invert_attention_mask(
encoder_attention_mask
)
else:
encoder_extended_attention_mask = None
# Prepare head mask if needed
head_mask = self.get_head_mask(head_mask, self.config.num_layers)
cross_attn_head_mask = self.get_head_mask(
cross_attn_head_mask, self.config.num_layers
)
present_key_value_states = () if use_cache else None
all_hidden_states = () if output_hidden_states else None
all_attentions = () if output_attentions else None
all_cross_attentions = () if (output_attentions and self.is_decoder) else None
position_bias = None
encoder_decoder_position_bias = None
hidden_states = self.dropout(inputs_embeds)
for i, (layer_module, past_key_value) in enumerate(
zip(self.block, past_key_values)
):
layer_head_mask = head_mask[i]
cross_attn_layer_head_mask = cross_attn_head_mask[i]
# Model parallel
if output_hidden_states:
all_hidden_states = all_hidden_states + (hidden_states,)
layer_outputs = layer_module(
hidden_states,
attention_mask=extended_attention_mask,
position_bias=position_bias,
encoder_hidden_states=encoder_hidden_states,
encoder_attention_mask=encoder_extended_attention_mask,
encoder_decoder_position_bias=encoder_decoder_position_bias,
layer_head_mask=layer_head_mask,
cross_attn_layer_head_mask=cross_attn_layer_head_mask,
past_key_value=past_key_value,
use_cache=use_cache,
output_attentions=output_attentions,
)
# layer_outputs is a tuple with:
# hidden-states, key-value-states, (self-attention position bias), (self-attention weights), (cross-attention position bias), (cross-attention weights)
if use_cache is False:
layer_outputs = layer_outputs[:1] + (None,) + layer_outputs[1:]
hidden_states, present_key_value_state = layer_outputs[:2]
# We share the position biases between the layers - the first layer store them
# layer_outputs = hidden-states, key-value-states (self-attention position bias), (self-attention weights),
# (cross-attention position bias), (cross-attention weights)
position_bias = layer_outputs[2]
if self.is_decoder and encoder_hidden_states is not None:
encoder_decoder_position_bias = layer_outputs[
4 if output_attentions else 3
]
# append next layer key value states
if use_cache:
present_key_value_states = present_key_value_states + (
present_key_value_state,
)
if output_attentions:
all_attentions = all_attentions + (layer_outputs[3],)
if self.is_decoder:
all_cross_attentions = all_cross_attentions + (layer_outputs[5],)
hidden_states = self.final_layer_norm(hidden_states)
hidden_states = self.dropout(hidden_states)
# Add last layer
if output_hidden_states:
all_hidden_states = all_hidden_states + (hidden_states,)
if not return_dict:
return tuple(
v
for v in [
hidden_states,
present_key_value_states,
all_hidden_states,
all_attentions,
all_cross_attentions,
]
if v is not None
)
return BaseModelOutputWithPastAndCrossAttentions(
last_hidden_state=hidden_states,
past_key_values=present_key_value_states,
hidden_states=all_hidden_states,
attentions=all_attentions,
cross_attentions=all_cross_attentions,
)
class T5ForConditionalGeneration(T5PreTrainedModel):
def __init__(self, config: T5Config, weights):
super().__init__(config)
self.model_dim = config.d_model
self.shared = TensorParallelEmbedding(prefix="shared", weights=weights)
encoder_config = copy.deepcopy(config)
encoder_config.is_decoder = False
encoder_config.use_cache = False
encoder_config.is_encoder_decoder = False
self.encoder = T5Stack(
config=encoder_config,
prefix="encoder",
weights=weights,
embed_tokens=self.shared,
)
decoder_config = copy.deepcopy(config)
decoder_config.is_decoder = True
decoder_config.is_encoder_decoder = False
decoder_config.num_layers = config.num_decoder_layers
self.decoder = T5Stack(
config=decoder_config,
prefix="decoder",
weights=weights,
embed_tokens=self.shared,
)
self.lm_head = TensorParallelHead.load(
config, prefix="lm_head", weights=weights
)
def forward(
self,
input_ids: Optional[torch.LongTensor] = None,
attention_mask: Optional[torch.FloatTensor] = None,
decoder_input_ids: Optional[torch.LongTensor] = None,
decoder_attention_mask: Optional[torch.BoolTensor] = None,
head_mask: Optional[torch.FloatTensor] = None,
decoder_head_mask: Optional[torch.FloatTensor] = None,
cross_attn_head_mask: Optional[torch.Tensor] = None,
encoder_outputs: Optional[Tuple[Tuple[torch.Tensor]]] = None,
past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
decoder_inputs_embeds: Optional[torch.FloatTensor] = None,
labels: Optional[torch.LongTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
) -> Union[Tuple[torch.FloatTensor], Seq2SeqLMOutput]:
use_cache = use_cache if use_cache is not None else self.config.use_cache
return_dict = (
return_dict if return_dict is not None else self.config.use_return_dict
)
# FutureWarning: head_mask was separated into two input args - head_mask, decoder_head_mask
if head_mask is not None and decoder_head_mask is None:
if self.config.num_layers == self.config.num_decoder_layers:
warnings.warn(__HEAD_MASK_WARNING_MSG, FutureWarning)
decoder_head_mask = head_mask
# Encode if needed (training, first prediction pass)
if encoder_outputs is None:
# Convert encoder inputs in embeddings if needed
encoder_outputs = self.encoder(
input_ids=input_ids,
attention_mask=attention_mask,
inputs_embeds=inputs_embeds,
head_mask=head_mask,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)
elif return_dict and not isinstance(encoder_outputs, BaseModelOutput):
encoder_outputs = BaseModelOutput(
last_hidden_state=encoder_outputs[0],
hidden_states=encoder_outputs[1] if len(encoder_outputs) > 1 else None,
attentions=encoder_outputs[2] if len(encoder_outputs) > 2 else None,
)
hidden_states = encoder_outputs[0]
if (
labels is not None
and decoder_input_ids is None
and decoder_inputs_embeds is None
):
# get decoder inputs from shifting lm labels to the right
decoder_input_ids = self._shift_right(labels)
# Decode
decoder_outputs = self.decoder(
input_ids=decoder_input_ids,
attention_mask=decoder_attention_mask,
inputs_embeds=decoder_inputs_embeds,
past_key_values=past_key_values,
encoder_hidden_states=hidden_states,
encoder_attention_mask=attention_mask,
head_mask=decoder_head_mask,
cross_attn_head_mask=cross_attn_head_mask,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)
sequence_output = decoder_outputs[0]
if self.config.tie_word_embeddings:
# Rescale output before projecting on vocab
# See https://github.com/tensorflow/mesh/blob/fa19d69eafc9a482aff0b59ddd96b025c0cb207d/mesh_tensorflow/transformer/transformer.py#L586
sequence_output = sequence_output * (self.model_dim**-0.5)
lm_logits = self.lm_head(sequence_output)
loss = None
if labels is not None:
loss_fct = CrossEntropyLoss(ignore_index=-100)
# move labels to correct device to enable PP
labels = labels.to(lm_logits.device)
loss = loss_fct(lm_logits.view(-1, lm_logits.size(-1)), labels.view(-1))
# TODO(thom): Add z_loss https://github.com/tensorflow/mesh/blob/fa19d69eafc9a482aff0b59ddd96b025c0cb207d/mesh_tensorflow/layers.py#L666
if not return_dict:
output = (lm_logits,) + decoder_outputs[1:] + encoder_outputs
return ((loss,) + output) if loss is not None else output
return Seq2SeqLMOutput(
loss=loss,
logits=lm_logits,
past_key_values=decoder_outputs.past_key_values,
decoder_hidden_states=decoder_outputs.hidden_states,
decoder_attentions=decoder_outputs.attentions,
cross_attentions=decoder_outputs.cross_attentions,
encoder_last_hidden_state=encoder_outputs.last_hidden_state,
encoder_hidden_states=encoder_outputs.hidden_states,
encoder_attentions=encoder_outputs.attentions,
)
def prepare_inputs_for_generation(
self,
input_ids,
past_key_values=None,
attention_mask=None,
head_mask=None,
decoder_head_mask=None,
decoder_attention_mask=None,
cross_attn_head_mask=None,
use_cache=None,
encoder_outputs=None,
**kwargs,
):
# cut decoder_input_ids if past is used
if past_key_values is not None:
input_ids = input_ids[:, -1:]
return {
"decoder_input_ids": input_ids,
"past_key_values": past_key_values,
"encoder_outputs": encoder_outputs,
"attention_mask": attention_mask,
"head_mask": head_mask,
"decoder_head_mask": decoder_head_mask,
"decoder_attention_mask": decoder_attention_mask,
"cross_attn_head_mask": cross_attn_head_mask,
"use_cache": use_cache,
}
def prepare_decoder_input_ids_from_labels(self, labels: torch.Tensor):
return self._shift_right(labels)
def _reorder_cache(self, past_key_values, beam_idx):
# if decoder past is not included in output
# speedy decoding is disabled and no need to reorder
if past_key_values is None:
logger.warning(
"You might want to consider setting `use_cache=True` to speed up decoding"
)
return past_key_values
reordered_decoder_past = ()
for layer_past_states in past_key_values:
# get the correct batch idx from layer past batch dim
# batch dim of `past` is at 2nd position
reordered_layer_past_states = ()
for layer_past_state in layer_past_states:
# need to set correct `past` for each of the four key / value states
reordered_layer_past_states = reordered_layer_past_states + (
layer_past_state.index_select(
0, beam_idx.to(layer_past_state.device)
),
)
assert reordered_layer_past_states[0].shape == layer_past_states[0].shape
assert len(reordered_layer_past_states) == len(layer_past_states)
reordered_decoder_past = reordered_decoder_past + (
reordered_layer_past_states,
)
return reordered_decoder_past
import torch
import torch.distributed
from accelerate import init_empty_weights
from opentelemetry import trace
from pathlib import Path
from safetensors import safe_open
from transformers import AutoConfig
from transformers.models.llama import LlamaTokenizer
from typing import Optional, List
from typing import Optional
from text_generation_server.models import FlashCausalLM
from text_generation_server.models.custom_modeling.flash_llama_modeling import (
FlashLlamaForCausalLM,
TensorParallelEmbedding,
TensorParallelRowLinear,
TensorParallelColumnLinear,
)
from text_generation_server.utils import (
initialize_torch_distributed,
weight_files,
download_weights,
weight_hub_files,
LocalEntryNotFoundError,
Weights,
)
tracer = trace.get_tracer(__name__)
class FlashLlama(FlashCausalLM):
def __init__(
self,
model_id: str,
revision: Optional[str] = None,
quantize: Optional[str] = None,
trust_remote_code: bool = False,
):
if torch.cuda.is_available():
device = torch.device("cuda")
dtype = torch.float16
else:
raise NotImplementedError("FlashLlama is only available on GPU")
tokenizer = LlamaTokenizer.from_pretrained(
model_id,
revision=revision,
padding_side="left",
truncation_side="left",
trust_remote_code=trust_remote_code,
)
config = AutoConfig.from_pretrained(
model_id, revision=revision, trust_remote_code=trust_remote_code
)
# We do not use from_pretrained as we modified the model internal module layout
try:
filenames = weight_files(model_id, revision, ".bin")
# Local files not found
except LocalEntryNotFoundError:
hub_files = weight_hub_files(model_id, revision, ".bin")
filenames = download_weights(hub_files, model_id, revision)
with init_empty_weights():
model = FlashLlamaForCausalLM(config)
self.load_weights(model, filenames, quantize, device, dtype)
super(FlashCausalLM, self).__init__(
model=model.to(device),
tokenizer=tokenizer,
requires_padding=False,
dtype=dtype,
device=device,
)
@staticmethod
def load_weights(
model,
filenames: List[Path],
quantize: Optional[str],
device: torch.device,
dtype: torch.dtype,
):
for filename in filenames:
state_dict = torch.load(filename, map_location="cpu")
for key, value in state_dict.items():
value = value.to(device if quantize is None else "cpu").to(dtype)
layer_name = ".".join(key.split(".")[:4])
# Fused qkv
if "q_proj" in key or "k_proj" in key or "v_proj" in key:
final_key = layer_name + ".query_key_value.weight"
# Fused gate and up projs
elif "gate_proj" in key or "up_proj" in key:
final_key = layer_name + ".gate_up_proj.weight"
else:
final_key = key
module_name, param_name = final_key.rsplit(".", 1)
module = model.get_submodule(module_name)
try:
current_parameter_tensor = module._parameters[param_name]
except KeyError:
current_parameter_tensor = None
if current_parameter_tensor is not None:
if current_parameter_tensor.device == torch.device("meta"):
# Init qkv
if "query_key_value" in final_key:
module._parameters[param_name] = value.new_empty(
(value.shape[0] * 3, value.shape[1])
)
# Init gate and up proj
elif "gate_up_proj" in final_key:
module._parameters[param_name] = value.new_empty(
(value.shape[0] * 2, value.shape[1])
)
# Copy to correct slice
if "q_proj" in key:
module._parameters[param_name][: value.shape[0]] = value
elif "k_proj" in key:
module._parameters[param_name][
value.shape[0] : value.shape[0] * 2
] = value
elif "v_proj" in key:
module._parameters[param_name][value.shape[0] * 2 :] = value
elif "gate_proj" in key:
module._parameters[param_name][: value.shape[0]] = value
elif "up_proj" in key:
module._parameters[param_name][value.shape[0] :] = value
else:
if current_parameter_tensor.shape != value.shape:
raise ValueError(
f"Name {final_key} -- Current {current_parameter_tensor.shape} and got {value.shape}"
)
module._parameters[param_name] = value
else:
module._buffers[param_name] = value
del value
torch.cuda.empty_cache()
model.post_load_weights(quantize)
class FlashLlamaSharded(FlashLlama):
def __init__(
self,
model_id: str,
......@@ -176,24 +47,16 @@ class FlashLlamaSharded(FlashLlama):
)
torch.distributed.barrier(group=self.process_group)
filenames = weight_files(model_id, revision=revision, extension=".safetensors")
weights = Weights(filenames, device, dtype, process_group=self.process_group)
with init_empty_weights():
model = FlashLlamaForCausalLM(config, process_group=self.process_group)
config.quantize = quantize
model = FlashLlamaForCausalLM(config, weights)
torch.distributed.barrier(group=self.process_group)
self.load_weights(
model,
filenames,
quantize=quantize,
device=device,
dtype=dtype,
rank=rank,
world_size=world_size,
)
torch.distributed.barrier(group=self.process_group)
super(FlashCausalLM, self).__init__(
model=model.to(device),
model=model,
tokenizer=tokenizer,
requires_padding=False,
dtype=dtype,
......@@ -201,114 +64,3 @@ class FlashLlamaSharded(FlashLlama):
rank=rank,
world_size=world_size,
)
@staticmethod
def load_weights(
model,
filenames: List[str],
quantize: Optional[str],
device: torch.device,
dtype: torch.dtype,
rank: int,
world_size: int,
):
for file in filenames:
with safe_open(
file, framework="pt", device=str(device) if quantize is None else "cpu"
) as f:
for name in f.keys():
slice_ = f.get_slice(name)
layer_name = ".".join(name.split(".")[:4])
# Fused qkv
if "q_proj" in name or "k_proj" in name or "v_proj" in name:
final_name = layer_name + ".query_key_value.weight"
# Fused gate and up projs
elif "gate_proj" in name or "up_proj" in name:
final_name = layer_name + ".gate_up_proj.weight"
else:
final_name = name
module_name, param_name = final_name.rsplit(".", 1)
module = model.get_submodule(module_name)
if isinstance(module, TensorParallelColumnLinear):
size = slice_.get_shape()[0]
block_size = size // world_size
start = rank * block_size
stop = (rank + 1) * block_size
tensor = slice_[start:stop]
elif isinstance(module, TensorParallelRowLinear):
size = slice_.get_shape()[1]
block_size = size // world_size
start = rank * block_size
stop = (rank + 1) * block_size
tensor = slice_[:, start:stop]
elif isinstance(module, TensorParallelEmbedding):
size = slice_.get_shape()[0]
block_size = size // world_size
start = rank * block_size
stop = (rank + 1) * block_size
tensor = slice_[start:stop]
elif name == "lm_head.weight" and model.model.tp_embeddings:
size = slice_.get_shape()[0]
block_size = size // world_size
start = rank * block_size
stop = (rank + 1) * block_size
tensor = slice_[start:stop]
else:
try:
tensor = slice_[:]
except:
tensor = f.get_tensor(name)
tensor = tensor.contiguous().to(dtype)
try:
current_parameter_tensor = module._parameters[param_name]
except KeyError:
current_parameter_tensor = None
if current_parameter_tensor is not None:
if current_parameter_tensor.device == torch.device("meta"):
# Init qkv
if "query_key_value" in final_name:
module._parameters[param_name] = tensor.new_empty(
(tensor.shape[0] * 3, tensor.shape[1])
)
# Init gate and up proj
elif "gate_up_proj" in final_name:
module._parameters[param_name] = tensor.new_empty(
(tensor.shape[0] * 2, tensor.shape[1])
)
# Init gate and up proj
if "q_proj" in name:
module._parameters[param_name][: tensor.shape[0]] = tensor
elif "k_proj" in name:
module._parameters[param_name][
tensor.shape[0] : tensor.shape[0] * 2
] = tensor
elif "v_proj" in name:
module._parameters[param_name][
tensor.shape[0] * 2 :
] = tensor
elif "gate_proj" in name:
module._parameters[param_name][: tensor.shape[0]] = tensor
elif "up_proj" in name:
module._parameters[param_name][tensor.shape[0] :] = tensor
else:
if current_parameter_tensor.shape != tensor.shape:
raise ValueError(
f"Name {name} -- Current {current_parameter_tensor.shape} and got {tensor.shape}"
)
module._parameters[param_name] = tensor
else:
module._buffers[param_name] = tensor
torch.cuda.empty_cache()
model.post_load_weights(quantize)
import torch
import torch.distributed
from accelerate import init_empty_weights
from opentelemetry import trace
from safetensors import safe_open
from transformers import AutoTokenizer, AutoConfig
from typing import Optional, List
from typing import Optional
from text_generation_server.models import FlashCausalLM
from text_generation_server.models.custom_modeling.flash_neox_modeling import (
FlashGPTNeoXForCausalLM,
TensorParallelEmbedding,
TensorParallelRowLinear,
TensorParallelColumnLinear,
)
from text_generation_server.utils import (
initialize_torch_distributed,
weight_files,
Weights,
)
tracer = trace.get_tracer(__name__)
class FlashNeoX(FlashCausalLM):
def __init__(
self,
model_id: str,
revision: Optional[str] = None,
quantize: Optional[str] = None,
trust_remote_code: bool = False,
):
super(FlashNeoX, self).__init__(
FlashGPTNeoXForCausalLM,
model_id,
revision,
quantize,
trust_remote_code=trust_remote_code,
)
class FlashNeoXSharded(FlashNeoX):
class FlashNeoXSharded(FlashCausalLM):
def __init__(
self,
model_id: str,
......@@ -65,23 +44,16 @@ class FlashNeoXSharded(FlashNeoX):
config = AutoConfig.from_pretrained(
model_id, revision=revision, trust_remote_code=trust_remote_code
)
config.quantize = quantize
torch.distributed.barrier(group=self.process_group)
filenames = weight_files(model_id, revision=revision, extension=".safetensors")
weights = Weights(
filenames, device=device, dtype=dtype, process_group=self.process_group
)
with init_empty_weights():
model = FlashGPTNeoXForCausalLM(config, self.process_group)
model = FlashGPTNeoXForCausalLM(config, weights)
torch.distributed.barrier(group=self.process_group)
self.load_weights(
model,
filenames,
quantize=quantize,
device=device,
dtype=dtype,
rank=rank,
world_size=world_size,
)
torch.distributed.barrier(group=self.process_group)
super(FlashCausalLM, self).__init__(
model=model.to(device),
......@@ -92,79 +64,3 @@ class FlashNeoXSharded(FlashNeoX):
rank=rank,
world_size=world_size,
)
@staticmethod
def load_weights(
model,
filenames: List[str],
quantize: Optional[str],
device: torch.device,
dtype: torch.dtype,
rank: int,
world_size: int,
):
parameters = dict(model.named_parameters())
for file in filenames:
with safe_open(
file, framework="pt", device=str(device) if quantize is None else "cpu"
) as f:
for name in f.keys():
module_name, param_name = name.rsplit(".", 1)
module = model.get_submodule(module_name)
current_parameter_tensor = parameters.get(name, None)
slice_ = f.get_slice(name)
if isinstance(module, TensorParallelColumnLinear):
size = slice_.get_shape()[0]
block_size = size // world_size
start = rank * block_size
stop = (rank + 1) * block_size
tensor = slice_[start:stop]
elif isinstance(module, TensorParallelRowLinear):
if param_name == "weight":
size = slice_.get_shape()[1]
block_size = size // world_size
start = rank * block_size
stop = (rank + 1) * block_size
tensor = slice_[:, start:stop]
else:
tensor = slice_[:]
# XXX: Hack for Rowlinear to add the bias only once.
if rank != 0:
tensor = torch.zeros_like(tensor)
elif isinstance(module, TensorParallelEmbedding):
size = slice_.get_shape()[0]
block_size = size // world_size
start = rank * block_size
stop = (rank + 1) * block_size
tensor = slice_[start:stop]
elif name == "embed_out.weight" and model.gpt_neox.tp_embeddings:
size = slice_.get_shape()[0]
block_size = size // world_size
start = rank * block_size
stop = (rank + 1) * block_size
tensor = slice_[start:stop]
else:
try:
tensor = slice_[:]
except:
tensor = f.get_tensor(name)
if (
current_parameter_tensor is not None
and current_parameter_tensor.shape != tensor.shape
):
raise ValueError(
f"Name {name} -- Current {current_parameter_tensor.shape} and got {tensor.shape}"
)
tensor = tensor.contiguous().to(dtype)
if current_parameter_tensor is not None:
module._parameters[param_name] = tensor
else:
module._buffers[param_name] = tensor
model.post_load_weights(quantize)
import torch
import torch.distributed
from pathlib import Path
from accelerate import init_empty_weights
from opentelemetry import trace
from safetensors import safe_open
from transformers import AutoTokenizer, AutoConfig
from typing import Optional, List
from transformers import AutoTokenizer
from typing import Optional
from text_generation_server.models import FlashCausalLM
from text_generation_server.models.custom_modeling.flash_rw_modeling import (
RWConfig,
FlashRWForCausalLM,
TensorParallelEmbedding,
TensorParallelRowLinear,
TensorParallelColumnLinear,
)
from text_generation_server.utils import (
initialize_torch_distributed,
weight_files,
download_weights,
weight_hub_files,
LocalEntryNotFoundError,
Weights,
)
tracer = trace.get_tracer(__name__)
class FlashRW(FlashCausalLM):
def __init__(
self,
model_id: str,
revision: Optional[str] = None,
quantize: Optional[str] = None,
trust_remote_code: bool = False,
):
if torch.cuda.is_available():
device = torch.device("cuda")
dtype = torch.float16
else:
raise NotImplementedError("RW is only available on GPU")
tokenizer = AutoTokenizer.from_pretrained(
model_id,
revision=revision,
padding_side="left",
truncation_side="left",
trust_remote_code=trust_remote_code,
)
config = RWConfig.from_pretrained(
model_id,
revision=revision,
)
# We do not use from_pretrained as it is too slow
try:
filenames = weight_files(model_id, revision, ".bin")
# Local files not found
except LocalEntryNotFoundError:
hub_files = weight_hub_files(model_id, revision, ".bin")
filenames = download_weights(hub_files, model_id, revision)
with init_empty_weights():
model = FlashRWForCausalLM(config)
self.load_weights(
model,
filenames,
quantize,
device,
dtype,
)
super(FlashCausalLM, self).__init__(
model=model.to(device),
tokenizer=tokenizer,
requires_padding=False,
dtype=dtype,
device=device,
)
@staticmethod
def load_weights(
model: FlashRWForCausalLM,
filenames: List[Path],
quantize: Optional[str],
device: torch.device,
dtype: torch.dtype,
):
for filename in filenames:
state_dict = torch.load(filename, map_location="cpu")
for key, value in state_dict.items():
value = value.to(device if quantize is None else "cpu").to(dtype)
module_name, param_name = key.rsplit(".", 1)
module = model.get_submodule(module_name)
try:
current_parameter_tensor = module._parameters[param_name]
if current_parameter_tensor.shape != value.shape:
raise ValueError(
f"Name {key} -- Current {current_parameter_tensor.shape} and got {value.shape}"
)
module._parameters[param_name] = value
except KeyError:
module._buffers[param_name] = value
del value
torch.cuda.empty_cache()
model.post_load_weights(quantize)
class FlashRWSharded(FlashRW):
class FlashRWSharded(FlashCausalLM):
def __init__(
self,
model_id: str,
......@@ -142,20 +48,12 @@ class FlashRWSharded(FlashRW):
torch.distributed.barrier(group=self.process_group)
filenames = weight_files(model_id, revision=revision, extension=".safetensors")
weights = Weights(filenames, device, dtype, process_group=self.process_group)
with init_empty_weights():
model = FlashRWForCausalLM(config, self.process_group)
config.quantize = quantize
model = FlashRWForCausalLM(config, weights)
torch.distributed.barrier(group=self.process_group)
self.load_weights(
model,
filenames,
quantize=quantize,
device=device,
dtype=dtype,
rank=rank,
world_size=world_size,
)
torch.distributed.barrier(group=self.process_group)
super(FlashCausalLM, self).__init__(
model=model.to(device),
......@@ -166,79 +64,3 @@ class FlashRWSharded(FlashRW):
rank=rank,
world_size=world_size,
)
@staticmethod
def load_weights(
model,
filenames: List[str],
quantize: Optional[str],
device: torch.device,
dtype: torch.dtype,
rank: int,
world_size: int,
):
parameters = dict(model.named_parameters())
for file in filenames:
with safe_open(
file, framework="pt", device=str(device) if quantize is None else "cpu"
) as f:
for name in f.keys():
module_name, param_name = name.rsplit(".", 1)
module = model.get_submodule(module_name)
current_parameter_tensor = parameters.get(name, None)
slice_ = f.get_slice(name)
if isinstance(module, TensorParallelColumnLinear):
size = slice_.get_shape()[0]
block_size = size // world_size
start = rank * block_size
stop = (rank + 1) * block_size
tensor = slice_[start:stop]
elif isinstance(module, TensorParallelRowLinear):
if param_name == "weight":
size = slice_.get_shape()[1]
block_size = size // world_size
start = rank * block_size
stop = (rank + 1) * block_size
tensor = slice_[:, start:stop]
else:
tensor = slice_[:]
# XXX: Hack for Rowlinear to add the bias only once.
if rank != 0:
tensor = torch.zeros_like(tensor)
elif isinstance(module, TensorParallelEmbedding):
size = slice_.get_shape()[0]
block_size = size // world_size
start = rank * block_size
stop = (rank + 1) * block_size
tensor = slice_[start:stop]
elif name == "lm_head.weight" and model.transformer.tp_embeddings:
size = slice_.get_shape()[0]
block_size = size // world_size
start = rank * block_size
stop = (rank + 1) * block_size
tensor = slice_[start:stop]
else:
try:
tensor = slice_[:]
except:
tensor = f.get_tensor(name)
if (
current_parameter_tensor is not None
and current_parameter_tensor.shape != tensor.shape
):
raise ValueError(
f"Name {name} -- Current {current_parameter_tensor.shape} and got {tensor.shape}"
)
tensor = tensor.contiguous().to(dtype)
if current_parameter_tensor is not None:
module._parameters[param_name] = tensor
else:
module._buffers[param_name] = tensor
model.post_load_weights(quantize)
import torch
import torch.distributed
from accelerate import init_empty_weights
from opentelemetry import trace
from safetensors import safe_open
from pathlib import Path
from transformers import AutoTokenizer, GPT2Config
from transformers import AutoTokenizer, AutoConfig
from typing import Optional, List
from text_generation_server.models import FlashCausalLM
from text_generation_server.models.custom_modeling.flash_santacoder_modeling import (
FlashSantacoderForCausalLM,
TensorParallelRowLinear,
TensorParallelColumnLinear,
TensorParallelEmbedding,
)
from text_generation_server.utils import (
initialize_torch_distributed,
weight_files,
download_weights,
weight_hub_files,
LocalEntryNotFoundError,
Weights,
)
tracer = trace.get_tracer(__name__)
class FlashSantacoder(FlashCausalLM):
def __init__(
self,
model_id: str,
revision: Optional[str] = None,
quantize: Optional[str] = None,
trust_remote_code: bool = False,
):
if torch.cuda.is_available():
device = torch.device("cuda")
dtype = torch.float16
else:
raise NotImplementedError("FlashSantacoder is only available on GPU")
tokenizer = AutoTokenizer.from_pretrained(
model_id,
revision=revision,
padding_side="left",
truncation_side="left",
trust_remote_code=trust_remote_code,
)
config = GPT2Config.from_pretrained(
model_id,
revision=revision,
)
# We do not use from_pretrained as we modified the model internal module layout
filenames = weight_files(model_id, revision, ".safetensors")
with init_empty_weights():
model = FlashSantacoderForCausalLM(config)
self.load_weights(
model,
filenames,
quantize,
device,
dtype,
config.architectures[0].startswith("GPT2"),
)
super(FlashCausalLM, self).__init__(
model=model.to(device),
tokenizer=tokenizer,
requires_padding=False,
dtype=dtype,
device=device,
)
@staticmethod
def load_weights(
model: FlashSantacoderForCausalLM,
filenames: List[Path],
quantize: Optional[str],
device: torch.device,
dtype: torch.dtype,
transpose: bool,
):
for filename in filenames:
with safe_open(
filename,
framework="pt",
device=str(device) if quantize is None else "cpu",
) as f:
for key in f.keys():
value = f.get_tensor(key)
value = value.to(device if quantize is None else "cpu").to(dtype)
layer_name = ".".join(key.split(".")[:4])
# Fused qkv
if "q_attn.weight" in key or "kv_attn.weight" in key:
final_key = layer_name + ".c_attn.weight"
elif "q_attn.bias" in key or "kv_attn.bias" in key:
final_key = layer_name + ".c_attn.bias"
else:
final_key = key
module_name, param_name = final_key.rsplit(".", 1)
module = model.get_submodule(module_name)
try:
current_parameter_tensor = module._parameters[param_name]
except KeyError:
current_parameter_tensor = None
if current_parameter_tensor is not None:
if transpose and (
"c_fc.weight" in key
or "c_proj.weight" in key
or "q_attn.weight" in key
or "kv_attn.weight" in key
or "c_attn.weight" in key
):
# Tranpose as we use nn.Linear instead of Conv1D
value = value.T
if current_parameter_tensor.device == torch.device("meta"):
# Init qkv
if "c_attn.weight" in final_key:
module._parameters[param_name] = value.new_empty(
(
model.transformer.head_size
* (model.transformer.num_heads + 2),
value.shape[1],
)
)
elif "c_attn.bias" in final_key:
module._parameters[param_name] = value.new_empty(
(
model.transformer.head_size
* (model.transformer.num_heads + 2)
)
)
# Copy to correct slice
if "q_attn.weight" in key:
module._parameters[param_name][: value.shape[0]] = value
elif "q_attn.bias" in key:
module._parameters[param_name][: value.shape[0]] = value
elif "kv_attn.weight" in key:
module._parameters[param_name][
model.transformer.head_size
* model.transformer.num_heads :
] = value
elif "kv_attn.bias" in key:
module._parameters[param_name][
model.transformer.head_size
* model.transformer.num_heads :
] = value
else:
if current_parameter_tensor.shape != value.shape:
raise ValueError(
f"Name {final_key} -- Current {current_parameter_tensor.shape} and got {value.shape}"
)
module._parameters[param_name] = value
else:
module._buffers[param_name] = value
del value
if model.lm_head.weight.device == torch.device("meta"):
model.lm_head.weight = torch.nn.Parameter(model.transformer.wte.weight)
torch.cuda.empty_cache()
model.post_load_weights(quantize)
uninitialized_parameters = []
for n, p in model.named_parameters():
if p.data.device == torch.device("meta"):
uninitialized_parameters.append(n)
if uninitialized_parameters:
raise RuntimeError(
f"found uninitialized parameters in model : {uninitialized_parameters}"
)
def decode(self, generated_ids: List[int]) -> str:
# Do not skip special tokens as they are used for custom parsing rules of the generated text
return self.tokenizer.decode(
generated_ids, skip_special_tokens=False, clean_up_tokenization_spaces=False
)
class FlashSantacoderSharded(FlashSantacoder):
class FlashSantacoderSharded(FlashCausalLM):
def __init__(
self,
model_id: str,
......@@ -214,28 +41,22 @@ class FlashSantacoderSharded(FlashSantacoder):
trust_remote_code=trust_remote_code,
)
config = GPT2Config.from_pretrained(
config = AutoConfig.from_pretrained(
model_id,
revision=revision,
trust_remote_code=True,
)
config.quantize = quantize
config.transpose = config.architectures[0].startswith("GPT2")
torch.distributed.barrier(group=self.process_group)
filenames = weight_files(model_id, revision=revision, extension=".safetensors")
weights = Weights(
filenames, device=device, dtype=dtype, process_group=self.process_group
)
with init_empty_weights():
model = FlashSantacoderForCausalLM(config, self.process_group)
model = FlashSantacoderForCausalLM(config, weights)
torch.distributed.barrier(group=self.process_group)
self.load_weights(
model,
filenames,
quantize=quantize,
device=device,
dtype=dtype,
rank=rank,
world_size=world_size,
transpose=config.architectures[0].startswith("GPT2"),
)
torch.distributed.barrier(group=self.process_group)
super(FlashCausalLM, self).__init__(
model=model.to(device),
......@@ -247,164 +68,8 @@ class FlashSantacoderSharded(FlashSantacoder):
world_size=world_size,
)
@staticmethod
def load_weights(
model,
filenames: List[str],
quantize: Optional[str],
device: torch.device,
dtype: torch.dtype,
rank: int,
world_size: int,
transpose: bool,
):
for file in filenames:
with safe_open(
file, framework="pt", device=str(device) if quantize is None else "cpu"
) as f:
for key in f.keys():
slice_ = f.get_slice(key)
layer_name = ".".join(key.split(".")[:4])
# Fused qkv
if "q_attn.weight" in key or "kv_attn.weight" in key:
final_key = layer_name + ".c_attn.weight"
elif "q_attn.bias" in key or "kv_attn.bias" in key:
final_key = layer_name + ".c_attn.bias"
else:
final_key = key
module_name, param_name = final_key.rsplit(".", 1)
module = model.get_submodule(module_name)
if isinstance(module, TensorParallelColumnLinear):
dim = 1 if transpose and "weight" in param_name else 0
size = slice_.get_shape()[dim]
block_size = size // world_size
start = rank * block_size
stop = (rank + 1) * block_size
tensor = (
slice_[start:stop] if dim == 0 else slice_[:, start:stop]
)
elif isinstance(module, TensorParallelRowLinear):
if param_name == "weight":
dim = 0 if transpose else 1
size = slice_.get_shape()[dim]
block_size = size // world_size
start = rank * block_size
stop = (rank + 1) * block_size
tensor = (
slice_[start:stop]
if dim == 0
else slice_[:, start:stop]
)
else:
tensor = slice_[:]
# XXX: Hack for Rowlinear to add the bias only once.
if rank != 0:
tensor = torch.zeros_like(tensor)
elif isinstance(module, TensorParallelEmbedding):
size = slice_.get_shape()[0]
block_size = size // world_size
start = rank * block_size
stop = (rank + 1) * block_size
tensor = slice_[start:stop]
elif key == "lm_head.weight" and model.transformer.tp_embeddings:
size = slice_.get_shape()[0]
block_size = size // world_size
start = rank * block_size
stop = (rank + 1) * block_size
tensor = slice_[start:stop]
else:
try:
tensor = slice_[:]
except:
tensor = f.get_tensor(key)
tensor = tensor.contiguous().to(dtype)
try:
current_parameter_tensor = module._parameters[param_name]
except KeyError:
current_parameter_tensor = None
if current_parameter_tensor is not None:
if transpose and (
"c_fc.weight" in key
or "c_proj.weight" in key
or "q_attn.weight" in key
or "kv_attn.weight" in key
or "c_attn.weight" in key
):
# Tranpose as we use nn.Linear instead of Conv1D
tensor = tensor.T
if current_parameter_tensor.device == torch.device("meta"):
# Init qkv
if "c_attn.weight" in final_key:
module._parameters[param_name] = tensor.new_empty(
(
model.transformer.head_size
* (model.transformer.num_heads + 2),
tensor.shape[1],
)
)
elif "c_attn.bias" in final_key:
module._parameters[param_name] = tensor.new_empty(
(
model.transformer.head_size
* (model.transformer.num_heads + 2)
)
)
# Copy to correct slice
if "q_attn" in key:
size = tensor.shape[0]
block_size = size // world_size
start = rank * block_size
stop = (rank + 1) * block_size
tensor = tensor[start:stop]
module._parameters[param_name][: tensor.shape[0]] = tensor
elif "kv_attn.weight" in key:
module._parameters[param_name][
model.transformer.head_size
* model.transformer.num_heads :
] = tensor
elif "kv_attn.bias" in key:
module._parameters[param_name][
model.transformer.head_size
* model.transformer.num_heads :
] = tensor
elif "c_attn" in key:
# Slice q_tensor by shard
q_tensor = tensor[: -2 * model.transformer.head_size]
block_size = q_tensor.shape[0] // world_size
start = rank * block_size
stop = (rank + 1) * block_size
q_tensor = q_tensor[start:stop]
module._parameters[param_name][
: q_tensor.shape[0]
] = q_tensor
# Kv tensor is copied for every shard
kv_tensor = tensor[-2 * model.transformer.head_size :]
module._parameters[param_name][
q_tensor.shape[0] :
] = kv_tensor
else:
if current_parameter_tensor.shape != tensor.shape:
raise ValueError(
f"Name {key} -- Current {current_parameter_tensor.shape} and got {tensor.shape}"
)
module._parameters[param_name] = tensor
else:
module._buffers[param_name] = tensor
if model.lm_head.weight.device == torch.device("meta"):
model.lm_head.weight = torch.nn.Parameter(model.transformer.wte.weight)
torch.cuda.empty_cache()
model.post_load_weights(quantize)
def decode(self, generated_ids: List[int]) -> str:
# Do not skip special tokens as they are used for custom parsing rules of the generated text
return self.tokenizer.decode(
generated_ids, skip_special_tokens=False, clean_up_tokenization_spaces=False
)
......@@ -2,41 +2,25 @@ import re
import torch
import torch.distributed
from typing import List, Optional, Type, Tuple
from typing import List, Optional, Type
from accelerate import init_empty_weights
from safetensors import safe_open
from transformers import (
AutoTokenizer,
AutoModelForCausalLM,
AutoConfig,
PreTrainedTokenizerBase,
)
from transformers.models.opt.parallel_layers import (
TensorParallelColumnLinear,
TensorParallelEmbedding,
TensorParallelRowLinear,
)
from text_generation_server.models import CausalLM
from text_generation_server.models.causal_lm import CausalLMBatch
from text_generation_server.pb import generate_pb2
from text_generation_server.models.opt import OPT
from text_generation_server.models.custom_modeling.opt_modeling import OPTForCausalLM
from text_generation_server.utils import (
NextTokenChooser,
StoppingCriteria,
initialize_torch_distributed,
weight_files,
Weights,
)
HAS_BITS_AND_BYTES = True
try:
import bitsandbytes as bnb
from bitsandbytes.nn import Int8Params
except Exception as e:
HAS_BITS_AND_BYTES = False
# CREDIT: Papers with code => https://github.com/paperswithcode/galai/blob/main/galai/utils.py
# we split individual characters inside special tokens like [START_DNA]
......@@ -168,33 +152,7 @@ class GalacticaCausalLMBatch(CausalLMBatch):
)
class Galactica(OPT):
@property
def batch_type(self) -> Type[CausalLMBatch]:
return GalacticaCausalLMBatch
def decode(self, generated_ids: List[int]) -> str:
# Do not skip special tokens as they are used for custom parsing rules of the generated text
return self.tokenizer.decode(
generated_ids, skip_special_tokens=False, clean_up_tokenization_spaces=False
)
def forward(
self, input_ids, attention_mask, position_ids, past_key_values: Optional = None
) -> Tuple[torch.Tensor, List[Tuple[torch.Tensor, torch.Tensor]]]:
"""Overwrite forward to ignore position_ids"""
# Model Forward
outputs = self.model.forward(
input_ids=input_ids,
attention_mask=attention_mask,
past_key_values=past_key_values,
use_cache=True,
)
return outputs.logits, outputs.past_key_values
class GalacticaSharded(Galactica):
class GalacticaSharded(CausalLM):
def __init__(
self,
model_id: str,
......@@ -224,26 +182,17 @@ class GalacticaSharded(Galactica):
tp_parallel=True,
trust_remote_code=trust_remote_code,
)
config.quantize = quantize
tokenizer.pad_token_id = config.pad_token_id
torch.distributed.barrier(group=self.process_group)
filenames = weight_files(model_id, revision=revision, extension=".safetensors")
weights = Weights(
filenames, device=device, dtype=dtype, process_group=self.process_group
)
with init_empty_weights():
model = AutoModelForCausalLM.from_config(
config, trust_remote_code=trust_remote_code
)
model = OPTForCausalLM(config, weights)
torch.distributed.barrier(group=self.process_group)
self.load_weights(
model,
filenames,
quantize=quantize,
device=device,
dtype=dtype,
rank=rank,
world_size=world_size,
)
torch.distributed.barrier(group=self.process_group)
super(CausalLM, self).__init__(
model=model,
......@@ -255,127 +204,15 @@ class GalacticaSharded(Galactica):
world_size=world_size,
)
@staticmethod
def load_weights(
model,
filenames: List[str],
quantize: Optional[str],
device: torch.device,
dtype: torch.dtype,
rank: int,
world_size: int,
):
parameters = dict(model.named_parameters())
for file in filenames:
with safe_open(
file, framework="pt", device=str(device) if quantize is None else "cpu"
) as f:
for name in f.keys():
if name == "lm_head.weight":
continue
module_name, param_name = name.rsplit(".", 1)
module = model.get_submodule(module_name)
current_tensor = parameters[name]
slice_ = f.get_slice(name)
if isinstance(module, TensorParallelColumnLinear):
size = slice_.get_shape()[0]
block_size = size // world_size
start = rank * block_size
stop = (rank + 1) * block_size
tensor = slice_[start:stop]
elif isinstance(module, TensorParallelRowLinear):
if param_name == "weight":
size = slice_.get_shape()[1]
block_size = size // world_size
start = rank * block_size
stop = (rank + 1) * block_size
tensor = slice_[:, start:stop]
else:
tensor = slice_[:]
# XXX: Hack for Rowlinear to add the bias only once.
if rank != 0:
tensor = torch.zeros_like(tensor)
elif isinstance(module, TensorParallelEmbedding):
size = slice_.get_shape()[0]
block_size = size // world_size
start = rank * block_size
stop = (rank + 1) * block_size
tensor = slice_[start:stop]
else:
tensor = slice_[:]
if current_tensor.shape != tensor.shape:
raise ValueError(
f"Name {name} -- Current {current_tensor.shape} and got {tensor.shape}"
)
tensor = tensor.contiguous().to(dtype)
if quantize == "bitsandbytes":
if not HAS_BITS_AND_BYTES:
raise ImportError(
"bitsandbytes is not available on your machine either because it is not installed "
"or you don't have a GPU.\n"
"You can install it with `pip install bitsandbytes`."
)
if (
type(module)
in [TensorParallelRowLinear, TensorParallelColumnLinear]
and param_name == "weight"
):
tensor = Int8Params(
tensor,
has_fp16_weights=False,
requires_grad=False,
).to(device)
state = bnb.MatmulLtState()
state.threshold = 6.0
state.has_fp16_weights = False
state.memory_efficient_backward = False
state.use_pool = True
state.CB = tensor.CB
state.SCB = tensor.SCB
tensor.CB = None
tensor.SCB = None
def replace_linear(state):
def linear(input, weight, bias):
out = bnb.matmul(
input,
weight,
state=state,
threshold=state.threshold,
bias=bias,
)
if state.CB is not None:
# we converted 8-bit row major to turing/ampere format
# in the first inference pass
# we no longer need the row-major weight
del state.CB
weight.data = state.CxB
return out
return linear
module.linear = replace_linear(state)
else:
tensor = tensor.to(device)
elif quantize == "gptq":
raise NotImplementedError("`gptq` is not implemented for now")
elif quantize is None:
tensor = tensor.to(device)
else:
raise ValueError(f"Unexpected quantize `{quantize}`")
@property
def batch_type(self) -> Type[CausalLMBatch]:
return GalacticaCausalLMBatch
module._parameters[param_name] = tensor
if name == "model.decoder.embed_tokens.weight":
model.lm_head._parameters["weight"] = tensor
def decode(self, generated_ids: List[int]) -> str:
# Do not skip special tokens as they are used for custom parsing rules of the generated text
return self.tokenizer.decode(
generated_ids, skip_special_tokens=False, clean_up_tokenization_spaces=False
)
def forward(
self, input_ids, attention_mask, position_ids, past_key_values: Optional = None
......@@ -386,10 +223,4 @@ class GalacticaSharded(Galactica):
past_key_values=past_key_values,
use_cache=True,
)
# Logits are sharded, so we need to gather them
logits = [torch.empty_like(outputs.logits) for _ in range(self.world_size)]
torch.distributed.all_gather(logits, outputs.logits, group=self.process_group)
logits = torch.cat(logits, dim=2)
return logits, outputs.past_key_values
return outputs.logits, outputs.past_key_values
import torch
import torch.distributed
from typing import List, Optional
from typing import Optional
from accelerate import init_empty_weights
from safetensors import safe_open
from transformers import (
AutoTokenizer,
AutoModelForCausalLM,
AutoConfig,
)
from transformers.models.gpt_neox.parallel_layers import (
TensorParallelColumnLinear,
TensorParallelEmbedding,
TensorParallelRowLinear,
)
from text_generation_server.models import CausalLM
from text_generation_server.models.custom_modeling.neox_modeling import (
GPTNeoxForCausalLM,
)
from text_generation_server.utils import (
initialize_torch_distributed,
weight_files,
Weights,
)
HAS_BITS_AND_BYTES = True
try:
import bitsandbytes as bnb
from bitsandbytes.nn import Int8Params
except Exception as e:
HAS_BITS_AND_BYTES = False
class GPTNeoxSharded(CausalLM):
def __init__(
......@@ -58,28 +46,18 @@ class GPTNeoxSharded(CausalLM):
config = AutoConfig.from_pretrained(
model_id,
revision=revision,
tp_parallel=True,
trust_remote_code=trust_remote_code,
)
config.quantize = quantize
torch.distributed.barrier(group=self.process_group)
filenames = weight_files(model_id, revision=revision, extension=".safetensors")
weights = Weights(
filenames, device=device, dtype=dtype, process_group=self.process_group
)
with init_empty_weights():
model = AutoModelForCausalLM.from_config(
config, trust_remote_code=trust_remote_code
)
model = GPTNeoxForCausalLM(config, weights)
torch.distributed.barrier(group=self.process_group)
self.load_weights(
model,
filenames,
quantize=quantize,
device=device,
dtype=dtype,
rank=rank,
world_size=world_size,
)
torch.distributed.barrier(group=self.process_group)
super(CausalLM, self).__init__(
model=model,
......@@ -91,161 +69,16 @@ class GPTNeoxSharded(CausalLM):
world_size=world_size,
)
@staticmethod
def load_weights(
model,
filenames: List[str],
quantize: Optional[str],
device: torch.device,
dtype: torch.dtype,
rank: int,
world_size: int,
):
parameters = dict(model.named_parameters())
for file in filenames:
with safe_open(
file, framework="pt", device=str(device) if quantize is None else "cpu"
) as f:
for name in f.keys():
module_name, param_name = name.rsplit(".", 1)
module = model.get_submodule(module_name)
current_parameter_tensor = parameters.get(name, None)
slice_ = f.get_slice(name)
if isinstance(module, TensorParallelColumnLinear):
size = slice_.get_shape()[0]
block_size = size // world_size
start = rank * block_size
stop = (rank + 1) * block_size
tensor = slice_[start:stop]
elif isinstance(module, TensorParallelRowLinear):
if param_name == "weight":
size = slice_.get_shape()[1]
block_size = size // world_size
start = rank * block_size
stop = (rank + 1) * block_size
tensor = slice_[:, start:stop]
else:
tensor = slice_[:]
# XXX: Hack for Rowlinear to add the bias only once.
if rank != 0:
tensor = torch.zeros_like(tensor)
elif isinstance(module, TensorParallelEmbedding):
size = slice_.get_shape()[0]
block_size = size // world_size
start = rank * block_size
stop = (rank + 1) * block_size
tensor = slice_[start:stop]
elif name == "embed_out.weight" and model.gpt_neox.tp_embeddings:
size = slice_.get_shape()[0]
block_size = size // world_size
start = rank * block_size
stop = (rank + 1) * block_size
tensor = slice_[start:stop]
else:
try:
tensor = slice_[:]
except:
tensor = f.get_tensor(name)
if (
current_parameter_tensor is not None
and current_parameter_tensor.shape != tensor.shape
):
raise ValueError(
f"Name {name} -- Current {current_parameter_tensor.shape} and got {tensor.shape}"
)
tensor = tensor.contiguous().to(dtype)
if quantize == "bitsandbytes":
if not HAS_BITS_AND_BYTES:
raise ImportError(
"bitsandbytes is not available on your machine either because it is not installed "
"or you don't have a GPU.\n"
"You can install it with `pip install bitsandbytes`."
)
if (
type(module)
in [TensorParallelRowLinear, TensorParallelColumnLinear]
and param_name == "weight"
):
tensor = Int8Params(
tensor,
has_fp16_weights=False,
requires_grad=False,
).to(device)
state = bnb.MatmulLtState()
state.threshold = 6.0
state.has_fp16_weights = False
state.memory_efficient_backward = False
state.use_pool = True
state.CB = tensor.CB
state.SCB = tensor.SCB
tensor.CB = None
tensor.SCB = None
def replace_linear(state):
def linear(input, weight, bias):
out = bnb.matmul(
input,
weight,
state=state,
threshold=state.threshold,
bias=bias,
)
if state.CB is not None:
# we converted 8-bit row major to turing/ampere format
# in the first inference pass
# we no longer need the row-major weight
del state.CB
weight.data = state.CxB
return out
return linear
module.linear = replace_linear(state)
else:
tensor = tensor.to(device)
elif quantize == "gptq":
raise NotImplementedError("`gptq` is not implemented for now")
elif quantize is None:
tensor = tensor.to(device)
else:
raise ValueError(f"Unexpected quantize `{quantize}`")
if current_parameter_tensor is not None:
module._parameters[param_name] = tensor
else:
module._buffers[param_name] = tensor
def forward(
self, input_ids, attention_mask, position_ids, past_key_values: Optional = None
):
if self.model.gpt_neox.tp_embeddings:
outputs = self.model.forward(
input_ids=input_ids,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_values=past_key_values,
use_cache=True,
)
# Logits are sharded, so we need to gather them
logits = [torch.empty_like(outputs.logits) for _ in range(self.world_size)]
torch.distributed.all_gather(
logits, outputs.logits, group=self.process_group
)
logits = torch.cat(logits, dim=2)
outputs = self.model.forward(
input_ids=input_ids,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_values=past_key_values,
use_cache=True,
)
return logits, outputs.past_key_values
# While the model itself is sharded, the embeddings might not as they might not be dividable by num-shard
else:
return super(GPTNeoxSharded, self).forward(
input_ids, attention_mask, position_ids, past_key_values
)
logits = outputs.logits
return logits, outputs.past_key_values
import torch
import torch.distributed
from typing import List, Optional, Tuple
from typing import Optional
from accelerate import init_empty_weights
from safetensors import safe_open
from transformers import (
AutoTokenizer,
AutoModelForCausalLM,
AutoConfig,
)
from transformers.models.opt.parallel_layers import (
TensorParallelColumnLinear,
TensorParallelEmbedding,
TensorParallelRowLinear,
)
from text_generation_server.models.custom_modeling.opt_modeling import OPTForCausalLM
from text_generation_server.models import CausalLM
from text_generation_server.utils import (
initialize_torch_distributed,
weight_files,
Weights,
)
HAS_BITS_AND_BYTES = True
try:
import bitsandbytes as bnb
from bitsandbytes.nn import Int8Params
except Exception as e:
HAS_BITS_AND_BYTES = False
class OPT(CausalLM):
def forward(
self, input_ids, attention_mask, position_ids, past_key_values: Optional = None
) -> Tuple[torch.Tensor, List[Tuple[torch.Tensor, torch.Tensor]]]:
"""Overwrite forward to ignore position_ids"""
# Model Forward
outputs = self.model.forward(
input_ids=input_ids,
attention_mask=attention_mask,
past_key_values=past_key_values,
use_cache=True,
)
return outputs.logits, outputs.past_key_values
class OPTSharded(OPT):
class OPTSharded(CausalLM):
def __init__(
self,
model_id: str,
......@@ -73,29 +43,19 @@ class OPTSharded(OPT):
config = AutoConfig.from_pretrained(
model_id,
revision=revision,
tp_parallel=True,
trust_remote_code=trust_remote_code,
)
config.quantize = quantize
tokenizer.pad_token_id = config.pad_token_id
torch.distributed.barrier(group=self.process_group)
filenames = weight_files(model_id, revision=revision, extension=".safetensors")
weights = Weights(
filenames, device=device, dtype=dtype, process_group=self.process_group
)
with init_empty_weights():
model = AutoModelForCausalLM.from_config(
config, trust_remote_code=trust_remote_code
)
model = OPTForCausalLM(config, weights)
torch.distributed.barrier(group=self.process_group)
self.load_weights(
model,
filenames,
quantize=quantize,
device=device,
dtype=dtype,
rank=rank,
world_size=world_size,
)
torch.distributed.barrier(group=self.process_group)
super(CausalLM, self).__init__(
model=model,
......@@ -107,128 +67,6 @@ class OPTSharded(OPT):
world_size=world_size,
)
@staticmethod
def load_weights(
model,
filenames: List[str],
quantize: Optional[str],
device: torch.device,
dtype: torch.dtype,
rank: int,
world_size: int,
):
parameters = dict(model.named_parameters())
for file in filenames:
with safe_open(
file, framework="pt", device=str(device) if quantize is None else "cpu"
) as f:
for name in f.keys():
if name == "lm_head.weight":
continue
module_name, param_name = name.rsplit(".", 1)
module = model.get_submodule(module_name)
current_tensor = parameters[name]
slice_ = f.get_slice(name)
if isinstance(module, TensorParallelColumnLinear):
size = slice_.get_shape()[0]
block_size = size // world_size
start = rank * block_size
stop = (rank + 1) * block_size
tensor = slice_[start:stop]
elif isinstance(module, TensorParallelRowLinear):
if param_name == "weight":
size = slice_.get_shape()[1]
block_size = size // world_size
start = rank * block_size
stop = (rank + 1) * block_size
tensor = slice_[:, start:stop]
else:
tensor = slice_[:]
# XXX: Hack for Rowlinear to add the bias only once.
if rank != 0:
tensor = torch.zeros_like(tensor)
elif isinstance(module, TensorParallelEmbedding):
size = slice_.get_shape()[0]
block_size = size // world_size
start = rank * block_size
stop = (rank + 1) * block_size
tensor = slice_[start:stop]
else:
tensor = slice_[:]
if current_tensor.shape != tensor.shape:
raise ValueError(
f"Name {name} -- Current {current_tensor.shape} and got {tensor.shape}"
)
tensor = tensor.contiguous().to(dtype)
if quantize == "bitsandbytes":
if not HAS_BITS_AND_BYTES:
raise ImportError(
"bitsandbytes is not available on your machine either because it is not installed "
"or you don't have a GPU.\n"
"You can install it with `pip install bitsandbytes`."
)
if (
type(module)
in [TensorParallelRowLinear, TensorParallelColumnLinear]
and param_name == "weight"
):
tensor = Int8Params(
tensor,
has_fp16_weights=False,
requires_grad=False,
).to(device)
state = bnb.MatmulLtState()
state.threshold = 6.0
state.has_fp16_weights = False
state.memory_efficient_backward = False
state.use_pool = True
state.CB = tensor.CB
state.SCB = tensor.SCB
tensor.CB = None
tensor.SCB = None
def replace_linear(state):
def linear(input, weight, bias):
out = bnb.matmul(
input,
weight,
state=state,
threshold=state.threshold,
bias=bias,
)
if state.CB is not None:
# we converted 8-bit row major to turing/ampere format
# in the first inference pass
# we no longer need the row-major weight
del state.CB
weight.data = state.CxB
return out
return linear
module.linear = replace_linear(state)
else:
tensor = tensor.to(device)
elif quantize == "gptq":
raise NotImplementedError("`gptq` is not implemented for now")
elif quantize is None:
tensor = tensor.to(device)
else:
raise ValueError(f"Unexpected quantize `{quantize}`")
module._parameters[param_name] = tensor
if name == "model.decoder.embed_tokens.weight":
model.lm_head._parameters["weight"] = tensor
def forward(
self, input_ids, attention_mask, position_ids, past_key_values: Optional = None
):
......@@ -239,9 +77,4 @@ class OPTSharded(OPT):
use_cache=True,
)
# Logits are sharded, so we need to gather them
logits = [torch.empty_like(outputs.logits) for _ in range(self.world_size)]
torch.distributed.all_gather(logits, outputs.logits, group=self.process_group)
logits = torch.cat(logits, dim=2)
return logits, outputs.past_key_values
return outputs.logits, outputs.past_key_values
......@@ -3,31 +3,20 @@ import torch.distributed
from typing import List, Optional, Tuple
from accelerate import init_empty_weights
from safetensors import safe_open
from transformers import (
AutoTokenizer,
AutoModelForSeq2SeqLM,
AutoConfig,
)
from text_generation_server.models import Seq2SeqLM
from text_generation_server.models.custom_modeling.t5_modeling import (
T5ForConditionalGeneration,
)
from text_generation_server.utils import (
initialize_torch_distributed,
weight_files,
Weights,
)
from transformers.models.t5.parallel_layers import (
TensorParallelRowLinear,
TensorParallelColumnLinear,
TensorParallelEmbedding,
)
HAS_BITS_AND_BYTES = True
try:
import bitsandbytes as bnb
from bitsandbytes.nn import Int8Params
except ImportError as e:
HAS_BITS_AND_BYTES = False
class T5Sharded(Seq2SeqLM):
......@@ -46,40 +35,30 @@ class T5Sharded(Seq2SeqLM):
device = torch.device("cpu")
dtype = torch.float32
tokenizer = AutoTokenizer.from_pretrained(
config = AutoConfig.from_pretrained(
model_id,
revision=revision,
padding_side="left",
truncation_side="left",
trust_remote_code=trust_remote_code,
)
config.quantize = quantize
config = AutoConfig.from_pretrained(
tokenizer = AutoTokenizer.from_pretrained(
model_id,
revision=revision,
tp_parallel=True,
padding_side="left",
truncation_side="left",
trust_remote_code=trust_remote_code,
)
tokenizer.bos_token_id = config.decoder_start_token_id
torch.distributed.barrier(group=self.process_group)
filenames = weight_files(model_id, revision=revision, extension=".safetensors")
weights = Weights(
filenames, device=device, dtype=dtype, process_group=self.process_group
)
with init_empty_weights():
model = AutoModelForSeq2SeqLM.from_config(
config, trust_remote_code=trust_remote_code
)
model = T5ForConditionalGeneration(config, weights)
torch.distributed.barrier(group=self.process_group)
self.load_weights(
model,
filenames,
quantize=quantize,
device=device,
dtype=dtype,
rank=rank,
world_size=world_size,
)
torch.distributed.barrier(group=self.process_group)
super(Seq2SeqLM, self).__init__(
model=model,
......@@ -91,151 +70,6 @@ class T5Sharded(Seq2SeqLM):
world_size=world_size,
)
@staticmethod
def load_weights(
model,
filenames: List[str],
quantize: Optional[str],
device: torch.device,
dtype: torch.dtype,
rank: int,
world_size: int,
):
parameters = dict(model.named_parameters())
for file in filenames:
with safe_open(
file, framework="pt", device=str(device) if quantize is None else "cpu"
) as f:
for name in f.keys():
module_name, param_name = name.rsplit(".", 1)
module = model.get_submodule(module_name)
current_parameter_tensor = parameters.get(name, None)
slice_ = f.get_slice(name)
if isinstance(module, TensorParallelColumnLinear):
size = slice_.get_shape()[0]
block_size = size // world_size
start = rank * block_size
stop = (rank + 1) * block_size
tensor = slice_[start:stop]
elif isinstance(module, TensorParallelRowLinear):
if param_name == "weight":
size = slice_.get_shape()[1]
block_size = size // world_size
start = rank * block_size
stop = (rank + 1) * block_size
tensor = slice_[:, start:stop]
else:
tensor = slice_[:]
# XXX: Hack for Rowlinear to add the bias only once.
if rank != 0:
tensor = torch.zeros_like(tensor)
elif isinstance(module, TensorParallelEmbedding):
size = slice_.get_shape()[0]
block_size = size // world_size
start = rank * block_size
stop = (rank + 1) * block_size
tensor = slice_[start:stop]
elif name == "lm_head.weight":
size = slice_.get_shape()[0]
block_size = size // world_size
start = rank * block_size
stop = (rank + 1) * block_size
tensor = slice_[start:stop]
elif "relative_attention_bias.weight" in name:
size = slice_.get_shape()[1]
block_size = size // world_size
start = rank * block_size
stop = (rank + 1) * block_size
tensor = slice_[:, start:stop]
else:
try:
tensor = slice_[:]
except:
tensor = f.get_tensor(name)
if (
current_parameter_tensor is not None
and current_parameter_tensor.shape != tensor.shape
):
raise ValueError(
f"Name {name} -- Current {current_parameter_tensor.shape} and got {tensor.shape}"
)
tensor = tensor.contiguous()
# See: https://github.com/huggingface/transformers/blob/1fe1e3caa44617047f149bcc0c0b566343b714a7/src/transformers/models/t5/modeling_t5.py#LL316C15-L316C71
if module_name.endswith("wo"):
tensor = tensor.to(torch.float32)
else:
tensor = tensor.to(dtype)
if quantize == "bitsandbytes" and not module_name.endswith("wo"):
if not HAS_BITS_AND_BYTES:
raise ImportError(
"bitsandbytes is not available on your machine either because it is not installed "
"or you don't have a GPU.\n"
"You can install it with `pip install bitsandbytes`."
)
if (
type(module)
in [TensorParallelRowLinear, TensorParallelColumnLinear]
and param_name == "weight"
):
tensor = Int8Params(
tensor,
has_fp16_weights=False,
requires_grad=False,
).to(device)
state = bnb.MatmulLtState()
state.threshold = 6.0
state.has_fp16_weights = False
state.memory_efficient_backward = False
state.use_pool = True
state.CB = tensor.CB
state.SCB = tensor.SCB
tensor.CB = None
tensor.SCB = None
def replace_linear(state):
def linear(input, weight, bias):
out = bnb.matmul(
input,
weight,
state=state,
threshold=state.threshold,
bias=bias,
)
if state.CB is not None:
# we converted 8-bit row major to turing/ampere format
# in the first inference pass
# we no longer need the row-major weight
del state.CB
weight.data = state.CxB
return out
return linear
module.linear = replace_linear(state)
else:
tensor = tensor.to(device)
elif quantize == "gptq" and not module_name.endswith("wo"):
raise NotImplementedError("`gptq` is not implemented for now")
elif quantize is None or module_name.endswith("wo"):
tensor = tensor.to(device)
else:
raise ValueError(f"Unexpected quantize `{quantize}`")
if current_parameter_tensor is not None:
module._parameters[param_name] = tensor
else:
module._buffers[param_name] = tensor
def forward(
self,
input_ids,
......@@ -260,13 +94,8 @@ class T5Sharded(Seq2SeqLM):
use_cache=True,
)
# Logits are sharded, so we need to gather them
logits = [torch.empty_like(outputs.logits) for _ in range(self.world_size)]
torch.distributed.all_gather(logits, outputs.logits, group=self.process_group)
logits = torch.cat(logits, dim=2)
return (
logits,
outputs.logits,
outputs.encoder_last_hidden_state,
outputs.past_key_values,
)
from text_generation_server.utils.convert import convert_file, convert_files
from text_generation_server.utils.dist import initialize_torch_distributed
from text_generation_server.utils.weights import Weights
from text_generation_server.utils.hub import (
weight_files,
weight_hub_files,
......@@ -35,4 +36,5 @@ __all__ = [
"StoppingCriteria",
"StopSequenceCriteria",
"FinishReason",
"Weights",
]
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment