Unverified Commit abd58ff8 authored by Nicolas Patry's avatar Nicolas Patry Committed by GitHub
Browse files

feat(server): Rework model loading (#344)

# What does this PR do?

Reworked the loading logic. Idea is to use cleaner loading code:

- Remove need for `no_init_weights`
- Remove all weird `bnb_linear` and `load_weights` and
`post_load_weights`.

New code layout:

- New class `Weights` in charge of handling loading the weights from
multiple files into appropiate tensors (potentially sharded)
- TP layers now are "shells", they contain the code to know what kind of
sharding we need + eventual `all_reduce`. They do not inherit from
linear, but they contain some kind of Linear instead
- the contained linear can be either FastLinear, BnbLinear or GPTq
Linear next.
- All modeling code is explictly made for sharding, process group is
just no-ops for non sharded code (removes a lot of test cases)

![Screenshot from 2023-05-19
23-19-59](https://github.com/huggingface/text-generation-inference/assets/204321/9a802654-74a3-488c-87a8-073743a6143f

)

---------
Co-authored-by: default avatarUbuntu <ubuntu@ip-172-31-41-161.taildb5d.ts.net>
Co-authored-by: default avatarUbuntu <ubuntu@ip-172-31-41-161.ec2.internal>
Co-authored-by: default avatarOlivierDehaene <olivier@huggingface.co>
Co-authored-by: default avatarOlivierDehaene <23298448+OlivierDehaene@users.noreply.github.com>
parent 19c41824
......@@ -4,6 +4,37 @@ import torch
from datetime import timedelta
class FakeBarrier:
def wait(self):
pass
class FakeGroup:
def __init__(self, rank, size):
self._rank = rank
self._size = size
def allreduce(self, *args, **kwargs):
return FakeBarrier()
def allgather(self, inputs, local_tensor, **kwargs):
assert (
len(inputs[0]) == len(local_tensor) == 1
), f"{len(inputs[0])} != {len(local_tensor)} != 1, and the FakeGroup is supposed to join on simple tensors"
for input_ in inputs:
input_[0].data = local_tensor[0].data
return FakeBarrier()
def barrier(self, *args, **kwargs):
return FakeBarrier()
def size(self):
return self._size
def rank(self):
return self._rank
def initialize_torch_distributed():
rank = int(os.getenv("RANK", "0"))
world_size = int(os.getenv("WORLD_SIZE", "1"))
......@@ -23,13 +54,18 @@ def initialize_torch_distributed():
backend = "gloo"
options = None
# Call the init process.
torch.distributed.init_process_group(
backend=backend,
world_size=world_size,
rank=rank,
timeout=timedelta(seconds=60),
pg_options=options,
)
if world_size == 1:
return FakeGroup(rank, world_size), rank, world_size
else:
if os.getenv("DEBUG", None) == "1":
return FakeGroup(rank, world_size), rank, world_size
# Call the init process.
torch.distributed.init_process_group(
backend=backend,
world_size=world_size,
rank=rank,
timeout=timedelta(seconds=60),
pg_options=options,
)
return torch.distributed.group.WORLD, rank, world_size
return torch.distributed.group.WORLD, rank, world_size
import torch
import torch.distributed
from torch import nn
from torch.nn import functional as F
from typing import Optional
from typing import List
HAS_BITS_AND_BYTES = True
try:
from bitsandbytes.nn import Linear8bitLt
except ImportError as e:
import bitsandbytes as bnb
from bitsandbytes.nn import Int8Params
except ImportError:
HAS_BITS_AND_BYTES = False
from accelerate import init_empty_weights
# Monkey patching
@classmethod
def load_layer_norm(cls, prefix, weights, eps):
weight = weights.get_tensor(f"{prefix}.weight")
bias = weights.get_tensor(f"{prefix}.bias")
with init_empty_weights():
ln = cls(weight.shape, eps=eps)
ln.weight = nn.Parameter(weight)
ln.bias = nn.Parameter(bias)
return ln
torch.nn.LayerNorm.load = load_layer_norm
class FastLinear(nn.Linear):
class FastLinear(nn.Module):
def __init__(
self,
in_features: int,
out_features: int,
bias: bool = True,
device=None,
dtype=None,
weight,
bias,
) -> None:
super(FastLinear, self).__init__(in_features, out_features, bias, device, dtype)
self.quantized = False
self.bnb_linear = None
def prepare_weights(self, quantize: Optional[str] = None):
if quantize == "bitsandbytes":
if not HAS_BITS_AND_BYTES:
raise ImportError(
"bitsandbytes is not available on your machine either because it is not installed "
"or you don't have a GPU.\n"
"You can install it with `pip install bitsandbytes`."
)
self.quantized = True
self.bnb_linear = Linear8bitLt(
self.in_features,
self.out_features,
has_fp16_weights=False,
threshold=6.0,
bias=False,
)
# Copy data to bnb_linear
self.bnb_linear.weight.data = self.weight.data
if self.bias is not None:
self.bnb_linear.bias = nn.Parameter(self.bias)
# Delete reference to data
self.weight = None
super().__init__()
self.weight = nn.Parameter(weight)
if bias is not None:
self.bias = nn.Parameter(bias)
else:
self.bias = None
elif quantize == "gptq":
raise NotImplementedError("`gptq` is not implemented for now")
elif quantize is None:
self.weight = nn.Parameter(self.weight.T)
@classmethod
def load(cls, config, prefix: str, weights, bias: bool):
weight = weights.get_tensor(f"{prefix}.weight")
if bias:
bias = weights.get_tensor(f"{prefix}.bias")
else:
raise ValueError(f"Unexpected quantize `{quantize}`")
bias = None
return cls(weight, bias)
def forward(self, input: torch.Tensor) -> torch.Tensor:
if self.quantized:
return self.bnb_linear(input)
else:
if self.bias is not None:
return torch.addmm(self.bias, input, self.weight)
return torch.matmul(input, self.weight)
return F.linear(input, self.weight, self.bias)
class TensorParallelColumnLinear(FastLinear):
class Linear8bitLt(nn.Module):
def __init__(
self,
in_features,
out_features,
process_group: torch.distributed.ProcessGroup,
bias=True,
device=None,
dtype=None,
weight,
bias,
has_fp16_weights=True,
memory_efficient_backward=False,
threshold=0.0,
index=None,
):
self.process_group = process_group
self.tp_world_size = process_group.size()
assert out_features % self.tp_world_size == 0
out_features = out_features // self.tp_world_size
super().__init__(
in_features=in_features,
out_features=out_features,
bias=bias,
device=device,
dtype=dtype,
super().__init__()
assert (
not memory_efficient_backward
), "memory_efficient_backward is no longer required and the argument is deprecated in 0.37.0 and will be removed in 0.39.0"
self.state = bnb.MatmulLtState()
self.index = index
# Necessary for stacked layers
self.state.threshold = threshold
self.state.has_fp16_weights = has_fp16_weights
self.state.memory_efficient_backward = memory_efficient_backward
if threshold > 0.0 and not has_fp16_weights:
self.state.use_pool = True
self.weight = Int8Params(
weight.data,
has_fp16_weights=has_fp16_weights,
requires_grad=has_fp16_weights,
)
self.weight.cuda(weight.device)
self.bias = bias
def init_8bit_state(self):
self.state.CB = self.weight.CB
self.state.SCB = self.weight.SCB
self.weight.CB = None
self.weight.SCB = None
def forward(self, x: torch.Tensor):
self.state.is_training = self.training
if self.weight.CB is not None:
self.init_8bit_state()
# weights are cast automatically as Int8Params, but the bias has to be cast manually
if self.bias is not None and self.bias.dtype != x.dtype:
self.bias.data = self.bias.data.to(x.dtype)
out = bnb.matmul(x, self.weight, bias=self.bias, state=self.state)
if not self.state.has_fp16_weights:
if self.state.CB is not None and self.state.CxB is not None:
# we converted 8-bit row major to turing/ampere format in the first inference pass
# we no longer need the row-major weight
del self.state.CB
self.weight.data = self.state.CxB
return out
class TensorParallelRowLinear(FastLinear):
def __init__(
self,
in_features,
out_features,
process_group: torch.distributed.ProcessGroup,
reduce=True,
bias=True,
device=None,
dtype=None,
):
def get_linear(weight, bias, quantize):
if quantize is None:
linear = FastLinear(weight, bias)
elif quantize == "bitsandbytes":
linear = Linear8bitLt(
weight,
bias,
has_fp16_weights=False,
threshold=6.0,
)
if bias is not None:
linear.bias = nn.Parameter(bias)
elif quantize == "gptq":
raise NotImplementedError("Soon")
else:
raise NotImplementedError(f"Quantization `{quantize}` is not implemented yet.")
return linear
class SuperLayer(nn.Module):
def __init__(self, linear):
super().__init__()
self.linear = linear
def forward(self, x):
return self.linear.forward(x)
class TensorParallelHead(SuperLayer):
def __init__(self, linear, process_group):
super().__init__(linear)
self.process_group = process_group
self.tp_world_size = process_group.size()
self.reduce = reduce
assert in_features % self.tp_world_size == 0
in_features = in_features // self.tp_world_size
super().__init__(
in_features=in_features,
out_features=out_features,
bias=bias,
device=device,
dtype=dtype,
@staticmethod
def load(config, prefix: str, weights):
weight = weights.get_sharded(f"{prefix}.weight", dim=0)
return TensorParallelHead(
get_linear(weight, bias=None, quantize=config.quantize),
process_group=weights.process_group,
)
def forward(self, input: torch.Tensor) -> torch.Tensor:
out = super(TensorParallelRowLinear, self).forward(input)
if self.reduce:
torch.distributed.all_reduce(out, group=self.process_group)
output = super().forward(input)
# Logits are sharded, so we need to gather them
world_output = [
torch.empty_like(output) for _ in range(self.process_group.size())
]
torch.distributed.all_gather(world_output, output, group=self.process_group)
world_output = torch.cat(world_output, dim=-1)
return world_output
class TensorParallelColumnLinear(SuperLayer):
@classmethod
def load(cls, config, prefix: str, weights, bias: bool):
weight = weights.get_sharded(f"{prefix}.weight", dim=0)
if bias:
bias = weights.get_sharded(f"{prefix}.bias", dim=0)
else:
bias = None
return cls(get_linear(weight, bias, config.quantize))
return out
@classmethod
def load_multi(cls, config, prefixes: List[str], weights, bias: bool, dim: int):
w = [weights.get_sharded(f"{p}.weight", dim=0) for p in prefixes]
weight = torch.cat(w, dim=dim)
if bias:
b = [weights.get_sharded(f"{p}.bias", dim=0) for p in prefixes]
bias = torch.cat(b, dim=0)
else:
bias = None
return cls(get_linear(weight, bias, config.quantize))
class TensorParallelEmbedding(nn.Embedding):
def __init__(
self,
num_embeddings,
embedding_dim,
process_group: torch.distributed.ProcessGroup,
reduce=True,
padding_idx=None,
max_norm=None,
norm_type=2.0,
scale_grad_by_freq=False,
sparse=False,
_weight=None,
device=None,
dtype=None,
):
self.reduce = reduce
class TensorParallelRowLinear(SuperLayer):
def __init__(self, linear, process_group):
super().__init__(linear)
self.process_group = process_group
self.tp_rank = process_group.rank()
self.tp_world_size = process_group.size()
self.original_num_embeddings = num_embeddings
@classmethod
def load(cls, config, prefix: str, weights, bias: bool):
weight = weights.get_sharded(f"{prefix}.weight", dim=1)
if bias and weights.process_group.rank() == 0:
# Rank is only on the first rank process
bias = weights.get_tensor(f"{prefix}.bias")
else:
bias = None
return cls(
get_linear(weight, bias, config.quantize),
process_group=weights.process_group,
)
assert num_embeddings % self.tp_world_size == 0
block_size = num_embeddings // self.tp_world_size
# inputs in `[min_id, max_id[` are handled by `self` to get embeddings
self.min_id = self.tp_rank * block_size
self.max_id = (self.tp_rank + 1) * block_size
def forward(self, input: torch.Tensor) -> torch.Tensor:
out = super().forward(input)
torch.distributed.all_reduce(out, group=self.process_group)
return out
# Additional entry that will map to zero
# Used for masking
self.null_idx = block_size
super().__init__(
block_size,
embedding_dim,
padding_idx=padding_idx,
max_norm=max_norm,
norm_type=norm_type,
scale_grad_by_freq=scale_grad_by_freq,
sparse=sparse,
_weight=_weight,
device=device,
dtype=dtype,
)
class TensorParallelEmbedding(nn.Module):
def __init__(self, prefix: str, weights, reduce=True):
super().__init__()
weight = weights.get_sharded(f"{prefix}.weight", dim=0)
num_embeddings = weights.get_shape(f"{prefix}.weight")[0]
process_group = weights.process_group
world_size = process_group.size()
rank = process_group.rank()
block_size = num_embeddings // world_size
self.min_id = rank * block_size
self.max_id = min(num_embeddings, (rank + 1) * block_size)
self.null_idx = block_size
self.process_group = weights.process_group
self.reduce = reduce
def add_null_idx(self):
"""Additional 0 entry used for masking"""
self.weight = nn.Parameter(F.pad(self.weight, (0, 0, 0, 1)))
self.weight = nn.Parameter(F.pad(weight, (0, 0, 0, 1)))
def forward(self, input: torch.Tensor) -> torch.Tensor:
# default all out of bounds values to `self.null_idx` that will then be mapped to 0
......@@ -180,7 +244,7 @@ class TensorParallelEmbedding(nn.Embedding):
self.null_idx,
input - self.min_id,
)
out = super().forward(input)
out = torch.nn.functional.embedding(input, self.weight)
if self.reduce:
torch.distributed.all_reduce(out, group=self.process_group)
return out
......@@ -232,7 +296,34 @@ try:
from flash_attn.layers.rotary import RotaryEmbedding
import rotary_emb
class PositionRotaryEmbedding(RotaryEmbedding):
class PositionRotaryEmbedding(nn.Module):
def __init__(self, inv_freq):
super().__init__()
self.register_buffer("inv_freq", inv_freq)
self._seq_len_cached = 0
self._cos_cached = None
self._sin_cached = None
self._cos_k_cached = None
self._sin_k_cached = None
@classmethod
def static(cls, dim, base, device):
inv_freq = 1.0 / (
base
** (torch.arange(0, dim, 2, device=device, dtype=torch.float32) / dim)
)
return cls(inv_freq)
@classmethod
def load(cls, prefix, weights):
# XXX: Always load this in float32 !
dtype = weights.dtype
weights.dtype = torch.float32
inv_freq = weights.get_tensor(f"{prefix}.inv_freq")
weights.dtype = dtype
return cls(inv_freq)
def _update_cos_sin_cache(self, dtype, device, seqlen):
# Reset the tables if the sequence length has changed,
# or if we're on a new device (possibly due to tracing for instance)
......
from pathlib import Path
from typing import List
from safetensors import safe_open
class Weights:
def __init__(self, filenames: List[Path], device, dtype, process_group):
routing = {}
for filename in filenames:
with safe_open(filename, framework="pytorch") as f:
for k in f.keys():
if k in routing:
raise RuntimeError(
f"Key {k} was found in multiple files: {filename} and {routing[k]}"
)
routing[k] = filename
self.routing = routing
self.device = device
self.dtype = dtype
self.process_group = process_group
self._handles = {}
def _get_handle(self, filename):
if filename not in self._handles:
f = safe_open(filename, framework="pytorch")
self._handles[filename] = f
return self._handles[filename]
def get_filename(self, tensor_name: str) -> str:
filename = self.routing.get(tensor_name, None)
if filename is None:
raise RuntimeError(f"weight {tensor_name} does not exist")
return str(filename)
def _get_slice(self, tensor_name: str):
filename = self.get_filename(tensor_name)
f = self._get_handle(filename)
slice_ = f.get_slice(tensor_name)
return slice_
def get_shape(self, tensor_name: str):
return self._get_slice(tensor_name).get_shape()
def get_tensor(self, tensor_name: str):
filename = self.get_filename(tensor_name)
f = self._get_handle(filename)
tensor = f.get_tensor(tensor_name)
tensor = tensor.to(dtype=self.dtype)
tensor = tensor.to(device=self.device)
return tensor
def get_sharded(self, tensor_name: str, dim: int):
filename = self.get_filename(tensor_name)
world_size = self.process_group.size()
rank = self.process_group.rank()
f = self._get_handle(filename)
slice_ = f.get_slice(tensor_name)
size = slice_.get_shape()[dim]
block_size = size // world_size
start = rank * block_size
stop = (rank + 1) * block_size
assert (
size % world_size == 0
), f"The choosen size {size} is not compatible with sharding on {world_size} shards"
if dim == 0:
tensor = slice_[start:stop]
elif dim == 1:
tensor = slice_[:, start:stop]
else:
raise NotImplementedError("Let's make that generic when needed")
tensor = tensor.to(dtype=self.dtype)
tensor = tensor.to(device=self.device)
return tensor
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment