"vscode:/vscode.git/clone" did not exist on "d16a192041a1122424dc3baf67075547ff219ddf"
Unverified Commit abd58ff8 authored by Nicolas Patry's avatar Nicolas Patry Committed by GitHub
Browse files

feat(server): Rework model loading (#344)

# What does this PR do?

Reworked the loading logic. Idea is to use cleaner loading code:

- Remove need for `no_init_weights`
- Remove all weird `bnb_linear` and `load_weights` and
`post_load_weights`.

New code layout:

- New class `Weights` in charge of handling loading the weights from
multiple files into appropiate tensors (potentially sharded)
- TP layers now are "shells", they contain the code to know what kind of
sharding we need + eventual `all_reduce`. They do not inherit from
linear, but they contain some kind of Linear instead
- the contained linear can be either FastLinear, BnbLinear or GPTq
Linear next.
- All modeling code is explictly made for sharding, process group is
just no-ops for non sharded code (removes a lot of test cases)

![Screenshot from 2023-05-19
23-19-59](https://github.com/huggingface/text-generation-inference/assets/204321/9a802654-74a3-488c-87a8-073743a6143f

)

---------
Co-authored-by: default avatarUbuntu <ubuntu@ip-172-31-41-161.taildb5d.ts.net>
Co-authored-by: default avatarUbuntu <ubuntu@ip-172-31-41-161.ec2.internal>
Co-authored-by: default avatarOlivierDehaene <olivier@huggingface.co>
Co-authored-by: default avatarOlivierDehaene <23298448+OlivierDehaene@users.noreply.github.com>
parent 19c41824
...@@ -6,12 +6,17 @@ from transformers import AutoTokenizer ...@@ -6,12 +6,17 @@ from transformers import AutoTokenizer
from text_generation_server.pb import generate_pb2 from text_generation_server.pb import generate_pb2
from text_generation_server.models.causal_lm import CausalLMBatch from text_generation_server.models.causal_lm import CausalLMBatch
from text_generation_server.models.bloom import BloomCausalLMBatch, BLOOM from text_generation_server.utils import weight_hub_files, download_weights
from text_generation_server.models.bloom import BloomCausalLMBatch, BLOOMSharded
@pytest.fixture(scope="session") @pytest.fixture(scope="session")
def default_bloom(): def default_bloom():
return BLOOM("bigscience/bloom-560m") model_id = "bigscience/bloom-560m"
revision = "main"
filenames = weight_hub_files(model_id, revision, ".safetensors")
download_weights(filenames, model_id, revision)
return BLOOMSharded(model_id)
@pytest.fixture(scope="session") @pytest.fixture(scope="session")
......
import os
import torch import torch
from loguru import logger from loguru import logger
...@@ -8,17 +9,20 @@ from typing import Optional ...@@ -8,17 +9,20 @@ from typing import Optional
from text_generation_server.models.model import Model from text_generation_server.models.model import Model
from text_generation_server.models.causal_lm import CausalLM from text_generation_server.models.causal_lm import CausalLM
from text_generation_server.models.flash_causal_lm import FlashCausalLM from text_generation_server.models.flash_causal_lm import FlashCausalLM
from text_generation_server.models.bloom import BLOOM, BLOOMSharded from text_generation_server.models.bloom import BLOOMSharded
from text_generation_server.models.seq2seq_lm import Seq2SeqLM from text_generation_server.models.seq2seq_lm import Seq2SeqLM
from text_generation_server.models.rw import RW from text_generation_server.models.rw import RW
from text_generation_server.models.opt import OPT, OPTSharded from text_generation_server.models.opt import OPTSharded
from text_generation_server.models.galactica import Galactica, GalacticaSharded from text_generation_server.models.galactica import GalacticaSharded
from text_generation_server.models.santacoder import SantaCoder from text_generation_server.models.santacoder import SantaCoder
from text_generation_server.models.gpt_neox import GPTNeoxSharded
from text_generation_server.models.t5 import T5Sharded from text_generation_server.models.t5 import T5Sharded
from text_generation_server.models.gpt_neox import GPTNeoxSharded
try: try:
if torch.cuda.is_available(): if (
torch.cuda.is_available()
and not os.getenv("USE_FLASH_ATTENTION", "").lower() == "false"
):
major, minor = torch.cuda.get_device_capability() major, minor = torch.cuda.get_device_capability()
is_sm75 = major == 7 and minor == 5 is_sm75 = major == 7 and minor == 5
is_sm8x = major == 8 and minor >= 0 is_sm8x = major == 8 and minor >= 0
...@@ -30,14 +34,12 @@ try: ...@@ -30,14 +34,12 @@ try:
f"GPU with CUDA capability {major} {minor} is not supported" f"GPU with CUDA capability {major} {minor} is not supported"
) )
from text_generation_server.models.flash_neox import FlashNeoX, FlashNeoXSharded from text_generation_server.models.flash_rw import FlashRWSharded
from text_generation_server.models.flash_rw import FlashRW, FlashRWSharded from text_generation_server.models.flash_neox import FlashNeoXSharded
from text_generation_server.models.flash_llama import ( from text_generation_server.models.flash_llama import (
FlashLlama, FlashLlama,
FlashLlamaSharded,
) )
from text_generation_server.models.flash_santacoder import ( from text_generation_server.models.flash_santacoder import (
FlashSantacoder,
FlashSantacoderSharded, FlashSantacoderSharded,
) )
...@@ -52,30 +54,22 @@ except ImportError: ...@@ -52,30 +54,22 @@ except ImportError:
__all__ = [ __all__ = [
"Model", "Model",
"BLOOM",
"BLOOMSharded", "BLOOMSharded",
"CausalLM", "CausalLM",
"FlashCausalLM", "FlashCausalLM",
"Galactica",
"GalacticaSharded", "GalacticaSharded",
"GPTNeoxSharded",
"Seq2SeqLM", "Seq2SeqLM",
"SantaCoder", "SantaCoder",
"OPT",
"OPTSharded", "OPTSharded",
"T5Sharded", "T5Sharded",
"get_model", "get_model",
] ]
if FLASH_ATTENTION: if FLASH_ATTENTION:
__all__.append(FlashNeoX)
__all__.append(FlashNeoXSharded) __all__.append(FlashNeoXSharded)
__all__.append(FlashRW)
__all__.append(FlashRWSharded) __all__.append(FlashRWSharded)
__all__.append(FlashSantacoder)
__all__.append(FlashSantacoderSharded) __all__.append(FlashSantacoderSharded)
__all__.append(FlashLlama) __all__.append(FlashLlama)
__all__.append(FlashLlamaSharded)
FLASH_ATT_ERROR_MESSAGE = ( FLASH_ATT_ERROR_MESSAGE = (
"{} requires Flash Attention CUDA kernels to be installed.\n" "{} requires Flash Attention CUDA kernels to be installed.\n"
...@@ -102,36 +96,24 @@ def get_model( ...@@ -102,36 +96,24 @@ def get_model(
trust_remote_code: bool, trust_remote_code: bool,
) -> Model: ) -> Model:
if "facebook/galactica" in model_id: if "facebook/galactica" in model_id:
if sharded:
return GalacticaSharded( return GalacticaSharded(
model_id, model_id, revision, quantize=quantize, trust_remote_code=trust_remote_code
revision,
quantize=quantize,
trust_remote_code=trust_remote_code,
)
else:
return Galactica(
model_id,
revision,
quantize=quantize,
trust_remote_code=trust_remote_code,
) )
if model_id.startswith("bigcode/"): if model_id.startswith("bigcode/"):
if sharded: if FLASH_ATTENTION:
if not FLASH_ATTENTION:
raise NotImplementedError(
FLASH_ATT_ERROR_MESSAGE.format(f"Sharded Santacoder")
)
return FlashSantacoderSharded( return FlashSantacoderSharded(
model_id, model_id,
revision, revision,
quantize=quantize, quantize=quantize,
trust_remote_code=trust_remote_code, trust_remote_code=trust_remote_code,
) )
elif sharded:
raise NotImplementedError(
FLASH_ATT_ERROR_MESSAGE.format("Sharded Santacoder")
)
else: else:
santacoder_cls = FlashSantacoder if FLASH_ATTENTION else SantaCoder return SantaCoder(
return santacoder_cls(
model_id, model_id,
revision, revision,
quantize=quantize, quantize=quantize,
...@@ -144,20 +126,19 @@ def get_model( ...@@ -144,20 +126,19 @@ def get_model(
model_type = config_dict["model_type"] model_type = config_dict["model_type"]
if model_type == "gpt_bigcode": if model_type == "gpt_bigcode":
if sharded: if FLASH_ATTENTION:
if not FLASH_ATTENTION:
raise NotImplementedError(
FLASH_ATT_ERROR_MESSAGE.format(f"Sharded Santacoder")
)
return FlashSantacoderSharded( return FlashSantacoderSharded(
model_id, model_id,
revision, revision,
quantize=quantize, quantize=quantize,
trust_remote_code=trust_remote_code, trust_remote_code=trust_remote_code,
) )
elif sharded:
raise NotImplementedError(
FLASH_ATT_ERROR_MESSAGE.format("Sharded Santacoder")
)
else: else:
santacoder_cls = FlashSantacoder if FLASH_ATTENTION else SantaCoder return SantaCoder(
return santacoder_cls(
model_id, model_id,
revision, revision,
quantize=quantize, quantize=quantize,
...@@ -165,33 +146,45 @@ def get_model( ...@@ -165,33 +146,45 @@ def get_model(
) )
if model_type == "bloom": if model_type == "bloom":
if sharded:
return BLOOMSharded( return BLOOMSharded(
model_id, revision, quantize=quantize, trust_remote_code=trust_remote_code
)
elif model_type == "gpt_neox":
if FLASH_ATTENTION:
return FlashNeoXSharded(
model_id,
revision,
quantize=quantize,
trust_remote_code=trust_remote_code,
)
elif sharded:
return GPTNeoxSharded(
model_id, model_id,
revision, revision,
quantize=quantize, quantize=quantize,
trust_remote_code=trust_remote_code, trust_remote_code=trust_remote_code,
) )
else: else:
return BLOOM( return CausalLM(
model_id, model_id,
revision, revision,
quantize=quantize, quantize=quantize,
trust_remote_code=trust_remote_code, trust_remote_code=trust_remote_code,
) )
if model_type == "gpt_neox": elif model_type == "llama":
if sharded: if FLASH_ATTENTION:
neox_cls = FlashNeoXSharded if FLASH_ATTENTION else GPTNeoxSharded return FlashLlama(
return neox_cls(
model_id, model_id,
revision, revision,
quantize=quantize, quantize=quantize,
trust_remote_code=trust_remote_code, trust_remote_code=trust_remote_code,
) )
elif sharded:
raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Sharded Llama"))
else: else:
neox_cls = FlashNeoX if FLASH_ATTENTION else CausalLM return CausalLM(
return neox_cls(
model_id, model_id,
revision, revision,
quantize=quantize, quantize=quantize,
...@@ -217,7 +210,7 @@ def get_model( ...@@ -217,7 +210,7 @@ def get_model(
) )
else: else:
if FLASH_ATTENTION and not config_dict.get("alibi", False): if FLASH_ATTENTION and not config_dict.get("alibi", False):
return FlashRW( return FlashRWSharded(
model_id, model_id,
revision, revision,
quantize=quantize, quantize=quantize,
...@@ -231,42 +224,12 @@ def get_model( ...@@ -231,42 +224,12 @@ def get_model(
trust_remote_code=trust_remote_code, trust_remote_code=trust_remote_code,
) )
if model_type == "llama": elif model_type == "opt":
if sharded:
if FLASH_ATTENTION:
return FlashLlamaSharded(
model_id,
revision,
quantize=quantize,
trust_remote_code=trust_remote_code,
)
raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format(f"Sharded Llama"))
else:
llama_cls = FlashLlama if FLASH_ATTENTION else CausalLM
return llama_cls(
model_id,
revision,
quantize=quantize,
trust_remote_code=trust_remote_code,
)
if model_type == "opt":
if sharded:
return OPTSharded( return OPTSharded(
model_id, model_id, revision, quantize=quantize, trust_remote_code=trust_remote_code
revision,
quantize=quantize,
trust_remote_code=trust_remote_code,
)
else:
return OPT(
model_id,
revision,
quantize=quantize,
trust_remote_code=trust_remote_code,
) )
if model_type == "t5": elif model_type == "t5":
if sharded: if sharded:
return T5Sharded( return T5Sharded(
model_id, model_id,
......
import torch import torch
import torch.distributed import torch.distributed
from typing import List, Optional, Type from typing import Optional, Type
from accelerate import init_empty_weights
from safetensors import safe_open
from transformers import ( from transformers import (
AutoTokenizer, AutoTokenizer,
AutoModelForCausalLM,
AutoConfig, AutoConfig,
PreTrainedTokenizerBase, PreTrainedTokenizerBase,
) )
from transformers.models.bloom.parallel_layers import (
TensorParallelColumnLinear,
TensorParallelEmbedding,
TensorParallelRowLinear,
)
from text_generation_server.models.custom_modeling.bloom_modeling import (
BloomForCausalLM,
)
from text_generation_server.models import CausalLM from text_generation_server.models import CausalLM
from text_generation_server.models.causal_lm import CausalLMBatch from text_generation_server.models.causal_lm import CausalLMBatch
from text_generation_server.pb import generate_pb2 from text_generation_server.pb import generate_pb2
from text_generation_server.utils import ( from text_generation_server.utils import (
initialize_torch_distributed, initialize_torch_distributed,
weight_files, weight_files,
Weights,
) )
HAS_BITS_AND_BYTES = True
try:
import bitsandbytes as bnb
from bitsandbytes.nn import Int8Params
except Exception as e:
HAS_BITS_AND_BYTES = False
class BloomCausalLMBatch(CausalLMBatch): class BloomCausalLMBatch(CausalLMBatch):
@classmethod @classmethod
...@@ -42,34 +31,12 @@ class BloomCausalLMBatch(CausalLMBatch): ...@@ -42,34 +31,12 @@ class BloomCausalLMBatch(CausalLMBatch):
dtype: torch.dtype, dtype: torch.dtype,
device: torch.device, device: torch.device,
) -> "CausalLMBatch": ) -> "CausalLMBatch":
batch = super(BloomCausalLMBatch, cls).from_pb( batch = super().from_pb(pb=pb, tokenizer=tokenizer, dtype=dtype, device=device)
pb=pb, tokenizer=tokenizer, dtype=dtype, device=device
)
batch.keys_head_dim_last = False batch.keys_head_dim_last = False
return batch return batch
class BLOOM(CausalLM): class BLOOMSharded(CausalLM):
def __init__(
self,
model_id: str,
revision: Optional[str] = None,
quantize: Optional[str] = None,
trust_remote_code: bool = False,
):
super(BLOOM, self).__init__(
model_id=model_id,
revision=revision,
quantize=quantize,
trust_remote_code=trust_remote_code,
)
@property
def batch_type(self) -> Type[CausalLMBatch]:
return BloomCausalLMBatch
class BLOOMSharded(BLOOM):
def __init__( def __init__(
self, self,
model_id: str, model_id: str,
...@@ -101,25 +68,16 @@ class BLOOMSharded(BLOOM): ...@@ -101,25 +68,16 @@ class BLOOMSharded(BLOOM):
trust_remote_code=trust_remote_code, trust_remote_code=trust_remote_code,
) )
config.pad_token_id = 3 config.pad_token_id = 3
config.quantize = quantize
torch.distributed.barrier(group=self.process_group) torch.distributed.barrier(group=self.process_group)
filenames = weight_files(model_id, revision=revision, extension=".safetensors") filenames = weight_files(model_id, revision=revision, extension=".safetensors")
weights = Weights(
with init_empty_weights(): filenames, device=device, dtype=dtype, process_group=self.process_group
model = AutoModelForCausalLM.from_config(
config, trust_remote_code=trust_remote_code
) )
torch.distributed.barrier(group=self.process_group) model = BloomForCausalLM(config, weights)
self.load_weights(
model,
filenames,
quantize=quantize,
device=device,
dtype=dtype,
rank=rank,
world_size=world_size,
)
torch.distributed.barrier(group=self.process_group) torch.distributed.barrier(group=self.process_group)
super(CausalLM, self).__init__( super(CausalLM, self).__init__(
model=model, model=model,
...@@ -131,132 +89,9 @@ class BLOOMSharded(BLOOM): ...@@ -131,132 +89,9 @@ class BLOOMSharded(BLOOM):
world_size=world_size, world_size=world_size,
) )
@staticmethod @property
def load_weights( def batch_type(self) -> Type[CausalLMBatch]:
model, return BloomCausalLMBatch
filenames: List[str],
quantize: Optional[str],
device: torch.device,
dtype: torch.dtype,
rank: int,
world_size: int,
):
parameters = dict(model.named_parameters())
for file in filenames:
with safe_open(
file, framework="pt", device=str(device) if quantize is None else "cpu"
) as f:
for name in f.keys():
if name.startswith("transformer.") or name.startswith("lm_head."):
full_name = name
else:
full_name = f"transformer.{name}"
module_name, param_name = full_name.rsplit(".", 1)
module = model.get_submodule(module_name)
current_tensor = parameters[full_name]
slice_ = f.get_slice(name)
if isinstance(module, TensorParallelColumnLinear):
size = slice_.get_shape()[0]
block_size = size // world_size
start = rank * block_size
stop = (rank + 1) * block_size
tensor = slice_[start:stop]
elif isinstance(module, TensorParallelRowLinear):
if param_name == "weight":
size = slice_.get_shape()[1]
block_size = size // world_size
start = rank * block_size
stop = (rank + 1) * block_size
tensor = slice_[:, start:stop]
else:
tensor = slice_[:]
# XXX: Hack for Rowlinear to add the bias only once.
if rank != 0:
tensor = torch.zeros_like(tensor)
elif (
isinstance(module, TensorParallelEmbedding)
or name == "lm_head.weight"
):
size = slice_.get_shape()[0]
block_size = size // world_size
start = rank * block_size
stop = (rank + 1) * block_size
tensor = slice_[start:stop]
else:
tensor = slice_[:]
if current_tensor.shape != tensor.shape:
raise ValueError(
f"Name {name} -- Current {current_tensor.shape} and got {tensor.shape}"
)
tensor = tensor.contiguous().to(dtype)
if quantize == "bitsandbytes":
if not HAS_BITS_AND_BYTES:
raise ImportError(
"bitsandbytes is not available on your machine either because it is not installed "
"or you don't have a GPU.\n"
"You can install it with `pip install bitsandbytes`."
)
if (
type(module)
in [TensorParallelRowLinear, TensorParallelColumnLinear]
and param_name == "weight"
):
tensor = Int8Params(
tensor,
has_fp16_weights=False,
requires_grad=False,
).to(device)
state = bnb.MatmulLtState()
state.threshold = 6.0
state.has_fp16_weights = False
state.memory_efficient_backward = False
state.use_pool = True
state.CB = tensor.CB
state.SCB = tensor.SCB
tensor.CB = None
tensor.SCB = None
def replace_linear(state):
def linear(input, weight, bias):
out = bnb.matmul(
input,
weight,
state=state,
threshold=state.threshold,
bias=bias,
)
if state.CB is not None:
# we converted 8-bit row major to turing/ampere format
# in the first inference pass
# we no longer need the row-major weight
del state.CB
weight.data = state.CxB
return out
return linear
module.linear = replace_linear(state)
else:
tensor = tensor.to(device)
elif quantize == "gptq":
raise NotImplementedError("`gptq` is not implemented for now")
elif quantize is None:
tensor = tensor.to(device)
else:
raise ValueError(f"Unexpected quantize `{quantize}`")
module._parameters[param_name] = tensor
if name == "word_embeddings.weight":
model.lm_head._parameters["weight"] = tensor
def forward( def forward(
self, input_ids, attention_mask, position_ids, past_key_values: Optional = None self, input_ids, attention_mask, position_ids, past_key_values: Optional = None
...@@ -269,9 +104,5 @@ class BLOOMSharded(BLOOM): ...@@ -269,9 +104,5 @@ class BLOOMSharded(BLOOM):
use_cache=True, use_cache=True,
) )
# Logits are sharded, so we need to gather them logits = outputs.logits
logits = [torch.empty_like(outputs.logits) for _ in range(self.world_size)]
torch.distributed.all_gather(logits, outputs.logits, group=self.process_group)
logits = torch.cat(logits, dim=2)
return logits, outputs.past_key_values return logits, outputs.past_key_values
...@@ -30,21 +30,23 @@ import flash_attn_cuda ...@@ -30,21 +30,23 @@ import flash_attn_cuda
import dropout_layer_norm import dropout_layer_norm
from text_generation_server.utils.layers import ( from text_generation_server.utils.layers import (
FastLinear,
TensorParallelRowLinear, TensorParallelRowLinear,
TensorParallelColumnLinear, TensorParallelColumnLinear,
TensorParallelEmbedding, TensorParallelEmbedding,
PositionRotaryEmbedding, PositionRotaryEmbedding,
TensorParallelHead,
) )
class LlamaRMSNorm(nn.Module): class LlamaRMSNorm(nn.Module):
def __init__(self, hidden_size, eps=1e-6): def __init__(self, prefix, weights, eps=1e-6):
""" """
LlamaRMSNorm is equivalent to T5LayerNorm LlamaRMSNorm is equivalent to T5LayerNorm
""" """
super().__init__() super().__init__()
self.weight = nn.Parameter(torch.ones(hidden_size))
weight = weights.get_tensor(f"{prefix}.weight")
self.weight = nn.Parameter(weight)
self.variance_epsilon = eps self.variance_epsilon = eps
def forward(self, hidden_states, residual=None): def forward(self, hidden_states, residual=None):
...@@ -91,34 +93,34 @@ class LlamaRMSNorm(nn.Module): ...@@ -91,34 +93,34 @@ class LlamaRMSNorm(nn.Module):
class FlashLlamaAttention(torch.nn.Module): class FlashLlamaAttention(torch.nn.Module):
def __init__( def __init__(
self, self,
num_heads, prefix: str,
hidden_size, config,
process_group=None, weights,
): ):
super().__init__() super().__init__()
self.num_heads = num_heads self.num_heads = config.num_attention_heads
self.hidden_size = hidden_size self.hidden_size = config.hidden_size
self.head_size = hidden_size // num_heads self.head_size = self.hidden_size // self.num_heads
self.rotary_emb = PositionRotaryEmbedding.load(
prefix=f"{prefix}.rotary_emb", weights=weights
)
self.rotary_emb = PositionRotaryEmbedding(self.head_size, base=10000)
self.softmax_scale = self.head_size ** (-0.5) self.softmax_scale = self.head_size ** (-0.5)
if process_group is None: self.num_heads = self.num_heads // weights.process_group.size()
self.query_key_value = FastLinear(hidden_size, 3 * hidden_size, bias=False) self.query_key_value = TensorParallelColumnLinear.load_multi(
self.o_proj = FastLinear(hidden_size, hidden_size, bias=False) config,
else: prefixes=[f"{prefix}.q_proj", f"{prefix}.k_proj", f"{prefix}.v_proj"],
self.num_heads = self.num_heads // process_group.size() dim=0,
self.query_key_value = TensorParallelColumnLinear( weights=weights,
hidden_size,
3 * hidden_size,
bias=False, bias=False,
process_group=process_group,
) )
self.o_proj = TensorParallelRowLinear( self.o_proj = TensorParallelRowLinear.load(
hidden_size, config,
hidden_size, prefix=f"{prefix}.o_proj",
weights=weights,
bias=False, bias=False,
process_group=process_group,
) )
def forward( def forward(
...@@ -195,8 +197,9 @@ class FlashLlamaAttention(torch.nn.Module): ...@@ -195,8 +197,9 @@ class FlashLlamaAttention(torch.nn.Module):
class LlamaMLP(nn.Module): class LlamaMLP(nn.Module):
def __init__(self, act, hidden_size, intermediate_size, process_group=None): def __init__(self, prefix, config, weights):
super().__init__() super().__init__()
act = config.hidden_act
self.act = ( self.act = (
ACT2FN[act] ACT2FN[act]
if "gelu" not in act if "gelu" not in act
...@@ -207,32 +210,23 @@ class LlamaMLP(nn.Module): ...@@ -207,32 +210,23 @@ class LlamaMLP(nn.Module):
else "none", else "none",
) )
) )
if process_group is None:
# Fuse gate and up proj
self.gate_up_proj = FastLinear(
hidden_size, 2 * intermediate_size, bias=False
)
self.down_proj = FastLinear(intermediate_size, hidden_size, bias=False)
self.intermediate_size = intermediate_size
else:
# Fuse gate and up proj # Fuse gate and up proj
self.gate_up_proj = TensorParallelColumnLinear( self.gate_up_proj = TensorParallelColumnLinear.load_multi(
hidden_size, config,
2 * intermediate_size, prefixes=[f"{prefix}.gate_proj", f"{prefix}.up_proj"],
weights=weights,
dim=0,
bias=False, bias=False,
process_group=process_group,
) )
self.down_proj = TensorParallelRowLinear( self.down_proj = TensorParallelRowLinear.load(
intermediate_size, config,
hidden_size, prefix=f"{prefix}.down_proj",
weights=weights,
bias=False, bias=False,
process_group=process_group,
reduce=True,
) )
self.intermediate_size = self.down_proj.in_features self.intermediate_size = (
config.intermediate_size // weights.process_group.size()
self.process_group = process_group )
def forward(self, hidden_states): def forward(self, hidden_states):
gate_up_states = self.gate_up_proj(hidden_states) gate_up_states = self.gate_up_proj(hidden_states)
...@@ -241,22 +235,22 @@ class LlamaMLP(nn.Module): ...@@ -241,22 +235,22 @@ class LlamaMLP(nn.Module):
class FlashLlamaLayer(nn.Module): class FlashLlamaLayer(nn.Module):
def __init__( def __init__(self, layer_id, config, weights):
self,
num_heads,
act,
hidden_size,
intermediate_size,
rms_norm_eps,
process_group=None,
):
super().__init__() super().__init__()
prefix = f"model.layers.{layer_id}"
self.self_attn = FlashLlamaAttention(
prefix=f"{prefix}.self_attn", config=config, weights=weights
)
self.mlp = LlamaMLP(prefix=f"{prefix}.mlp", config=config, weights=weights)
self.self_attn = FlashLlamaAttention(num_heads, hidden_size, process_group) self.input_layernorm = LlamaRMSNorm(
self.mlp = LlamaMLP(act, hidden_size, intermediate_size, process_group) prefix=f"{prefix}.input_layernorm", weights=weights, eps=config.rms_norm_eps
)
self.input_layernorm = LlamaRMSNorm(hidden_size, eps=rms_norm_eps) self.post_attention_layernorm = LlamaRMSNorm(
self.post_attention_layernorm = LlamaRMSNorm(hidden_size, eps=rms_norm_eps) prefix=f"{prefix}.post_attention_layernorm",
weights=weights,
eps=config.rms_norm_eps,
)
def forward( def forward(
self, self,
...@@ -295,54 +289,35 @@ class FlashLlamaLayer(nn.Module): ...@@ -295,54 +289,35 @@ class FlashLlamaLayer(nn.Module):
class FlashLlamaModel(torch.nn.Module): class FlashLlamaModel(torch.nn.Module):
def __init__(self, config, process_group=None): def __init__(self, config, weights):
super(FlashLlamaModel, self).__init__() super().__init__()
self.config = config self.config = config
self.tp_embeddings = False process_group = weights.process_group
if process_group is not None:
self.tp_rank = process_group.rank() self.tp_rank = process_group.rank()
self.tp_world_size = process_group.size() self.tp_world_size = process_group.size()
if config.vocab_size % self.tp_world_size == 0:
self.tp_embeddings = True
if self.tp_embeddings:
self.embed_tokens = TensorParallelEmbedding( self.embed_tokens = TensorParallelEmbedding(
config.vocab_size, config.hidden_size, process_group=process_group prefix="model.embed_tokens", weights=weights
) )
else:
self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size)
self.layers = nn.ModuleList( self.layers = nn.ModuleList(
[ [
FlashLlamaLayer( FlashLlamaLayer(
config.num_attention_heads, layer_id,
config.hidden_act, config,
config.hidden_size, weights,
config.intermediate_size,
config.rms_norm_eps,
process_group,
) )
for _ in range(config.num_hidden_layers) for layer_id in range(config.num_hidden_layers)
] ]
) )
self.norm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps) self.norm = LlamaRMSNorm(
prefix="model.norm", weights=weights, eps=config.rms_norm_eps
)
self.gradient_checkpointing = False self.gradient_checkpointing = False
self.head_size = self.layers[0].self_attn.head_size self.head_size = self.layers[0].self_attn.head_size
self.num_heads = self.layers[0].self_attn.num_heads self.num_heads = self.layers[0].self_attn.num_heads
def post_load_weights(self, quantize: Optional[str] = None):
if isinstance(self.embed_tokens, TensorParallelEmbedding):
self.embed_tokens.add_null_idx()
for layer in self.layers:
layer: FlashLlamaLayer
layer.self_attn.query_key_value.prepare_weights(quantize)
layer.self_attn.o_proj.prepare_weights(quantize)
layer.mlp.gate_up_proj.prepare_weights(quantize)
layer.mlp.down_proj.prepare_weights(quantize)
def forward( def forward(
self, self,
input_ids, input_ids,
...@@ -410,29 +385,15 @@ class FlashLlamaModel(torch.nn.Module): ...@@ -410,29 +385,15 @@ class FlashLlamaModel(torch.nn.Module):
class FlashLlamaForCausalLM(torch.nn.Module): class FlashLlamaForCausalLM(torch.nn.Module):
def __init__(self, config, process_group=None): def __init__(self, config, weights):
super().__init__() super().__init__()
self.process_group = process_group self.model = FlashLlamaModel(config, weights)
if self.process_group is not None: self.lm_head = TensorParallelHead.load(
self.world_size = self.process_group.size() config,
else: prefix="lm_head",
self.world_size = 1 weights=weights,
self.model = FlashLlamaModel(config, process_group)
if self.model.tp_embeddings:
self.lm_head = FastLinear(
config.hidden_size,
config.vocab_size // process_group.size(),
bias=False,
) )
else:
self.lm_head = FastLinear(config.hidden_size, config.vocab_size, bias=False)
def post_load_weights(self, quantize: Optional[str] = None):
self.model.post_load_weights(quantize)
self.lm_head.prepare_weights()
def forward( def forward(
self, self,
...@@ -457,12 +418,4 @@ class FlashLlamaForCausalLM(torch.nn.Module): ...@@ -457,12 +418,4 @@ class FlashLlamaForCausalLM(torch.nn.Module):
if lm_head_indices is not None: if lm_head_indices is not None:
hidden_states = hidden_states[lm_head_indices] hidden_states = hidden_states[lm_head_indices]
logits = self.lm_head(hidden_states) logits = self.lm_head(hidden_states)
if self.model.tp_embeddings:
# Logits are sharded, so we need to gather them
world_logits = [torch.empty_like(logits) for _ in range(self.world_size)]
torch.distributed.all_gather(world_logits, logits, group=self.process_group)
world_logits = torch.cat(world_logits, dim=1)
return world_logits, present
return logits, present return logits, present
...@@ -31,61 +31,81 @@ from typing import Optional ...@@ -31,61 +31,81 @@ from typing import Optional
import flash_attn_cuda import flash_attn_cuda
from text_generation_server.utils.layers import ( from text_generation_server.utils.layers import (
FastLinear,
TensorParallelRowLinear, TensorParallelRowLinear,
TensorParallelColumnLinear, TensorParallelColumnLinear,
TensorParallelEmbedding, TensorParallelEmbedding,
TensorParallelHead,
FastLayerNorm, FastLayerNorm,
PositionRotaryEmbedding, PositionRotaryEmbedding,
get_linear,
) )
class FlashNeoxAttention(torch.nn.Module): def load_row(config, prefix: str, weights, bias: bool):
def __init__( weight = weights.get_sharded(f"{prefix}.weight", dim=1)
self, if bias and weights.process_group.rank() == 0:
# Rank is only on the first rank process
bias = weights.get_tensor(f"{prefix}.bias")
else:
bias = None
linear = get_linear(weight, bias, config.quantize)
if config.use_parallel_residual:
return linear
else:
return TensorParallelRowLinear(linear, process_group=weights.process_group)
def load_qkv(config, prefix: str, weights, num_heads, head_size, hidden_size):
weight = weights.get_sharded(f"{prefix}.weight", dim=0)
bias = weights.get_sharded(f"{prefix}.bias", dim=0)
weight = (
weight.view(
num_heads, num_heads,
3,
head_size,
hidden_size, hidden_size,
rotary_pct, )
rotary_emb_base, .permute(1, 0, 2, 3)
process_group=None, .reshape(-1, hidden_size)
reduce=True, )
): bias = bias.view(num_heads, 3, head_size).permute(1, 0, 2).reshape(-1)
linear = get_linear(weight, bias, config.quantize)
if config.use_parallel_residual:
return linear
else:
return TensorParallelColumnLinear(linear)
class FlashNeoxAttention(torch.nn.Module):
def __init__(self, config, prefix, weights):
super().__init__() super().__init__()
num_heads = config.num_attention_heads
hidden_size = config.hidden_size
self.num_heads = num_heads self.num_heads = num_heads
self.hidden_size = hidden_size self.hidden_size = hidden_size
self.head_size = hidden_size // num_heads self.head_size = hidden_size // num_heads
self.num_heads = self.num_heads // weights.process_group.size()
rotary_ndims = int(self.head_size * rotary_pct) self.rotary_emb = PositionRotaryEmbedding.load(
self.rotary_emb = PositionRotaryEmbedding(rotary_ndims, base=rotary_emb_base) prefix=f"{prefix}.rotary_emb", weights=weights
self.softmax_scale = self.head_size ** (-0.5)
if process_group is None:
self.query_key_value = FastLinear(hidden_size, 3 * hidden_size)
self.dense = FastLinear(hidden_size, hidden_size)
else:
self.num_heads = self.num_heads // process_group.size()
self.query_key_value = TensorParallelColumnLinear(
hidden_size,
3 * hidden_size,
process_group=process_group,
)
self.dense = TensorParallelRowLinear(
hidden_size, hidden_size, process_group=process_group, reduce=reduce
) )
def shuffle_qkv_dims(self): self.softmax_scale = self.head_size ** (-0.5)
"""Swap dims to avoid an additional permute"""
self.query_key_value.weight = torch.nn.Parameter( self.query_key_value = load_qkv(
self.query_key_value.weight.view( config,
self.num_heads, 3, self.head_size, self.hidden_size prefix=f"{prefix}.query_key_value",
) weights=weights,
.permute(1, 0, 2, 3) num_heads=self.num_heads,
.reshape(-1, self.hidden_size) head_size=self.head_size,
hidden_size=self.hidden_size,
) )
self.query_key_value.bias = torch.nn.Parameter( self.dense = load_row(
self.query_key_value.bias.view(self.num_heads, 3, self.head_size) config, prefix=f"{prefix}.dense", weights=weights, bias=True
.permute(1, 0, 2)
.reshape(-1)
) )
def forward( def forward(
...@@ -162,10 +182,9 @@ class FlashNeoxAttention(torch.nn.Module): ...@@ -162,10 +182,9 @@ class FlashNeoxAttention(torch.nn.Module):
class FlashMLP(nn.Module): class FlashMLP(nn.Module):
def __init__( def __init__(self, config, prefix, weights):
self, act, hidden_size, intermediate_size, process_group=None, reduce=True
):
super().__init__() super().__init__()
act = config.hidden_act
self.act = ( self.act = (
ACT2FN[act] ACT2FN[act]
if "gelu" not in act if "gelu" not in act
...@@ -177,22 +196,12 @@ class FlashMLP(nn.Module): ...@@ -177,22 +196,12 @@ class FlashMLP(nn.Module):
) )
) )
if process_group is None: self.dense_h_to_4h = TensorParallelColumnLinear.load(
self.dense_h_to_4h = FastLinear(hidden_size, intermediate_size) config, prefix=f"{prefix}.dense_h_to_4h", weights=weights, bias=True
self.dense_4h_to_h = FastLinear(intermediate_size, hidden_size)
else:
self.dense_h_to_4h = TensorParallelColumnLinear(
hidden_size,
intermediate_size,
process_group=process_group,
) )
self.dense_4h_to_h = TensorParallelRowLinear( self.dense_4h_to_h = load_row(
intermediate_size, config, prefix=f"{prefix}.dense_4h_to_h", weights=weights, bias=True
hidden_size,
process_group=process_group,
reduce=reduce,
) )
self.process_group = process_group
def forward(self, hidden_states): def forward(self, hidden_states):
hidden_states = self.dense_h_to_4h(hidden_states) hidden_states = self.dense_h_to_4h(hidden_states)
...@@ -202,38 +211,28 @@ class FlashMLP(nn.Module): ...@@ -202,38 +211,28 @@ class FlashMLP(nn.Module):
class FlashNeoXLayer(nn.Module): class FlashNeoXLayer(nn.Module):
def __init__( def __init__(self, layer_id, config, weights):
self,
num_heads,
act,
hidden_size,
intermediate_size,
rotary_pct,
rotary_emb_base,
layer_norm_eps,
use_parallel_residual,
process_group=None,
):
super().__init__() super().__init__()
self.use_parallel_residual = use_parallel_residual
self.input_layernorm = FastLayerNorm(hidden_size, eps=layer_norm_eps) layer_norm_eps = config.layer_norm_eps
self.post_attention_layernorm = FastLayerNorm(hidden_size, eps=layer_norm_eps)
self.attention = FlashNeoxAttention( prefix = f"gpt_neox.layers.{layer_id}"
num_heads,
hidden_size, self.use_parallel_residual = config.use_parallel_residual
rotary_pct, self.input_layernorm = FastLayerNorm.load(
rotary_emb_base, prefix=f"{prefix}.input_layernorm", weights=weights, eps=layer_norm_eps
process_group,
reduce=not use_parallel_residual,
) )
self.mlp = FlashMLP( self.post_attention_layernorm = FastLayerNorm.load(
act, prefix=f"{prefix}.post_attention_layernorm",
hidden_size, weights=weights,
intermediate_size, eps=layer_norm_eps,
process_group,
reduce=not use_parallel_residual,
) )
self.process_group = process_group self.attention = FlashNeoxAttention(
config, prefix=f"{prefix}.attention", weights=weights
)
self.mlp = FlashMLP(config, prefix=f"{prefix}.mlp", weights=weights)
self.process_group = weights.process_group
def forward( def forward(
self, self,
...@@ -266,8 +265,6 @@ class FlashNeoXLayer(nn.Module): ...@@ -266,8 +265,6 @@ class FlashNeoXLayer(nn.Module):
mlp_output = self.mlp(ln2_hidden_states) mlp_output = self.mlp(ln2_hidden_states)
intermediate = mlp_output + attn_output intermediate = mlp_output + attn_output
# Only reduce once and after the addition instead of once per layer
if self.process_group is not None:
torch.distributed.all_reduce(intermediate, group=self.process_group) torch.distributed.all_reduce(intermediate, group=self.process_group)
return intermediate + hidden_states, None return intermediate + hidden_states, None
...@@ -302,42 +299,24 @@ class FlashGPTNeoXPreTrainedModel(PreTrainedModel): ...@@ -302,42 +299,24 @@ class FlashGPTNeoXPreTrainedModel(PreTrainedModel):
class FlashGPTNeoXModel(FlashGPTNeoXPreTrainedModel): class FlashGPTNeoXModel(FlashGPTNeoXPreTrainedModel):
def __init__(self, config, process_group=None): def __init__(self, config, weights):
super().__init__(config) super().__init__(config)
self.config = config self.config = config
self.tp_embeddings = False
if process_group is not None:
self.tp_rank = process_group.rank()
self.tp_world_size = process_group.size()
if config.vocab_size % self.tp_world_size == 0:
self.tp_embeddings = True
if self.tp_embeddings:
self.embed_in = TensorParallelEmbedding( self.embed_in = TensorParallelEmbedding(
config.vocab_size, config.hidden_size, process_group=process_group prefix="gpt_neox.embed_in", weights=weights
) )
else:
self.embed_in = nn.Embedding(config.vocab_size, config.hidden_size)
self.layers = nn.ModuleList( self.layers = nn.ModuleList(
[ [
FlashNeoXLayer( FlashNeoXLayer(layer_id, config, weights)
config.num_attention_heads, for layer_id in range(config.num_hidden_layers)
config.hidden_act,
config.hidden_size,
config.intermediate_size,
config.rotary_pct,
config.rotary_emb_base,
config.layer_norm_eps,
config.use_parallel_residual,
process_group,
)
for _ in range(config.num_hidden_layers)
] ]
) )
self.final_layer_norm = FastLayerNorm( self.final_layer_norm = FastLayerNorm.load(
config.hidden_size, eps=config.layer_norm_eps prefix="gpt_neox.final_layer_norm",
weights=weights,
eps=config.layer_norm_eps,
) )
self.gradient_checkpointing = False self.gradient_checkpointing = False
...@@ -345,29 +324,6 @@ class FlashGPTNeoXModel(FlashGPTNeoXPreTrainedModel): ...@@ -345,29 +324,6 @@ class FlashGPTNeoXModel(FlashGPTNeoXPreTrainedModel):
self.head_size = self.layers[0].attention.head_size self.head_size = self.layers[0].attention.head_size
self.num_heads = self.layers[0].attention.num_heads self.num_heads = self.layers[0].attention.num_heads
def post_load_weights(self, quantize: Optional[str] = None):
if isinstance(self.embed_in, TensorParallelEmbedding):
self.embed_in.add_null_idx()
for layer in self.layers:
layer: FlashNeoXLayer
layer.attention.shuffle_qkv_dims()
layer.attention.query_key_value.prepare_weights(quantize)
layer.attention.dense.prepare_weights(quantize)
layer.mlp.dense_h_to_4h.prepare_weights(quantize)
layer.mlp.dense_4h_to_h.prepare_weights(quantize)
@classmethod
def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs):
# Pop here as we will replace the layer in our own logic and don't want from_pretrained
# to do it for us
load_in_8bit = kwargs.pop("load_in_8bit", False)
model = super(FlashGPTNeoXModel, cls).from_pretrained(
pretrained_model_name_or_path, load_in_8bit=False, *model_args, **kwargs
)
model.post_load_weights("bitsandbytes" if load_in_8bit else None)
return model
def forward( def forward(
self, self,
input_ids, input_ids,
...@@ -435,43 +391,14 @@ class FlashGPTNeoXModel(FlashGPTNeoXPreTrainedModel): ...@@ -435,43 +391,14 @@ class FlashGPTNeoXModel(FlashGPTNeoXPreTrainedModel):
class FlashGPTNeoXForCausalLM(FlashGPTNeoXPreTrainedModel): class FlashGPTNeoXForCausalLM(FlashGPTNeoXPreTrainedModel):
def __init__(self, config, process_group=None): def __init__(self, config, weights):
super().__init__(config) super().__init__(config)
self.gpt_neox = FlashGPTNeoXModel(config, weights)
self.process_group = process_group self.embed_out = TensorParallelHead.load(
if self.process_group is not None: config, prefix="embed_out", weights=weights
self.world_size = self.process_group.size()
else:
self.world_size = 1
self.gpt_neox = FlashGPTNeoXModel(config, process_group)
if self.gpt_neox.tp_embeddings:
self.embed_out = FastLinear(
config.hidden_size,
config.vocab_size // process_group.size(),
bias=False,
)
else:
self.embed_out = FastLinear(
config.hidden_size, config.vocab_size, bias=False
) )
def post_load_weights(self, quantize: Optional[str] = None):
self.gpt_neox.post_load_weights(quantize)
self.embed_out.prepare_weights()
@classmethod
def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs):
# Pop here as we will replace the layer in our own logic and don't want from_pretrained
# to do it for us
load_in_8bit = kwargs.pop("load_in_8bit", False)
model = super(FlashGPTNeoXForCausalLM, cls).from_pretrained(
pretrained_model_name_or_path, load_in_8bit=False, *model_args, **kwargs
)
model.post_load_weights("bitsandbytes" if load_in_8bit else None)
return model
def forward( def forward(
self, self,
input_ids, input_ids,
...@@ -495,12 +422,4 @@ class FlashGPTNeoXForCausalLM(FlashGPTNeoXPreTrainedModel): ...@@ -495,12 +422,4 @@ class FlashGPTNeoXForCausalLM(FlashGPTNeoXPreTrainedModel):
if lm_head_indices is not None: if lm_head_indices is not None:
hidden_states = hidden_states[lm_head_indices] hidden_states = hidden_states[lm_head_indices]
logits = self.embed_out(hidden_states) logits = self.embed_out(hidden_states)
if self.gpt_neox.tp_embeddings:
# Logits are sharded, so we need to gather them
world_logits = [torch.empty_like(logits) for _ in range(self.world_size)]
torch.distributed.all_gather(world_logits, logits, group=self.process_group)
world_logits = torch.cat(world_logits, dim=1)
return world_logits, present
return logits, present return logits, present
import os
import torch import torch
import torch.distributed import torch.distributed
...@@ -12,15 +10,31 @@ from typing import Optional ...@@ -12,15 +10,31 @@ from typing import Optional
import flash_attn_cuda import flash_attn_cuda
from text_generation_server.utils.layers import ( from text_generation_server.utils.layers import (
FastLinear,
TensorParallelRowLinear, TensorParallelRowLinear,
TensorParallelColumnLinear, TensorParallelColumnLinear,
TensorParallelEmbedding, TensorParallelEmbedding,
TensorParallelHead,
FastLayerNorm, FastLayerNorm,
PositionRotaryEmbedding, PositionRotaryEmbedding,
get_linear,
) )
def load_row(config, prefix: str, weights, bias: bool):
weight = weights.get_sharded(f"{prefix}.weight", dim=1)
if bias and weights.process_group.rank() == 0:
# Rank is only on the first rank process
bias = weights.get_tensor(f"{prefix}.bias")
else:
bias = None
linear = get_linear(weight, bias, config.quantize)
if config.parallel_attn:
return linear
else:
return TensorParallelRowLinear(linear, process_group=weights.process_group)
class RWConfig(PretrainedConfig): class RWConfig(PretrainedConfig):
attribute_map = { attribute_map = {
"num_hidden_layers": "n_layer", "num_hidden_layers": "n_layer",
...@@ -85,44 +99,31 @@ class RWConfig(PretrainedConfig): ...@@ -85,44 +99,31 @@ class RWConfig(PretrainedConfig):
class FlashRWAttention(torch.nn.Module): class FlashRWAttention(torch.nn.Module):
def __init__( def __init__(
self, self,
num_heads, config,
num_heads_kv, prefix,
hidden_size, weights,
bias,
process_group=None,
reduce=True,
): ):
super().__init__() super().__init__()
self.num_heads = num_heads self.num_heads = config.n_head
self.num_heads_kv = num_heads_kv self.num_heads_kv = config.n_head_kv
self.hidden_size = hidden_size self.hidden_size = config.hidden_size
self.head_size = hidden_size // num_heads self.head_size = self.hidden_size // self.num_heads
self.rotary_emb = PositionRotaryEmbedding(self.head_size, base=10000) self.rotary_emb = PositionRotaryEmbedding.static(
dim=self.head_size, base=10000.0, device=weights.device
)
self.softmax_scale = self.head_size ** (-0.5) self.softmax_scale = self.head_size ** (-0.5)
self.num_heads = self.num_heads // weights.process_group.size()
if process_group is None: self.query_key_value = TensorParallelColumnLinear.load(
self.query_key_value = FastLinear( config,
hidden_size, prefix=f"{prefix}.query_key_value",
self.head_size * (self.num_heads + 2 * self.num_heads_kv), weights=weights,
bias=bias, bias=config.bias,
) )
self.dense = FastLinear(hidden_size, hidden_size, bias=bias) self.dense = load_row(
else: config, prefix=f"{prefix}.dense", weights=weights, bias=config.bias
self.query_key_value = TensorParallelColumnLinear(
hidden_size,
self.head_size * (self.num_heads + 2 * self.num_heads_kv),
bias=bias,
process_group=process_group,
)
self.dense = TensorParallelRowLinear(
hidden_size,
hidden_size,
bias=bias,
process_group=process_group,
reduce=reduce,
) )
self.num_heads = self.num_heads // process_group.size()
def forward( def forward(
self, self,
...@@ -212,58 +213,49 @@ class FlashRWAttention(torch.nn.Module): ...@@ -212,58 +213,49 @@ class FlashRWAttention(torch.nn.Module):
class FlashRWLargeAttention(torch.nn.Module): class FlashRWLargeAttention(torch.nn.Module):
def __init__( def __init__(
self, self,
num_heads, config,
num_heads_kv, prefix,
hidden_size, weights,
bias,
process_group=None,
reduce=True,
): ):
super().__init__() super().__init__()
hidden_size = config.hidden_size
num_heads = config.n_head
num_heads_kv = config.n_head_kv
self.hidden_size = hidden_size self.hidden_size = hidden_size
self.head_size = hidden_size // num_heads self.head_size = hidden_size // num_heads
self.rotary_emb = PositionRotaryEmbedding(self.head_size, base=10000) self.rotary_emb = PositionRotaryEmbedding.static(
self.head_size, base=10000.0, device=weights.device
)
self.softmax_scale = self.head_size ** (-0.5) self.softmax_scale = self.head_size ** (-0.5)
self.num_groups = num_heads // (num_heads_kv * 2) self.num_groups = num_heads // (num_heads_kv * 2)
self.num_heads = num_heads // self.num_groups self.num_heads = num_heads // self.num_groups
self.num_heads_kv = num_heads_kv // self.num_groups self.num_heads_kv = num_heads_kv // self.num_groups
process_group = weights.process_group
if process_group is None:
self.query_key_value = FastLinear(
hidden_size,
self.num_groups
* self.head_size
* (self.num_heads + 2 * self.num_heads_kv),
bias=bias,
)
self.dense = FastLinear(hidden_size, hidden_size, bias=bias)
else:
if process_group.size() > self.num_groups: if process_group.size() > self.num_groups:
raise NotImplementedError( raise NotImplementedError(
f"Tensor Parallelism is not implemented for world_size > n groups" f"Tensor Parallelism is not implemented for world_size > n groups"
) )
if self.num_groups % process_group.size() != 0:
raise NotImplementedError(
f"Tensor Parallelism is not implemented for {self.num_groups} not divisible by {process_group.size()}"
)
self.num_groups = self.num_groups // process_group.size()
self.query_key_value = TensorParallelColumnLinear( self.query_key_value = TensorParallelColumnLinear.load(
hidden_size, config,
self.num_groups prefix=f"{prefix}.query_key_value",
* self.head_size weights=weights,
* (self.num_heads + 2 * self.num_heads_kv), bias=config.bias,
bias=bias,
process_group=process_group,
) )
self.dense = TensorParallelRowLinear( self.dense = load_row(
hidden_size, config, prefix=f"{prefix}.dense", weights=weights, bias=config.bias
hidden_size,
bias=bias,
process_group=process_group,
reduce=reduce,
) )
self.num_groups = self.num_groups // process_group.size()
def forward( def forward(
self, self,
hidden_states, hidden_states,
...@@ -359,28 +351,16 @@ class FlashRWLargeAttention(torch.nn.Module): ...@@ -359,28 +351,16 @@ class FlashRWLargeAttention(torch.nn.Module):
class FlashMLP(nn.Module): class FlashMLP(nn.Module):
def __init__(self, hidden_size, bias, process_group=None, reduce=True): def __init__(self, config, prefix, weights):
super().__init__() super().__init__()
self.act = torch.nn.functional.gelu self.act = torch.nn.functional.gelu
if process_group is None: self.dense_h_to_4h = TensorParallelColumnLinear.load(
self.dense_h_to_4h = FastLinear(hidden_size, 4 * hidden_size, bias=bias) config, prefix=f"{prefix}.dense_h_to_4h", weights=weights, bias=config.bias
self.dense_4h_to_h = FastLinear(4 * hidden_size, hidden_size, bias=bias)
else:
self.dense_h_to_4h = TensorParallelColumnLinear(
hidden_size,
4 * hidden_size,
bias=bias,
process_group=process_group,
) )
self.dense_4h_to_h = TensorParallelRowLinear( self.dense_4h_to_h = load_row(
4 * hidden_size, config, prefix=f"{prefix}.dense_4h_to_h", weights=weights, bias=config.bias
hidden_size,
bias=bias,
process_group=process_group,
reduce=reduce,
) )
self.process_group = process_group
def forward(self, hidden_states): def forward(self, hidden_states):
hidden_states = self.dense_h_to_4h(hidden_states) hidden_states = self.dense_h_to_4h(hidden_states)
...@@ -392,38 +372,44 @@ class FlashMLP(nn.Module): ...@@ -392,38 +372,44 @@ class FlashMLP(nn.Module):
class FlashRWLayer(nn.Module): class FlashRWLayer(nn.Module):
def __init__( def __init__(
self, self,
num_heads, layer_id,
num_heads_kv, config,
hidden_size, weights,
bias,
layer_norm_eps,
parallel_attn,
process_group=None,
): ):
super().__init__() super().__init__()
parallel_attn = config.parallel_attn
self.parallel_attn = parallel_attn self.parallel_attn = parallel_attn
self.input_layernorm = FastLayerNorm(hidden_size, eps=layer_norm_eps) prefix = f"transformer.h.{layer_id}"
self.input_layernorm = FastLayerNorm.load(
prefix=f"{prefix}.input_layernorm",
weights=weights,
eps=config.layer_norm_epsilon,
)
self.self_attention = FlashRWAttention( self.self_attention = FlashRWAttention(
num_heads, config,
num_heads_kv, prefix=f"{prefix}.self_attention",
hidden_size, weights=weights,
bias,
process_group=process_group,
reduce=False,
) )
self.post_attention_layernorm = ( self.post_attention_layernorm = (
FastLayerNorm(hidden_size, eps=layer_norm_eps) FastLayerNorm.load(
prefix=f"{prefix}.post_attention_layernorm",
weights=weights,
eps=config.layer_norm_epsilon,
)
if not parallel_attn if not parallel_attn
else None else None
) )
self.mlp = FlashMLP( self.mlp = FlashMLP(
hidden_size, bias, process_group=process_group, reduce=False config,
prefix=f"{prefix}.mlp",
weights=weights,
) )
self.process_group = process_group self.process_group = weights.process_group
def forward( def forward(
self, self,
...@@ -454,8 +440,6 @@ class FlashRWLayer(nn.Module): ...@@ -454,8 +440,6 @@ class FlashRWLayer(nn.Module):
mlp_output = self.mlp(ln_hidden_states) mlp_output = self.mlp(ln_hidden_states)
intermediate = mlp_output + attn_output intermediate = mlp_output + attn_output
# Only reduce once and after the addition instead of once per layer
if self.process_group is not None:
torch.distributed.all_reduce(intermediate, group=self.process_group) torch.distributed.all_reduce(intermediate, group=self.process_group)
return intermediate, residual return intermediate, residual
...@@ -483,33 +467,30 @@ class FlashRWLayer(nn.Module): ...@@ -483,33 +467,30 @@ class FlashRWLayer(nn.Module):
class FlashRWLargeLayer(nn.Module): class FlashRWLargeLayer(nn.Module):
def __init__( def __init__(self, layer_id, config, weights):
self,
num_heads,
num_heads_kv,
hidden_size,
bias,
layer_norm_eps,
process_group=None,
):
super().__init__() super().__init__()
self.ln_attn = FastLayerNorm(hidden_size, eps=layer_norm_eps) prefix = f"transformer.h.{layer_id}"
self.ln_mlp = FastLayerNorm(hidden_size, eps=layer_norm_eps) self.ln_attn = FastLayerNorm.load(
prefix=f"{prefix}.ln_attn",
weights=weights,
eps=config.layer_norm_epsilon,
)
self.ln_mlp = FastLayerNorm.load(
prefix=f"{prefix}.ln_mlp",
weights=weights,
eps=config.layer_norm_epsilon,
)
self.self_attention = FlashRWLargeAttention( self.self_attention = FlashRWLargeAttention(
num_heads, config,
num_heads_kv, prefix=f"{prefix}.self_attention",
hidden_size, weights=weights,
bias,
process_group=process_group,
reduce=False,
) )
assert config.parallel_attn, "This version doesn't support non parallel_attn"
self.mlp = FlashMLP( self.mlp = FlashMLP(config, prefix=f"{prefix}.mlp", weights=weights)
hidden_size, bias, process_group=process_group, reduce=False
)
self.process_group = process_group self.process_group = weights.process_group
def forward( def forward(
self, self,
...@@ -543,8 +524,6 @@ class FlashRWLargeLayer(nn.Module): ...@@ -543,8 +524,6 @@ class FlashRWLargeLayer(nn.Module):
intermediate = attn_output + mlp_output intermediate = attn_output + mlp_output
# Only reduce once and after the addition instead of once per layer
if self.process_group is not None:
torch.distributed.all_reduce(intermediate, group=self.process_group) torch.distributed.all_reduce(intermediate, group=self.process_group)
return intermediate, residual return intermediate, residual
...@@ -555,37 +534,18 @@ class FlashRWPreTrainedModel(PreTrainedModel): ...@@ -555,37 +534,18 @@ class FlashRWPreTrainedModel(PreTrainedModel):
class FlashRWModel(FlashRWPreTrainedModel): class FlashRWModel(FlashRWPreTrainedModel):
def __init__(self, config, process_group=None): def __init__(self, config, weights):
super().__init__(config) super().__init__(config)
self.config = config self.config = config
self.tp_embeddings = False
if process_group is not None:
self.tp_rank = process_group.rank()
self.tp_world_size = process_group.size()
if config.vocab_size % self.tp_world_size == 0:
self.tp_embeddings = True
if self.tp_embeddings:
self.word_embeddings = TensorParallelEmbedding( self.word_embeddings = TensorParallelEmbedding(
config.vocab_size, config.hidden_size, process_group=process_group prefix="transformer.word_embeddings", weights=weights
) )
else:
self.word_embeddings = nn.Embedding(config.vocab_size, config.hidden_size)
if config.model_type == "RefinedWebModel": if config.model_type == "RefinedWebModel":
self.h = nn.ModuleList( self.h = nn.ModuleList(
[ [
FlashRWLayer( FlashRWLayer(layer_id, config, weights)
config.n_head, for layer_id in range(config.num_hidden_layers)
config.n_head_kv,
config.hidden_size,
config.bias,
config.layer_norm_epsilon,
config.parallel_attn,
process_group,
)
for _ in range(config.num_hidden_layers)
] ]
) )
self.cache_size = ( self.cache_size = (
...@@ -596,15 +556,8 @@ class FlashRWModel(FlashRWPreTrainedModel): ...@@ -596,15 +556,8 @@ class FlashRWModel(FlashRWPreTrainedModel):
elif config.model_type == "RefinedWeb": elif config.model_type == "RefinedWeb":
self.h = nn.ModuleList( self.h = nn.ModuleList(
[ [
FlashRWLargeLayer( FlashRWLargeLayer(layer_id, config, weights)
config.n_head, for layer_id in range(config.num_hidden_layers)
config.n_head_kv,
config.hidden_size,
config.bias,
config.layer_norm_epsilon,
process_group,
)
for _ in range(config.num_hidden_layers)
] ]
) )
self.cache_size = ( self.cache_size = (
...@@ -617,31 +570,13 @@ class FlashRWModel(FlashRWPreTrainedModel): ...@@ -617,31 +570,13 @@ class FlashRWModel(FlashRWPreTrainedModel):
f"model_type {config.model_type} is not supported." f"model_type {config.model_type} is not supported."
) )
self.ln_f = FastLayerNorm(config.hidden_size, eps=config.layer_norm_epsilon) self.ln_f = FastLayerNorm.load(
prefix="transformer.ln_f",
self.head_size = self.h[0].self_attention.head_size weights=weights,
eps=config.layer_norm_epsilon,
def post_load_weights(self, quantize: Optional[str] = None):
if isinstance(self.word_embeddings, TensorParallelEmbedding):
self.word_embeddings.add_null_idx()
for layer in self.h:
layer: FlashRWLayer
layer.self_attention.query_key_value.prepare_weights(quantize)
layer.self_attention.dense.prepare_weights(quantize)
layer.mlp.dense_h_to_4h.prepare_weights(quantize)
layer.mlp.dense_4h_to_h.prepare_weights(quantize)
@classmethod
def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs):
# Pop here as we will replace the layer in our own logic and don't want from_pretrained
# to do it for us
load_in_8bit = kwargs.pop("load_in_8bit", False)
model = super(FlashRWModel, cls).from_pretrained(
pretrained_model_name_or_path, load_in_8bit=False, *model_args, **kwargs
) )
model.post_load_weights("bitsandbytes" if load_in_8bit else None) self.head_size = self.h[0].self_attention.head_size
return model
def forward( def forward(
self, self,
...@@ -708,40 +643,14 @@ class FlashRWModel(FlashRWPreTrainedModel): ...@@ -708,40 +643,14 @@ class FlashRWModel(FlashRWPreTrainedModel):
class FlashRWForCausalLM(FlashRWPreTrainedModel): class FlashRWForCausalLM(FlashRWPreTrainedModel):
def __init__(self, config, process_group=None): def __init__(self, config, weights):
super().__init__(config) super().__init__(config)
self.process_group = process_group self.transformer = FlashRWModel(config, weights)
if self.process_group is not None:
self.world_size = self.process_group.size()
else:
self.world_size = 1
self.transformer = FlashRWModel(config, process_group)
if self.transformer.tp_embeddings:
self.lm_head = FastLinear(
config.hidden_size,
config.vocab_size // process_group.size(),
bias=False,
)
else:
self.lm_head = FastLinear(config.hidden_size, config.vocab_size, bias=False)
def post_load_weights(self, quantize: Optional[str] = None): self.lm_head = TensorParallelHead.load(
self.transformer.post_load_weights(quantize) config, prefix="lm_head", weights=weights
self.lm_head.prepare_weights()
@classmethod
def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs):
# Pop here as we will replace the layer in our own logic and don't want from_pretrained
# to do it for us
load_in_8bit = kwargs.pop("load_in_8bit", False)
model = super(FlashRWForCausalLM, cls).from_pretrained(
pretrained_model_name_or_path, load_in_8bit=False, *model_args, **kwargs
) )
model.post_load_weights("bitsandbytes" if load_in_8bit else None)
return model
def forward( def forward(
self, self,
...@@ -766,12 +675,4 @@ class FlashRWForCausalLM(FlashRWPreTrainedModel): ...@@ -766,12 +675,4 @@ class FlashRWForCausalLM(FlashRWPreTrainedModel):
if lm_head_indices is not None: if lm_head_indices is not None:
hidden_states = hidden_states[lm_head_indices] hidden_states = hidden_states[lm_head_indices]
logits = self.lm_head(hidden_states) logits = self.lm_head(hidden_states)
if self.transformer.tp_embeddings:
# Logits are sharded, so we need to gather them
world_logits = [torch.empty_like(logits) for _ in range(self.world_size)]
torch.distributed.all_gather(world_logits, logits, group=self.process_group)
world_logits = torch.cat(world_logits, dim=1)
return world_logits, present
return logits, present return logits, present
import torch import torch
import torch.distributed import torch.distributed
from accelerate import init_empty_weights
from opentelemetry import trace from opentelemetry import trace
from safetensors import safe_open
from transformers import AutoTokenizer, AutoConfig from transformers import AutoTokenizer, AutoConfig
from typing import Optional, List from typing import Optional
from text_generation_server.models import FlashCausalLM from text_generation_server.models import FlashCausalLM
from text_generation_server.models.custom_modeling.flash_neox_modeling import ( from text_generation_server.models.custom_modeling.flash_neox_modeling import (
FlashGPTNeoXForCausalLM, FlashGPTNeoXForCausalLM,
TensorParallelEmbedding,
TensorParallelRowLinear,
TensorParallelColumnLinear,
) )
from text_generation_server.utils import ( from text_generation_server.utils import (
initialize_torch_distributed, initialize_torch_distributed,
weight_files, weight_files,
Weights,
) )
tracer = trace.get_tracer(__name__) tracer = trace.get_tracer(__name__)
class FlashNeoX(FlashCausalLM): class FlashNeoXSharded(FlashCausalLM):
def __init__(
self,
model_id: str,
revision: Optional[str] = None,
quantize: Optional[str] = None,
trust_remote_code: bool = False,
):
super(FlashNeoX, self).__init__(
FlashGPTNeoXForCausalLM,
model_id,
revision,
quantize,
trust_remote_code=trust_remote_code,
)
class FlashNeoXSharded(FlashNeoX):
def __init__( def __init__(
self, self,
model_id: str, model_id: str,
...@@ -65,23 +44,16 @@ class FlashNeoXSharded(FlashNeoX): ...@@ -65,23 +44,16 @@ class FlashNeoXSharded(FlashNeoX):
config = AutoConfig.from_pretrained( config = AutoConfig.from_pretrained(
model_id, revision=revision, trust_remote_code=trust_remote_code model_id, revision=revision, trust_remote_code=trust_remote_code
) )
config.quantize = quantize
torch.distributed.barrier(group=self.process_group) torch.distributed.barrier(group=self.process_group)
filenames = weight_files(model_id, revision=revision, extension=".safetensors") filenames = weight_files(model_id, revision=revision, extension=".safetensors")
weights = Weights(
filenames, device=device, dtype=dtype, process_group=self.process_group
)
with init_empty_weights(): model = FlashGPTNeoXForCausalLM(config, weights)
model = FlashGPTNeoXForCausalLM(config, self.process_group)
torch.distributed.barrier(group=self.process_group)
self.load_weights(
model,
filenames,
quantize=quantize,
device=device,
dtype=dtype,
rank=rank,
world_size=world_size,
)
torch.distributed.barrier(group=self.process_group) torch.distributed.barrier(group=self.process_group)
super(FlashCausalLM, self).__init__( super(FlashCausalLM, self).__init__(
model=model.to(device), model=model.to(device),
...@@ -92,79 +64,3 @@ class FlashNeoXSharded(FlashNeoX): ...@@ -92,79 +64,3 @@ class FlashNeoXSharded(FlashNeoX):
rank=rank, rank=rank,
world_size=world_size, world_size=world_size,
) )
@staticmethod
def load_weights(
model,
filenames: List[str],
quantize: Optional[str],
device: torch.device,
dtype: torch.dtype,
rank: int,
world_size: int,
):
parameters = dict(model.named_parameters())
for file in filenames:
with safe_open(
file, framework="pt", device=str(device) if quantize is None else "cpu"
) as f:
for name in f.keys():
module_name, param_name = name.rsplit(".", 1)
module = model.get_submodule(module_name)
current_parameter_tensor = parameters.get(name, None)
slice_ = f.get_slice(name)
if isinstance(module, TensorParallelColumnLinear):
size = slice_.get_shape()[0]
block_size = size // world_size
start = rank * block_size
stop = (rank + 1) * block_size
tensor = slice_[start:stop]
elif isinstance(module, TensorParallelRowLinear):
if param_name == "weight":
size = slice_.get_shape()[1]
block_size = size // world_size
start = rank * block_size
stop = (rank + 1) * block_size
tensor = slice_[:, start:stop]
else:
tensor = slice_[:]
# XXX: Hack for Rowlinear to add the bias only once.
if rank != 0:
tensor = torch.zeros_like(tensor)
elif isinstance(module, TensorParallelEmbedding):
size = slice_.get_shape()[0]
block_size = size // world_size
start = rank * block_size
stop = (rank + 1) * block_size
tensor = slice_[start:stop]
elif name == "embed_out.weight" and model.gpt_neox.tp_embeddings:
size = slice_.get_shape()[0]
block_size = size // world_size
start = rank * block_size
stop = (rank + 1) * block_size
tensor = slice_[start:stop]
else:
try:
tensor = slice_[:]
except:
tensor = f.get_tensor(name)
if (
current_parameter_tensor is not None
and current_parameter_tensor.shape != tensor.shape
):
raise ValueError(
f"Name {name} -- Current {current_parameter_tensor.shape} and got {tensor.shape}"
)
tensor = tensor.contiguous().to(dtype)
if current_parameter_tensor is not None:
module._parameters[param_name] = tensor
else:
module._buffers[param_name] = tensor
model.post_load_weights(quantize)
This diff is collapsed.
This diff is collapsed.
from text_generation_server.utils.convert import convert_file, convert_files from text_generation_server.utils.convert import convert_file, convert_files
from text_generation_server.utils.dist import initialize_torch_distributed from text_generation_server.utils.dist import initialize_torch_distributed
from text_generation_server.utils.weights import Weights
from text_generation_server.utils.hub import ( from text_generation_server.utils.hub import (
weight_files, weight_files,
weight_hub_files, weight_hub_files,
...@@ -35,4 +36,5 @@ __all__ = [ ...@@ -35,4 +36,5 @@ __all__ = [
"StoppingCriteria", "StoppingCriteria",
"StopSequenceCriteria", "StopSequenceCriteria",
"FinishReason", "FinishReason",
"Weights",
] ]
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment