Commit 295831a4 authored by Olivier Dehaene's avatar Olivier Dehaene
Browse files

Init

parents
import torch
from dataclasses import dataclass
from typing import Dict, Optional, List
from bloom_inference.pb import generate_pb2
from bloom_inference.utils import NextTokenChooser, StoppingCriteria
@dataclass
class CacheEntry:
batch_id: int
request_ids: List[int]
input_ids: Dict[str, torch.Tensor]
all_input_ids: List[torch.Tensor]
next_token_choosers: List[NextTokenChooser]
stopping_criterias: List[StoppingCriteria]
def __len__(self):
return len(self.request_ids)
def to_pb(self):
return generate_pb2.CacheEntry(
id=self.batch_id,
request_ids=self.request_ids,
sequence_length=max(len(entry) for entry in self.all_input_ids),
)
class Cache:
def __init__(self):
self.cache: Dict[str, CacheEntry] = {}
def pop(self, batch_id: str) -> Optional[CacheEntry]:
return self.cache.pop(batch_id, None)
def set(self, entry: CacheEntry):
if entry is not None:
self.cache[entry.batch_id] = entry
def delete(self, batch_id: str):
del self.cache[batch_id]
def clear(self):
self.cache.clear()
def __len__(self):
return len(self.cache.keys())
import typer
from pathlib import Path
from torch.distributed.launcher import launch_agent, LaunchConfig
from typing import Optional
from bloom_inference.server import serve
def main(
model_name: str,
num_gpus: int = 1,
shard_directory: Optional[Path] = None,
):
if num_gpus == 1:
serve(model_name, False, shard_directory)
else:
config = LaunchConfig(
min_nodes=1,
max_nodes=1,
nproc_per_node=num_gpus,
rdzv_backend="c10d",
max_restarts=0,
)
launch_agent(config, serve, [model_name, True, shard_directory])
if __name__ == "__main__":
typer.run(main)
import torch
import torch.distributed
from dataclasses import dataclass
from pathlib import Path
from typing import List, Tuple, Optional, Dict
from transformers import AutoTokenizer, AutoModelForCausalLM, AutoConfig
from transformers.modeling_utils import no_init_weights
from bloom_inference.cache import CacheEntry
from bloom_inference.pb import generate_pb2
from bloom_inference.shard_model import shard_model, match_suffix
from bloom_inference.utils import (
StoppingCriteria,
NextTokenChooser,
initialize_torch_distributed,
set_default_dtype,
)
torch.manual_seed(0)
@dataclass
class Batch:
batch_id: int
request_ids: List[int]
input_ids: Dict[str, torch.Tensor]
all_input_ids: List[torch.Tensor]
next_token_choosers: List[NextTokenChooser]
stopping_criterias: List[StoppingCriteria]
@classmethod
def from_batch_pb(
cls, pb: generate_pb2.Batch, tokenizer: AutoTokenizer, device: torch.device
) -> "Batch":
request_ids = []
inputs = []
next_token_choosers = []
stopping_criterias = []
# Parse batch
for r in pb.requests:
request_ids.append(r.id)
inputs.append(r.inputs)
next_token_choosers.append(
NextTokenChooser(
temperature=r.parameters.temperature,
top_k=r.parameters.top_k,
top_p=r.parameters.top_p,
do_sample=r.parameters.do_sample,
)
)
stopping_criterias.append(StoppingCriteria(max_new_tokens=r.max_new_tokens))
input_ids = tokenizer(inputs, return_tensors="pt", padding=True).to(device)
all_input_ids = input_ids["input_ids"].unsqueeze(-1)
return cls(
pb.id,
request_ids,
input_ids,
all_input_ids,
next_token_choosers,
stopping_criterias,
)
@classmethod
def from_cache_entry(cls, cache_entry: CacheEntry) -> "Batch":
return cls(
cache_entry.batch_id,
cache_entry.request_ids,
cache_entry.input_ids,
cache_entry.all_input_ids,
cache_entry.next_token_choosers,
cache_entry.stopping_criterias,
)
@classmethod
def from_batch_cached_pb(cls, pb: generate_pb2.BatchCached, cache) -> "Batch":
if len(pb.batch_cached_ids) == 1:
cache_entry = cache.pop(pb.batch_cached_ids[0])
if cache_entry is None:
raise ValueError(f"Batch ID {pb.batch_id} not found in cache")
return cls.from_cache_entry(cache_entry)
total_batch_size = pb.total_batch_size
max_sequence_length = pb.max_sequence_length
input_ids = {"input_ids": None, "attention_mask": None, "past_key_values": []}
request_ids = []
all_input_ids = []
next_token_choosers = []
stopping_criterias = []
start_index = 0
for i, batch_id in enumerate(pb.batch_cached_ids):
cache_entry = cache.pop(batch_id)
if cache_entry is None:
raise ValueError(f"Batch ID {batch_id} not found in cache")
request_ids.extend(cache_entry.request_ids)
all_input_ids.extend(cache_entry.all_input_ids)
next_token_choosers.extend(cache_entry.next_token_choosers)
stopping_criterias.extend(cache_entry.stopping_criterias)
batch_size = len(cache_entry.request_ids)
end_index = start_index + batch_size
sequence_length = max(len(entry) for entry in cache_entry.all_input_ids)
if input_ids["input_ids"] is None:
input_ids["input_ids"] = torch.empty(
(total_batch_size, 1),
dtype=cache_entry.input_ids["input_ids"].dtype,
device=cache_entry.input_ids["input_ids"].device,
)
input_ids["input_ids"][start_index:end_index] = cache_entry.input_ids[
"input_ids"
]
if input_ids["attention_mask"] is None:
input_ids["attention_mask"] = torch.zeros(
(total_batch_size, max_sequence_length),
dtype=cache_entry.input_ids["attention_mask"].dtype,
device=cache_entry.input_ids["attention_mask"].device,
)
input_ids["attention_mask"][
start_index:end_index, -sequence_length:
] = cache_entry.input_ids["attention_mask"][:, -sequence_length:]
for j, past in enumerate(cache_entry.input_ids["past_key_values"]):
# TODO: this could be done without the views by using indices
past_keys = past[0]
past_values = past[1]
_, head_dim, padded_sequence_length = past_keys.shape
past_keys = past_keys.view(
batch_size, -1, head_dim, padded_sequence_length
)
past_values = past_values.view(
batch_size, -1, padded_sequence_length, head_dim
)
num_heads = past_keys.shape[1]
if j == len(input_ids["past_key_values"]):
padded_past_keys = torch.zeros(
(
total_batch_size,
num_heads,
head_dim,
max_sequence_length - 1,
),
dtype=past_keys.dtype,
device=past_keys.device,
)
padded_past_values = torch.zeros(
(
total_batch_size,
num_heads,
max_sequence_length - 1,
head_dim,
),
dtype=past_values.dtype,
device=past_values.device,
)
input_ids["past_key_values"].append(
[padded_past_keys, padded_past_values]
)
input_ids["past_key_values"][j][0][
start_index:end_index, :, :, -(sequence_length - 1):
] = past_keys[:, :, :, -(sequence_length - 1):]
input_ids["past_key_values"][j][1][
start_index:end_index, :, -(sequence_length - 1):, :
] = past_values[:, :, -(sequence_length - 1):, :]
if (i + 1) == len(pb.batch_cached_ids):
input_ids["past_key_values"][j][0] = input_ids["past_key_values"][
j
][0].view(total_batch_size * num_heads, head_dim, -1)
input_ids["past_key_values"][j][1] = input_ids["past_key_values"][
j
][1].view(total_batch_size * num_heads, -1, head_dim)
start_index += batch_size
assert pb.request_ids == request_ids
return cls(
pb.id,
request_ids,
input_ids,
all_input_ids,
next_token_choosers,
stopping_criterias,
)
@dataclass
class FinishedGeneration:
request_id: str
output: str
def to_pb(self) -> generate_pb2.FinishedGeneration:
return generate_pb2.FinishedGeneration(id=self.request_id, output=self.output)
class BLOOM:
def __init__(self, model_name: str):
if torch.cuda.is_available():
self.device = torch.device("cuda")
else:
self.device = torch.device("cpu")
self.tokenizer = AutoTokenizer.from_pretrained(model_name, padding_side="left")
self.model = (
AutoModelForCausalLM.from_pretrained(model_name).eval().to(self.device)
)
self.num_heads = self.model.base_model.num_heads
def forward(self, input_ids, attention_mask, past_key_values: Optional = None):
# Model Forward
return self.model.forward(
input_ids=input_ids,
attention_mask=attention_mask,
past_key_values=past_key_values,
use_cache=True,
)
def generate_token(
self, batch: Batch
) -> Tuple[List[FinishedGeneration], Optional[CacheEntry]]:
with torch.no_grad():
outputs = self.forward(**batch.input_ids)
# List of indices to cache
cache_indices = []
cache_past_indices = []
# New input_ids for next forward; keep in cache
cache_next_input_ids = []
cache_all_input_ids = []
# Finished requests
finished_generations: List[FinishedGeneration] = []
# Zipped iterator
iterator = zip(
batch.request_ids,
outputs.logits,
batch.next_token_choosers,
batch.stopping_criterias,
batch.all_input_ids,
)
# For each member of the batch
for i, (
request_id,
logits,
next_token_chooser,
stopping_criteria,
all_tokens,
) in enumerate(iterator):
# Select next token
next_token = next_token_chooser(all_tokens, logits.unsqueeze(0)[:, -1])
# Append next token to all tokens
all_tokens = torch.cat([all_tokens, next_token])
# Evaluate stopping criteria
if stopping_criteria(all_tokens):
# Decode all tokens
output = self.tokenizer.decode(
all_tokens.squeeze(-1), skip_special_tokens=True
)
# Add to the list of finished generations with the original request id
finished_generations.append(FinishedGeneration(request_id, output))
# must be added to the cache
else:
cache_indices.append(i)
cache_past_indices.extend([j for j in range(i * self.num_heads, (i + 1) * self.num_heads)])
cache_next_input_ids.append(next_token)
cache_all_input_ids.append(all_tokens)
# No cache is needed, we finished all generations in the batch
if not cache_indices:
return finished_generations, None
# If we finished at least one generation
cache_input_ids = {"input_ids": torch.cat(cache_next_input_ids, dim=0)}
if finished_generations:
# Apply indices to attention mask, past key values and other items that need to be cached
cache_input_ids["attention_mask"] = batch.input_ids["attention_mask"][
cache_indices
]
cache_input_ids["past_key_values"] = [
(keys[cache_past_indices], values[cache_past_indices])
for keys, values in outputs["past_key_values"]
]
cache_request_ids = [batch.request_ids[i] for i in cache_indices]
cache_next_token_choosers = [
batch.next_token_choosers[i] for i in cache_indices
]
cache_stopping_criterias = [
batch.stopping_criterias[i] for i in cache_indices
]
else:
cache_input_ids["attention_mask"] = batch.input_ids["attention_mask"]
cache_input_ids["past_key_values"] = outputs["past_key_values"]
cache_request_ids = batch.request_ids
cache_next_token_choosers = batch.next_token_choosers
cache_stopping_criterias = batch.stopping_criterias
# Update attention_mask with padding as we added a new token to input_ids
cache_input_ids["attention_mask"] = torch.cat(
[
cache_input_ids["attention_mask"],
torch.ones((cache_input_ids["attention_mask"].shape[0], 1)).to(
cache_input_ids["attention_mask"].device
),
],
dim=1,
)
cache_entry = CacheEntry(
batch.batch_id,
cache_request_ids,
cache_input_ids,
cache_all_input_ids,
cache_next_token_choosers,
cache_stopping_criterias,
)
return finished_generations, cache_entry
class BLOOMSharded(BLOOM):
def __init__(self, model_name: str, shard_directory: Path):
super(BLOOM, self).__init__()
self.process_group, self.rank, self.world_size = initialize_torch_distributed()
self.master = self.rank == 0
if torch.cuda.is_available():
self.device = torch.device(f"cuda:{self.rank}")
dtype = torch.bfloat16
else:
self.device = torch.device("cpu")
dtype = torch.float32
self.tokenizer = AutoTokenizer.from_pretrained(model_name, padding_side="left")
# shard state_dict
if self.master:
# TODO @thomasw21 do some caching
shard_state_dict_paths = shard_model(
model_name, shard_directory, tp_world_size=self.world_size, dtype=dtype
)
shard_state_dict_paths = [
str(path.absolute()) for path in shard_state_dict_paths
]
else:
shard_state_dict_paths = [None] * self.world_size
torch.distributed.broadcast_object_list(
shard_state_dict_paths, src=0, group=self.process_group
)
shard_state_dict_path = shard_state_dict_paths[self.rank]
config = AutoConfig.from_pretrained(
model_name, slow_but_exact=False, tp_parallel=True
)
config.pad_token_id = 3
# The flag below controls whether to allow TF32 on matmul. This flag defaults to False
# in PyTorch 1.12 and later.
torch.backends.cuda.matmul.allow_tf32 = True
# The flag below controls whether to allow TF32 on cuDNN. This flag defaults to True.
torch.backends.cudnn.allow_tf32 = True
with set_default_dtype(dtype):
with no_init_weights():
# we can probably set the device to `meta` here?
model = AutoModelForCausalLM.from_config(config).to(dtype)
torch.distributed.barrier(group=self.process_group)
# print_rank_0(f"Initialized model")
state_dict = torch.load(shard_state_dict_path)
# TODO @thomasw21: HACK in order to transpose all weight prior
for key in state_dict.keys():
do_transpose = False
if not match_suffix(key, "weight"):
continue
for potential_suffix in [
"self_attention.query_key_value.weight",
"self_attention.dense.weight",
"dense_h_to_4h.weight",
"dense_4h_to_h.weight",
]:
if match_suffix(key, potential_suffix):
do_transpose = True
if do_transpose:
state_dict[key] = state_dict[key].transpose(1, 0).contiguous()
model.load_state_dict(state_dict)
self.model = model.to(self.device).eval()
self.num_heads = config.n_head // self.process_group.size()
torch.distributed.barrier(group=self.process_group)
def forward(self, input_ids, attention_mask, past_key_values: Optional = None):
outputs = self.model.forward(
input_ids=input_ids,
attention_mask=attention_mask,
past_key_values=past_key_values,
use_cache=True,
)
logits_shard = outputs.logits[:, -1, :].contiguous()
batch_size, vocab_shard_size = logits_shard.shape
vocab_size = self.world_size * vocab_shard_size
logits = [torch.empty_like(logits_shard) for _ in range(self.world_size)]
torch.distributed.all_gather(logits, logits_shard, group=self.process_group)
logits = torch.cat(logits, dim=1).view(batch_size, 1, vocab_size)
outputs.logits = logits
return outputs
# -*- coding: utf-8 -*-
# Generated by the protocol buffer compiler. DO NOT EDIT!
# source: generate.proto
"""Generated protocol buffer code."""
from google.protobuf.internal import builder as _builder
from google.protobuf import descriptor as _descriptor
from google.protobuf import descriptor_pool as _descriptor_pool
from google.protobuf import symbol_database as _symbol_database
# @@protoc_insertion_point(imports)
_sym_db = _symbol_database.Default()
DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\x0egenerate.proto\x12\x0bgenerate.v1\"(\n\x18ServiceDiscoveryResponse\x12\x0c\n\x04urls\x18\x01 \x03(\t\"^\n\x16LogitsWarperParameters\x12\x13\n\x0btemperature\x18\x01 \x01(\x02\x12\r\n\x05top_k\x18\x02 \x01(\r\x12\r\n\x05top_p\x18\x03 \x01(\x02\x12\x11\n\tdo_sample\x18\x04 \x01(\x08\"v\n\x07Request\x12\n\n\x02id\x18\x01 \x01(\x04\x12\x0e\n\x06inputs\x18\x02 \x01(\t\x12\x37\n\nparameters\x18\x03 \x01(\x0b\x32#.generate.v1.LogitsWarperParameters\x12\x16\n\x0emax_new_tokens\x18\x04 \x01(\r\";\n\x05\x42\x61tch\x12\n\n\x02id\x18\x01 \x01(\x04\x12&\n\x08requests\x18\x02 \x03(\x0b\x32\x14.generate.v1.Request\"\x7f\n\x0b\x42\x61tchCached\x12\n\n\x02id\x18\x01 \x01(\x04\x12\x13\n\x0brequest_ids\x18\x02 \x03(\x04\x12\x18\n\x10\x62\x61tch_cached_ids\x18\x03 \x03(\x04\x12\x18\n\x10total_batch_size\x18\x04 \x01(\r\x12\x1b\n\x13max_sequence_length\x18\x05 \x01(\r\"0\n\x12\x46inishedGeneration\x12\n\n\x02id\x18\x01 \x01(\x04\x12\x0e\n\x06output\x18\x02 \x01(\t\"F\n\nCacheEntry\x12\n\n\x02id\x18\x01 \x01(\x04\x12\x13\n\x0brequest_ids\x18\x02 \x03(\x04\x12\x17\n\x0fsequence_length\x18\x03 \x01(\r\"\x80\x01\n\x08Response\x12\x31\n\x08\x66inished\x18\x01 \x03(\x0b\x32\x1f.generate.v1.FinishedGeneration\x12\x31\n\x0b\x63\x61\x63he_entry\x18\x02 \x01(\x0b\x32\x17.generate.v1.CacheEntryH\x00\x88\x01\x01\x42\x0e\n\x0c_cache_entry\"\x07\n\x05\x45mpty2\x94\x02\n\x0eTextGeneration\x12O\n\x10ServiceDiscovery\x12\x12.generate.v1.Empty\x1a%.generate.v1.ServiceDiscoveryResponse\"\x00\x12\x34\n\nClearCache\x12\x12.generate.v1.Empty\x1a\x12.generate.v1.Empty\x12\x35\n\x08Generate\x12\x12.generate.v1.Batch\x1a\x15.generate.v1.Response\x12\x44\n\x11GenerateWithCache\x12\x18.generate.v1.BatchCached\x1a\x15.generate.v1.Responseb\x06proto3')
_builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, globals())
_builder.BuildTopDescriptorsAndMessages(DESCRIPTOR, 'generate_pb2', globals())
if _descriptor._USE_C_DESCRIPTORS == False:
DESCRIPTOR._options = None
_SERVICEDISCOVERYRESPONSE._serialized_start=31
_SERVICEDISCOVERYRESPONSE._serialized_end=71
_LOGITSWARPERPARAMETERS._serialized_start=73
_LOGITSWARPERPARAMETERS._serialized_end=167
_REQUEST._serialized_start=169
_REQUEST._serialized_end=287
_BATCH._serialized_start=289
_BATCH._serialized_end=348
_BATCHCACHED._serialized_start=350
_BATCHCACHED._serialized_end=477
_FINISHEDGENERATION._serialized_start=479
_FINISHEDGENERATION._serialized_end=527
_CACHEENTRY._serialized_start=529
_CACHEENTRY._serialized_end=599
_RESPONSE._serialized_start=602
_RESPONSE._serialized_end=730
_EMPTY._serialized_start=732
_EMPTY._serialized_end=739
_TEXTGENERATION._serialized_start=742
_TEXTGENERATION._serialized_end=1018
# @@protoc_insertion_point(module_scope)
# -*- coding: utf-8 -*-
# Generated by the protocol buffer compiler. DO NOT EDIT!
# source: generate.proto
"""Generated protocol buffer code."""
from google.protobuf.internal import builder as _builder
from google.protobuf import descriptor as _descriptor
from google.protobuf import descriptor_pool as _descriptor_pool
from google.protobuf import symbol_database as _symbol_database
# @@protoc_insertion_point(imports)
_sym_db = _symbol_database.Default()
DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\x0egenerate.proto\x12\x0bgenerate.v1\"(\n\x18ServiceDiscoveryResponse\x12\x0c\n\x04urls\x18\x01 \x03(\t\"^\n\x16LogitsWarperParameters\x12\x13\n\x0btemperature\x18\x01 \x01(\x02\x12\r\n\x05top_k\x18\x02 \x01(\r\x12\r\n\x05top_p\x18\x03 \x01(\x02\x12\x11\n\tdo_sample\x18\x04 \x01(\x08\"v\n\x07Request\x12\n\n\x02id\x18\x01 \x01(\x04\x12\x0e\n\x06inputs\x18\x02 \x01(\t\x12\x37\n\nparameters\x18\x03 \x01(\x0b\x32#.generate.v1.LogitsWarperParameters\x12\x16\n\x0emax_new_tokens\x18\x04 \x01(\r\";\n\x05\x42\x61tch\x12\n\n\x02id\x18\x01 \x01(\x04\x12&\n\x08requests\x18\x02 \x03(\x0b\x32\x14.generate.v1.Request\"\x7f\n\x0b\x42\x61tchCached\x12\n\n\x02id\x18\x01 \x01(\x04\x12\x13\n\x0brequest_ids\x18\x02 \x03(\x04\x12\x18\n\x10\x62\x61tch_cached_ids\x18\x03 \x03(\x04\x12\x18\n\x10total_batch_size\x18\x04 \x01(\r\x12\x1b\n\x13max_sequence_length\x18\x05 \x01(\r\"0\n\x12\x46inishedGeneration\x12\n\n\x02id\x18\x01 \x01(\x04\x12\x0e\n\x06output\x18\x02 \x01(\t\"F\n\nCacheEntry\x12\n\n\x02id\x18\x01 \x01(\x04\x12\x13\n\x0brequest_ids\x18\x02 \x03(\x04\x12\x17\n\x0fsequence_length\x18\x03 \x01(\r\"\x80\x01\n\x08Response\x12\x31\n\x08\x66inished\x18\x01 \x03(\x0b\x32\x1f.generate.v1.FinishedGeneration\x12\x31\n\x0b\x63\x61\x63he_entry\x18\x02 \x01(\x0b\x32\x17.generate.v1.CacheEntryH\x00\x88\x01\x01\x42\x0e\n\x0c_cache_entry\"\x07\n\x05\x45mpty2\x94\x02\n\x0eTextGeneration\x12O\n\x10ServiceDiscovery\x12\x12.generate.v1.Empty\x1a%.generate.v1.ServiceDiscoveryResponse\"\x00\x12\x34\n\nClearCache\x12\x12.generate.v1.Empty\x1a\x12.generate.v1.Empty\x12\x35\n\x08Generate\x12\x12.generate.v1.Batch\x1a\x15.generate.v1.Response\x12\x44\n\x11GenerateWithCache\x12\x18.generate.v1.BatchCached\x1a\x15.generate.v1.Responseb\x06proto3')
_builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, globals())
_builder.BuildTopDescriptorsAndMessages(DESCRIPTOR, 'generate_pb2', globals())
if _descriptor._USE_C_DESCRIPTORS == False:
DESCRIPTOR._options = None
_SERVICEDISCOVERYRESPONSE._serialized_start=31
_SERVICEDISCOVERYRESPONSE._serialized_end=71
_LOGITSWARPERPARAMETERS._serialized_start=73
_LOGITSWARPERPARAMETERS._serialized_end=167
_REQUEST._serialized_start=169
_REQUEST._serialized_end=287
_BATCH._serialized_start=289
_BATCH._serialized_end=348
_BATCHCACHED._serialized_start=350
_BATCHCACHED._serialized_end=477
_FINISHEDGENERATION._serialized_start=479
_FINISHEDGENERATION._serialized_end=527
_CACHEENTRY._serialized_start=529
_CACHEENTRY._serialized_end=599
_RESPONSE._serialized_start=602
_RESPONSE._serialized_end=730
_EMPTY._serialized_start=732
_EMPTY._serialized_end=739
_TEXTGENERATION._serialized_start=742
_TEXTGENERATION._serialized_end=1018
# @@protoc_insertion_point(module_scope)
# Generated by the gRPC Python protocol compiler plugin. DO NOT EDIT!
"""Client and server classes corresponding to protobuf-defined services."""
import grpc
from . import generate_pb2 as generate__pb2
class TextGenerationStub(object):
"""Missing associated documentation comment in .proto file."""
def __init__(self, channel):
"""Constructor.
Args:
channel: A grpc.Channel.
"""
self.ServiceDiscovery = channel.unary_unary(
'/generate.v1.TextGeneration/ServiceDiscovery',
request_serializer=generate__pb2.Empty.SerializeToString,
response_deserializer=generate__pb2.ServiceDiscoveryResponse.FromString,
)
self.ClearCache = channel.unary_unary(
'/generate.v1.TextGeneration/ClearCache',
request_serializer=generate__pb2.Empty.SerializeToString,
response_deserializer=generate__pb2.Empty.FromString,
)
self.Generate = channel.unary_unary(
'/generate.v1.TextGeneration/Generate',
request_serializer=generate__pb2.Batch.SerializeToString,
response_deserializer=generate__pb2.Response.FromString,
)
self.GenerateWithCache = channel.unary_unary(
'/generate.v1.TextGeneration/GenerateWithCache',
request_serializer=generate__pb2.BatchCached.SerializeToString,
response_deserializer=generate__pb2.Response.FromString,
)
class TextGenerationServicer(object):
"""Missing associated documentation comment in .proto file."""
def ServiceDiscovery(self, request, context):
"""/ Service discovery
"""
context.set_code(grpc.StatusCode.UNIMPLEMENTED)
context.set_details('Method not implemented!')
raise NotImplementedError('Method not implemented!')
def ClearCache(self, request, context):
"""/ Empties batch cache
"""
context.set_code(grpc.StatusCode.UNIMPLEMENTED)
context.set_details('Method not implemented!')
raise NotImplementedError('Method not implemented!')
def Generate(self, request, context):
"""/ Generate tokens for a batch without cache
"""
context.set_code(grpc.StatusCode.UNIMPLEMENTED)
context.set_details('Method not implemented!')
raise NotImplementedError('Method not implemented!')
def GenerateWithCache(self, request, context):
"""/ Generate tokens for a batch with cache
"""
context.set_code(grpc.StatusCode.UNIMPLEMENTED)
context.set_details('Method not implemented!')
raise NotImplementedError('Method not implemented!')
def add_TextGenerationServicer_to_server(servicer, server):
rpc_method_handlers = {
'ServiceDiscovery': grpc.unary_unary_rpc_method_handler(
servicer.ServiceDiscovery,
request_deserializer=generate__pb2.Empty.FromString,
response_serializer=generate__pb2.ServiceDiscoveryResponse.SerializeToString,
),
'ClearCache': grpc.unary_unary_rpc_method_handler(
servicer.ClearCache,
request_deserializer=generate__pb2.Empty.FromString,
response_serializer=generate__pb2.Empty.SerializeToString,
),
'Generate': grpc.unary_unary_rpc_method_handler(
servicer.Generate,
request_deserializer=generate__pb2.Batch.FromString,
response_serializer=generate__pb2.Response.SerializeToString,
),
'GenerateWithCache': grpc.unary_unary_rpc_method_handler(
servicer.GenerateWithCache,
request_deserializer=generate__pb2.BatchCached.FromString,
response_serializer=generate__pb2.Response.SerializeToString,
),
}
generic_handler = grpc.method_handlers_generic_handler(
'generate.v1.TextGeneration', rpc_method_handlers)
server.add_generic_rpc_handlers((generic_handler,))
# This class is part of an EXPERIMENTAL API.
class TextGeneration(object):
"""Missing associated documentation comment in .proto file."""
@staticmethod
def ServiceDiscovery(request,
target,
options=(),
channel_credentials=None,
call_credentials=None,
insecure=False,
compression=None,
wait_for_ready=None,
timeout=None,
metadata=None):
return grpc.experimental.unary_unary(request, target, '/generate.v1.TextGeneration/ServiceDiscovery',
generate__pb2.Empty.SerializeToString,
generate__pb2.ServiceDiscoveryResponse.FromString,
options, channel_credentials,
insecure, call_credentials, compression, wait_for_ready, timeout, metadata)
@staticmethod
def ClearCache(request,
target,
options=(),
channel_credentials=None,
call_credentials=None,
insecure=False,
compression=None,
wait_for_ready=None,
timeout=None,
metadata=None):
return grpc.experimental.unary_unary(request, target, '/generate.v1.TextGeneration/ClearCache',
generate__pb2.Empty.SerializeToString,
generate__pb2.Empty.FromString,
options, channel_credentials,
insecure, call_credentials, compression, wait_for_ready, timeout, metadata)
@staticmethod
def Generate(request,
target,
options=(),
channel_credentials=None,
call_credentials=None,
insecure=False,
compression=None,
wait_for_ready=None,
timeout=None,
metadata=None):
return grpc.experimental.unary_unary(request, target, '/generate.v1.TextGeneration/Generate',
generate__pb2.Batch.SerializeToString,
generate__pb2.Response.FromString,
options, channel_credentials,
insecure, call_credentials, compression, wait_for_ready, timeout, metadata)
@staticmethod
def GenerateWithCache(request,
target,
options=(),
channel_credentials=None,
call_credentials=None,
insecure=False,
compression=None,
wait_for_ready=None,
timeout=None,
metadata=None):
return grpc.experimental.unary_unary(request, target, '/generate.v1.TextGeneration/GenerateWithCache',
generate__pb2.BatchCached.SerializeToString,
generate__pb2.Response.FromString,
options, channel_credentials,
insecure, call_credentials, compression, wait_for_ready, timeout, metadata)
# Generated by the gRPC Python protocol compiler plugin. DO NOT EDIT!
"""Client and server classes corresponding to protobuf-defined services."""
import grpc
import generate_pb2 as generate__pb2
class TextGenerationStub(object):
"""Missing associated documentation comment in .proto file."""
def __init__(self, channel):
"""Constructor.
Args:
channel: A grpc.Channel.
"""
self.ServiceDiscovery = channel.unary_unary(
'/generate.v1.TextGeneration/ServiceDiscovery',
request_serializer=generate__pb2.Empty.SerializeToString,
response_deserializer=generate__pb2.ServiceDiscoveryResponse.FromString,
)
self.ClearCache = channel.unary_unary(
'/generate.v1.TextGeneration/ClearCache',
request_serializer=generate__pb2.Empty.SerializeToString,
response_deserializer=generate__pb2.Empty.FromString,
)
self.Generate = channel.unary_unary(
'/generate.v1.TextGeneration/Generate',
request_serializer=generate__pb2.Batch.SerializeToString,
response_deserializer=generate__pb2.Response.FromString,
)
self.GenerateWithCache = channel.unary_unary(
'/generate.v1.TextGeneration/GenerateWithCache',
request_serializer=generate__pb2.BatchCached.SerializeToString,
response_deserializer=generate__pb2.Response.FromString,
)
class TextGenerationServicer(object):
"""Missing associated documentation comment in .proto file."""
def ServiceDiscovery(self, request, context):
"""/ Service discovery
"""
context.set_code(grpc.StatusCode.UNIMPLEMENTED)
context.set_details('Method not implemented!')
raise NotImplementedError('Method not implemented!')
def ClearCache(self, request, context):
"""/ Empties batch cache
"""
context.set_code(grpc.StatusCode.UNIMPLEMENTED)
context.set_details('Method not implemented!')
raise NotImplementedError('Method not implemented!')
def Generate(self, request, context):
"""/ Generate tokens for a batch without cache
"""
context.set_code(grpc.StatusCode.UNIMPLEMENTED)
context.set_details('Method not implemented!')
raise NotImplementedError('Method not implemented!')
def GenerateWithCache(self, request, context):
"""/ Generate tokens for a batch with cache
"""
context.set_code(grpc.StatusCode.UNIMPLEMENTED)
context.set_details('Method not implemented!')
raise NotImplementedError('Method not implemented!')
def add_TextGenerationServicer_to_server(servicer, server):
rpc_method_handlers = {
'ServiceDiscovery': grpc.unary_unary_rpc_method_handler(
servicer.ServiceDiscovery,
request_deserializer=generate__pb2.Empty.FromString,
response_serializer=generate__pb2.ServiceDiscoveryResponse.SerializeToString,
),
'ClearCache': grpc.unary_unary_rpc_method_handler(
servicer.ClearCache,
request_deserializer=generate__pb2.Empty.FromString,
response_serializer=generate__pb2.Empty.SerializeToString,
),
'Generate': grpc.unary_unary_rpc_method_handler(
servicer.Generate,
request_deserializer=generate__pb2.Batch.FromString,
response_serializer=generate__pb2.Response.SerializeToString,
),
'GenerateWithCache': grpc.unary_unary_rpc_method_handler(
servicer.GenerateWithCache,
request_deserializer=generate__pb2.BatchCached.FromString,
response_serializer=generate__pb2.Response.SerializeToString,
),
}
generic_handler = grpc.method_handlers_generic_handler(
'generate.v1.TextGeneration', rpc_method_handlers)
server.add_generic_rpc_handlers((generic_handler,))
# This class is part of an EXPERIMENTAL API.
class TextGeneration(object):
"""Missing associated documentation comment in .proto file."""
@staticmethod
def ServiceDiscovery(request,
target,
options=(),
channel_credentials=None,
call_credentials=None,
insecure=False,
compression=None,
wait_for_ready=None,
timeout=None,
metadata=None):
return grpc.experimental.unary_unary(request, target, '/generate.v1.TextGeneration/ServiceDiscovery',
generate__pb2.Empty.SerializeToString,
generate__pb2.ServiceDiscoveryResponse.FromString,
options, channel_credentials,
insecure, call_credentials, compression, wait_for_ready, timeout, metadata)
@staticmethod
def ClearCache(request,
target,
options=(),
channel_credentials=None,
call_credentials=None,
insecure=False,
compression=None,
wait_for_ready=None,
timeout=None,
metadata=None):
return grpc.experimental.unary_unary(request, target, '/generate.v1.TextGeneration/ClearCache',
generate__pb2.Empty.SerializeToString,
generate__pb2.Empty.FromString,
options, channel_credentials,
insecure, call_credentials, compression, wait_for_ready, timeout, metadata)
@staticmethod
def Generate(request,
target,
options=(),
channel_credentials=None,
call_credentials=None,
insecure=False,
compression=None,
wait_for_ready=None,
timeout=None,
metadata=None):
return grpc.experimental.unary_unary(request, target, '/generate.v1.TextGeneration/Generate',
generate__pb2.Batch.SerializeToString,
generate__pb2.Response.FromString,
options, channel_credentials,
insecure, call_credentials, compression, wait_for_ready, timeout, metadata)
@staticmethod
def GenerateWithCache(request,
target,
options=(),
channel_credentials=None,
call_credentials=None,
insecure=False,
compression=None,
wait_for_ready=None,
timeout=None,
metadata=None):
return grpc.experimental.unary_unary(request, target, '/generate.v1.TextGeneration/GenerateWithCache',
generate__pb2.BatchCached.SerializeToString,
generate__pb2.Response.FromString,
options, channel_credentials,
insecure, call_credentials, compression, wait_for_ready, timeout, metadata)
import torch
from pathlib import Path
from tqdm import tqdm
MODEL_NAME = "bigscience/bloom"
def match_suffix(text, suffix):
return text[-len(suffix) :] == suffix
def prepare_weights(hub_path: Path, save_path: Path, tp_world_size: int):
save_paths = [
save_path / f"{MODEL_NAME}_tp-rank-{tp_rank}-of-{tp_world_size}.pty"
for tp_rank in range(tp_world_size)
]
if all(save_path.exists() for save_path in save_paths):
print("Weights are already prepared")
return
shards_state_dicts = [{} for _ in range(tp_world_size)]
for weight_path in tqdm(hub_path.glob("*.bin")):
state_dict = torch.load(weight_path, map_location="cpu")
keys = list(state_dict.keys())
for state_name in keys:
state = state_dict[state_name]
if any(
match_suffix(state_name, candidate)
for candidate in [
"self_attention.query_key_value.weight",
"self_attention.query_key_value.bias",
"mlp.dense_h_to_4h.weight",
"mlp.dense_h_to_4h.bias",
"word_embeddings.weight",
"lm_head.weight",
]
):
output_size = state.shape[0]
assert output_size % tp_world_size == 0
block_size = output_size // tp_world_size
sharded_weights = torch.split(state, block_size, dim=0)
assert len(sharded_weights) == tp_world_size
for tp_rank, shard in enumerate(sharded_weights):
assert shard.shape[0] == block_size
if match_suffix(state_name, "lm_head.weight"):
shards_state_dicts[tp_rank][state_name] = shard.detach().clone()
else:
shards_state_dicts[tp_rank][
"transformer." + state_name
] = shard.detach().clone()
elif any(
match_suffix(state_name, candidate)
for candidate in [
"self_attention.dense.weight",
"mlp.dense_4h_to_h.weight",
"lm_head.weight",
]
):
input_size = state.shape[1]
assert input_size % tp_world_size == 0
block_size = input_size // tp_world_size
sharded_weights = torch.split(state, block_size, dim=1)
assert len(sharded_weights) == tp_world_size
for tp_rank, shard in enumerate(sharded_weights):
assert shard.shape[1] == block_size
if match_suffix(state_name, "lm_head.weight"):
shards_state_dicts[tp_rank][state_name] = shard.detach().clone()
else:
shards_state_dicts[tp_rank][
"transformer." + state_name
] = shard.detach().clone()
elif any(
match_suffix(state_name, candidate)
for candidate in [
"self_attention.dense.bias",
"mlp.dense_4h_to_h.bias",
]
):
shards_state_dicts[0][
"transformer." + state_name
] = state.detach().clone()
for tp_rank in range(1, tp_world_size):
shards_state_dicts[tp_rank][
"transformer." + state_name
] = torch.zeros_like(state)
else:
# We duplicate parameters across tp ranks
for tp_rank in range(tp_world_size):
shards_state_dicts[tp_rank][
"transformer." + state_name
] = state.detach().clone()
del state_dict[state_name] # delete key from state_dict
del state # delete tensor
# we save state_dict
for tp_rank, (save_path, shard_state_dict) in enumerate(
zip(save_paths, shards_state_dicts)
):
save_paths.append(save_path)
save_path.parent.mkdir(parents=True, exist_ok=True)
if save_path.exists():
print(f"Skipping {save_path} as it already exists")
else:
torch.save(shard_state_dict, save_path)
return save_paths
if __name__ == "__main__":
from argparse import ArgumentParser
parser = ArgumentParser()
parser.add_argument("--hub-path", required=True, type=str)
parser.add_argument("--save-path", required=True, type=str)
parser.add_argument("--world-size", required=True, type=int)
args = parser.parse_args()
prepare_weights(Path(args.hub_path), Path(args.save_path), args.world_size)
import asyncio
from grpc import aio
from grpc_reflection.v1alpha import reflection
from pathlib import Path
from typing import Optional, List
from bloom_inference.cache import Cache
from bloom_inference.model import BLOOM, Batch, BLOOMSharded
from bloom_inference.pb import generate_pb2_grpc, generate_pb2
class TextGeneration(generate_pb2_grpc.TextGenerationServicer):
def __init__(self, model: BLOOM, cache: Cache, server_urls: List[str]):
self.cache = cache
self.model = model
self.server_urls = server_urls
async def ServiceDiscovery(self, request, context):
return generate_pb2.ServiceDiscoveryResponse(urls=self.server_urls)
async def ClearCache(self, request, context):
self.cache.clear()
return generate_pb2.Empty()
async def Generate(self, request, context):
batch = Batch.from_batch_pb(request, self.model.tokenizer, self.model.device)
finished_generations, cache_entry = self.model.generate_token(batch)
self.cache.set(cache_entry)
return generate_pb2.Response(
finished=[
finished_generation.to_pb()
for finished_generation in finished_generations
],
cache_entry=cache_entry.to_pb() if cache_entry else None,
)
async def GenerateWithCache(self, request, context):
batch = Batch.from_batch_cached_pb(request, self.cache)
finished_generations, cache_entry = self.model.generate_token(batch)
self.cache.set(cache_entry)
return generate_pb2.Response(
finished=[
finished_generation.to_pb()
for finished_generation in finished_generations
],
cache_entry=cache_entry.to_pb() if cache_entry else None,
)
def serve(model_name, sharded, shard_directory):
async def serve_inner(
model_name: str,
sharded: bool = False,
shard_directory: Optional[Path] = None,
):
unix_socket_template = "unix:///tmp/bloom-inference-{}"
if sharded:
if shard_directory is None:
raise ValueError("shard_directory must be set when sharded is True")
model = BLOOMSharded(model_name, shard_directory)
server_urls = [
unix_socket_template.format(rank) for rank in range(model.world_size)
]
local_url = unix_socket_template.format(model.rank)
else:
model = BLOOM(model_name)
local_url = unix_socket_template.format(0)
server_urls = [local_url]
server = aio.server()
generate_pb2_grpc.add_TextGenerationServicer_to_server(
TextGeneration(model, Cache(), server_urls), server
)
SERVICE_NAMES = (
generate_pb2.DESCRIPTOR.services_by_name["TextGeneration"].full_name,
reflection.SERVICE_NAME,
)
reflection.enable_server_reflection(SERVICE_NAMES, server)
server.add_insecure_port(local_url)
await server.start()
print("Server started at {}".format(local_url))
await server.wait_for_termination()
asyncio.run(serve_inner(model_name, sharded, shard_directory))
if __name__ == "__main__":
serve("bigscience/bloom-560m", True, Path("/tmp/models"))
from pathlib import Path
import torch
from torch import nn
from transformers import AutoModelForCausalLM
def match_suffix(text, suffix):
return text[-len(suffix) :] == suffix
def shard_model(model_name: str, path: Path, tp_world_size: int, dtype: torch.dtype):
"""BLOOM specific sharding mechanism"""
save_paths = [
path / f"{model_name}_tp-rank-{tp_rank}-of-{tp_world_size}.pty"
for tp_rank in range(tp_world_size)
]
if all(save_path.exists() for save_path in save_paths):
print("Loading already cached values")
return save_paths
model: nn.Module = AutoModelForCausalLM.from_pretrained(
model_name, torch_dtype=dtype, local_files_only=True
)
shards_state_dicts = [{} for _ in range(tp_world_size)]
state_dict = model.state_dict()
keys = list(state_dict.keys())
for state_name in keys:
print(state_name)
state = state_dict[state_name]
if any(
match_suffix(state_name, candidate)
for candidate in [
"self_attention.query_key_value.weight",
"self_attention.query_key_value.bias",
"mlp.dense_h_to_4h.weight",
"mlp.dense_h_to_4h.bias",
"transformer.word_embeddings.weight",
"lm_head.weight",
]
):
output_size = state.shape[0]
assert output_size % tp_world_size == 0
block_size = output_size // tp_world_size
sharded_weights = torch.split(state, block_size, dim=0)
assert len(sharded_weights) == tp_world_size
for tp_rank, shard in enumerate(sharded_weights):
assert shard.shape[0] == block_size
shards_state_dicts[tp_rank][state_name] = shard.detach().clone()
elif any(
match_suffix(state_name, candidate)
for candidate in [
"self_attention.dense.weight",
"mlp.dense_4h_to_h.weight",
"lm_head.weight",
]
):
input_size = state.shape[1]
assert input_size % tp_world_size == 0
block_size = input_size // tp_world_size
sharded_weights = torch.split(state, block_size, dim=1)
assert len(sharded_weights) == tp_world_size
for tp_rank, shard in enumerate(sharded_weights):
assert shard.shape[1] == block_size
shards_state_dicts[tp_rank][state_name] = shard.detach().clone()
elif any(
match_suffix(state_name, candidate)
for candidate in [
"self_attention.dense.bias",
"mlp.dense_4h_to_h.bias",
]
):
shards_state_dicts[0][state_name] = state.detach().clone()
for tp_rank in range(1, tp_world_size):
shards_state_dicts[tp_rank][state_name] = torch.zeros_like(state)
else:
# We duplicate parameters across tp ranks
for tp_rank in range(tp_world_size):
shards_state_dicts[tp_rank][state_name] = state.detach().clone()
del state_dict[state_name] # delete key from state_dict
del state # delete tensor
# we save state_dict
for tp_rank, (save_path, shard_state_dict) in enumerate(
zip(save_paths, shards_state_dicts)
):
save_path.parent.mkdir(parents=True, exist_ok=True)
torch.save(shard_state_dict, save_path)
save_paths.append(save_path)
return save_paths
if __name__ == "__main__":
model_name = "bigscience/bloom"
save_path = Path("/data/shards")
tp_world_size = 8
dtype = torch.bfloat16
shard_model(model_name, save_path, tp_world_size=tp_world_size, dtype=dtype)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment