"vscode:/vscode.git/clone" did not exist on "4f2ee48ed1c66ee0e189daa4120581de324ee814"
Commit 295831a4 authored by Olivier Dehaene's avatar Olivier Dehaene
Browse files

Init

parents
import torch
from dataclasses import dataclass
from typing import Dict, Optional, List
from bloom_inference.pb import generate_pb2
from bloom_inference.utils import NextTokenChooser, StoppingCriteria
@dataclass
class CacheEntry:
batch_id: int
request_ids: List[int]
input_ids: Dict[str, torch.Tensor]
all_input_ids: List[torch.Tensor]
next_token_choosers: List[NextTokenChooser]
stopping_criterias: List[StoppingCriteria]
def __len__(self):
return len(self.request_ids)
def to_pb(self):
return generate_pb2.CacheEntry(
id=self.batch_id,
request_ids=self.request_ids,
sequence_length=max(len(entry) for entry in self.all_input_ids),
)
class Cache:
def __init__(self):
self.cache: Dict[str, CacheEntry] = {}
def pop(self, batch_id: str) -> Optional[CacheEntry]:
return self.cache.pop(batch_id, None)
def set(self, entry: CacheEntry):
if entry is not None:
self.cache[entry.batch_id] = entry
def delete(self, batch_id: str):
del self.cache[batch_id]
def clear(self):
self.cache.clear()
def __len__(self):
return len(self.cache.keys())
import typer
from pathlib import Path
from torch.distributed.launcher import launch_agent, LaunchConfig
from typing import Optional
from bloom_inference.server import serve
def main(
model_name: str,
num_gpus: int = 1,
shard_directory: Optional[Path] = None,
):
if num_gpus == 1:
serve(model_name, False, shard_directory)
else:
config = LaunchConfig(
min_nodes=1,
max_nodes=1,
nproc_per_node=num_gpus,
rdzv_backend="c10d",
max_restarts=0,
)
launch_agent(config, serve, [model_name, True, shard_directory])
if __name__ == "__main__":
typer.run(main)
import torch
import torch.distributed
from dataclasses import dataclass
from pathlib import Path
from typing import List, Tuple, Optional, Dict
from transformers import AutoTokenizer, AutoModelForCausalLM, AutoConfig
from transformers.modeling_utils import no_init_weights
from bloom_inference.cache import CacheEntry
from bloom_inference.pb import generate_pb2
from bloom_inference.shard_model import shard_model, match_suffix
from bloom_inference.utils import (
StoppingCriteria,
NextTokenChooser,
initialize_torch_distributed,
set_default_dtype,
)
torch.manual_seed(0)
@dataclass
class Batch:
batch_id: int
request_ids: List[int]
input_ids: Dict[str, torch.Tensor]
all_input_ids: List[torch.Tensor]
next_token_choosers: List[NextTokenChooser]
stopping_criterias: List[StoppingCriteria]
@classmethod
def from_batch_pb(
cls, pb: generate_pb2.Batch, tokenizer: AutoTokenizer, device: torch.device
) -> "Batch":
request_ids = []
inputs = []
next_token_choosers = []
stopping_criterias = []
# Parse batch
for r in pb.requests:
request_ids.append(r.id)
inputs.append(r.inputs)
next_token_choosers.append(
NextTokenChooser(
temperature=r.parameters.temperature,
top_k=r.parameters.top_k,
top_p=r.parameters.top_p,
do_sample=r.parameters.do_sample,
)
)
stopping_criterias.append(StoppingCriteria(max_new_tokens=r.max_new_tokens))
input_ids = tokenizer(inputs, return_tensors="pt", padding=True).to(device)
all_input_ids = input_ids["input_ids"].unsqueeze(-1)
return cls(
pb.id,
request_ids,
input_ids,
all_input_ids,
next_token_choosers,
stopping_criterias,
)
@classmethod
def from_cache_entry(cls, cache_entry: CacheEntry) -> "Batch":
return cls(
cache_entry.batch_id,
cache_entry.request_ids,
cache_entry.input_ids,
cache_entry.all_input_ids,
cache_entry.next_token_choosers,
cache_entry.stopping_criterias,
)
@classmethod
def from_batch_cached_pb(cls, pb: generate_pb2.BatchCached, cache) -> "Batch":
if len(pb.batch_cached_ids) == 1:
cache_entry = cache.pop(pb.batch_cached_ids[0])
if cache_entry is None:
raise ValueError(f"Batch ID {pb.batch_id} not found in cache")
return cls.from_cache_entry(cache_entry)
total_batch_size = pb.total_batch_size
max_sequence_length = pb.max_sequence_length
input_ids = {"input_ids": None, "attention_mask": None, "past_key_values": []}
request_ids = []
all_input_ids = []
next_token_choosers = []
stopping_criterias = []
start_index = 0
for i, batch_id in enumerate(pb.batch_cached_ids):
cache_entry = cache.pop(batch_id)
if cache_entry is None:
raise ValueError(f"Batch ID {batch_id} not found in cache")
request_ids.extend(cache_entry.request_ids)
all_input_ids.extend(cache_entry.all_input_ids)
next_token_choosers.extend(cache_entry.next_token_choosers)
stopping_criterias.extend(cache_entry.stopping_criterias)
batch_size = len(cache_entry.request_ids)
end_index = start_index + batch_size
sequence_length = max(len(entry) for entry in cache_entry.all_input_ids)
if input_ids["input_ids"] is None:
input_ids["input_ids"] = torch.empty(
(total_batch_size, 1),
dtype=cache_entry.input_ids["input_ids"].dtype,
device=cache_entry.input_ids["input_ids"].device,
)
input_ids["input_ids"][start_index:end_index] = cache_entry.input_ids[
"input_ids"
]
if input_ids["attention_mask"] is None:
input_ids["attention_mask"] = torch.zeros(
(total_batch_size, max_sequence_length),
dtype=cache_entry.input_ids["attention_mask"].dtype,
device=cache_entry.input_ids["attention_mask"].device,
)
input_ids["attention_mask"][
start_index:end_index, -sequence_length:
] = cache_entry.input_ids["attention_mask"][:, -sequence_length:]
for j, past in enumerate(cache_entry.input_ids["past_key_values"]):
# TODO: this could be done without the views by using indices
past_keys = past[0]
past_values = past[1]
_, head_dim, padded_sequence_length = past_keys.shape
past_keys = past_keys.view(
batch_size, -1, head_dim, padded_sequence_length
)
past_values = past_values.view(
batch_size, -1, padded_sequence_length, head_dim
)
num_heads = past_keys.shape[1]
if j == len(input_ids["past_key_values"]):
padded_past_keys = torch.zeros(
(
total_batch_size,
num_heads,
head_dim,
max_sequence_length - 1,
),
dtype=past_keys.dtype,
device=past_keys.device,
)
padded_past_values = torch.zeros(
(
total_batch_size,
num_heads,
max_sequence_length - 1,
head_dim,
),
dtype=past_values.dtype,
device=past_values.device,
)
input_ids["past_key_values"].append(
[padded_past_keys, padded_past_values]
)
input_ids["past_key_values"][j][0][
start_index:end_index, :, :, -(sequence_length - 1):
] = past_keys[:, :, :, -(sequence_length - 1):]
input_ids["past_key_values"][j][1][
start_index:end_index, :, -(sequence_length - 1):, :
] = past_values[:, :, -(sequence_length - 1):, :]
if (i + 1) == len(pb.batch_cached_ids):
input_ids["past_key_values"][j][0] = input_ids["past_key_values"][
j
][0].view(total_batch_size * num_heads, head_dim, -1)
input_ids["past_key_values"][j][1] = input_ids["past_key_values"][
j
][1].view(total_batch_size * num_heads, -1, head_dim)
start_index += batch_size
assert pb.request_ids == request_ids
return cls(
pb.id,
request_ids,
input_ids,
all_input_ids,
next_token_choosers,
stopping_criterias,
)
@dataclass
class FinishedGeneration:
request_id: str
output: str
def to_pb(self) -> generate_pb2.FinishedGeneration:
return generate_pb2.FinishedGeneration(id=self.request_id, output=self.output)
class BLOOM:
def __init__(self, model_name: str):
if torch.cuda.is_available():
self.device = torch.device("cuda")
else:
self.device = torch.device("cpu")
self.tokenizer = AutoTokenizer.from_pretrained(model_name, padding_side="left")
self.model = (
AutoModelForCausalLM.from_pretrained(model_name).eval().to(self.device)
)
self.num_heads = self.model.base_model.num_heads
def forward(self, input_ids, attention_mask, past_key_values: Optional = None):
# Model Forward
return self.model.forward(
input_ids=input_ids,
attention_mask=attention_mask,
past_key_values=past_key_values,
use_cache=True,
)
def generate_token(
self, batch: Batch
) -> Tuple[List[FinishedGeneration], Optional[CacheEntry]]:
with torch.no_grad():
outputs = self.forward(**batch.input_ids)
# List of indices to cache
cache_indices = []
cache_past_indices = []
# New input_ids for next forward; keep in cache
cache_next_input_ids = []
cache_all_input_ids = []
# Finished requests
finished_generations: List[FinishedGeneration] = []
# Zipped iterator
iterator = zip(
batch.request_ids,
outputs.logits,
batch.next_token_choosers,
batch.stopping_criterias,
batch.all_input_ids,
)
# For each member of the batch
for i, (
request_id,
logits,
next_token_chooser,
stopping_criteria,
all_tokens,
) in enumerate(iterator):
# Select next token
next_token = next_token_chooser(all_tokens, logits.unsqueeze(0)[:, -1])
# Append next token to all tokens
all_tokens = torch.cat([all_tokens, next_token])
# Evaluate stopping criteria
if stopping_criteria(all_tokens):
# Decode all tokens
output = self.tokenizer.decode(
all_tokens.squeeze(-1), skip_special_tokens=True
)
# Add to the list of finished generations with the original request id
finished_generations.append(FinishedGeneration(request_id, output))
# must be added to the cache
else:
cache_indices.append(i)
cache_past_indices.extend([j for j in range(i * self.num_heads, (i + 1) * self.num_heads)])
cache_next_input_ids.append(next_token)
cache_all_input_ids.append(all_tokens)
# No cache is needed, we finished all generations in the batch
if not cache_indices:
return finished_generations, None
# If we finished at least one generation
cache_input_ids = {"input_ids": torch.cat(cache_next_input_ids, dim=0)}
if finished_generations:
# Apply indices to attention mask, past key values and other items that need to be cached
cache_input_ids["attention_mask"] = batch.input_ids["attention_mask"][
cache_indices
]
cache_input_ids["past_key_values"] = [
(keys[cache_past_indices], values[cache_past_indices])
for keys, values in outputs["past_key_values"]
]
cache_request_ids = [batch.request_ids[i] for i in cache_indices]
cache_next_token_choosers = [
batch.next_token_choosers[i] for i in cache_indices
]
cache_stopping_criterias = [
batch.stopping_criterias[i] for i in cache_indices
]
else:
cache_input_ids["attention_mask"] = batch.input_ids["attention_mask"]
cache_input_ids["past_key_values"] = outputs["past_key_values"]
cache_request_ids = batch.request_ids
cache_next_token_choosers = batch.next_token_choosers
cache_stopping_criterias = batch.stopping_criterias
# Update attention_mask with padding as we added a new token to input_ids
cache_input_ids["attention_mask"] = torch.cat(
[
cache_input_ids["attention_mask"],
torch.ones((cache_input_ids["attention_mask"].shape[0], 1)).to(
cache_input_ids["attention_mask"].device
),
],
dim=1,
)
cache_entry = CacheEntry(
batch.batch_id,
cache_request_ids,
cache_input_ids,
cache_all_input_ids,
cache_next_token_choosers,
cache_stopping_criterias,
)
return finished_generations, cache_entry
class BLOOMSharded(BLOOM):
def __init__(self, model_name: str, shard_directory: Path):
super(BLOOM, self).__init__()
self.process_group, self.rank, self.world_size = initialize_torch_distributed()
self.master = self.rank == 0
if torch.cuda.is_available():
self.device = torch.device(f"cuda:{self.rank}")
dtype = torch.bfloat16
else:
self.device = torch.device("cpu")
dtype = torch.float32
self.tokenizer = AutoTokenizer.from_pretrained(model_name, padding_side="left")
# shard state_dict
if self.master:
# TODO @thomasw21 do some caching
shard_state_dict_paths = shard_model(
model_name, shard_directory, tp_world_size=self.world_size, dtype=dtype
)
shard_state_dict_paths = [
str(path.absolute()) for path in shard_state_dict_paths
]
else:
shard_state_dict_paths = [None] * self.world_size
torch.distributed.broadcast_object_list(
shard_state_dict_paths, src=0, group=self.process_group
)
shard_state_dict_path = shard_state_dict_paths[self.rank]
config = AutoConfig.from_pretrained(
model_name, slow_but_exact=False, tp_parallel=True
)
config.pad_token_id = 3
# The flag below controls whether to allow TF32 on matmul. This flag defaults to False
# in PyTorch 1.12 and later.
torch.backends.cuda.matmul.allow_tf32 = True
# The flag below controls whether to allow TF32 on cuDNN. This flag defaults to True.
torch.backends.cudnn.allow_tf32 = True
with set_default_dtype(dtype):
with no_init_weights():
# we can probably set the device to `meta` here?
model = AutoModelForCausalLM.from_config(config).to(dtype)
torch.distributed.barrier(group=self.process_group)
# print_rank_0(f"Initialized model")
state_dict = torch.load(shard_state_dict_path)
# TODO @thomasw21: HACK in order to transpose all weight prior
for key in state_dict.keys():
do_transpose = False
if not match_suffix(key, "weight"):
continue
for potential_suffix in [
"self_attention.query_key_value.weight",
"self_attention.dense.weight",
"dense_h_to_4h.weight",
"dense_4h_to_h.weight",
]:
if match_suffix(key, potential_suffix):
do_transpose = True
if do_transpose:
state_dict[key] = state_dict[key].transpose(1, 0).contiguous()
model.load_state_dict(state_dict)
self.model = model.to(self.device).eval()
self.num_heads = config.n_head // self.process_group.size()
torch.distributed.barrier(group=self.process_group)
def forward(self, input_ids, attention_mask, past_key_values: Optional = None):
outputs = self.model.forward(
input_ids=input_ids,
attention_mask=attention_mask,
past_key_values=past_key_values,
use_cache=True,
)
logits_shard = outputs.logits[:, -1, :].contiguous()
batch_size, vocab_shard_size = logits_shard.shape
vocab_size = self.world_size * vocab_shard_size
logits = [torch.empty_like(logits_shard) for _ in range(self.world_size)]
torch.distributed.all_gather(logits, logits_shard, group=self.process_group)
logits = torch.cat(logits, dim=1).view(batch_size, 1, vocab_size)
outputs.logits = logits
return outputs
# -*- coding: utf-8 -*-
# Generated by the protocol buffer compiler. DO NOT EDIT!
# source: generate.proto
"""Generated protocol buffer code."""
from google.protobuf.internal import builder as _builder
from google.protobuf import descriptor as _descriptor
from google.protobuf import descriptor_pool as _descriptor_pool
from google.protobuf import symbol_database as _symbol_database
# @@protoc_insertion_point(imports)
_sym_db = _symbol_database.Default()
DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\x0egenerate.proto\x12\x0bgenerate.v1\"(\n\x18ServiceDiscoveryResponse\x12\x0c\n\x04urls\x18\x01 \x03(\t\"^\n\x16LogitsWarperParameters\x12\x13\n\x0btemperature\x18\x01 \x01(\x02\x12\r\n\x05top_k\x18\x02 \x01(\r\x12\r\n\x05top_p\x18\x03 \x01(\x02\x12\x11\n\tdo_sample\x18\x04 \x01(\x08\"v\n\x07Request\x12\n\n\x02id\x18\x01 \x01(\x04\x12\x0e\n\x06inputs\x18\x02 \x01(\t\x12\x37\n\nparameters\x18\x03 \x01(\x0b\x32#.generate.v1.LogitsWarperParameters\x12\x16\n\x0emax_new_tokens\x18\x04 \x01(\r\";\n\x05\x42\x61tch\x12\n\n\x02id\x18\x01 \x01(\x04\x12&\n\x08requests\x18\x02 \x03(\x0b\x32\x14.generate.v1.Request\"\x7f\n\x0b\x42\x61tchCached\x12\n\n\x02id\x18\x01 \x01(\x04\x12\x13\n\x0brequest_ids\x18\x02 \x03(\x04\x12\x18\n\x10\x62\x61tch_cached_ids\x18\x03 \x03(\x04\x12\x18\n\x10total_batch_size\x18\x04 \x01(\r\x12\x1b\n\x13max_sequence_length\x18\x05 \x01(\r\"0\n\x12\x46inishedGeneration\x12\n\n\x02id\x18\x01 \x01(\x04\x12\x0e\n\x06output\x18\x02 \x01(\t\"F\n\nCacheEntry\x12\n\n\x02id\x18\x01 \x01(\x04\x12\x13\n\x0brequest_ids\x18\x02 \x03(\x04\x12\x17\n\x0fsequence_length\x18\x03 \x01(\r\"\x80\x01\n\x08Response\x12\x31\n\x08\x66inished\x18\x01 \x03(\x0b\x32\x1f.generate.v1.FinishedGeneration\x12\x31\n\x0b\x63\x61\x63he_entry\x18\x02 \x01(\x0b\x32\x17.generate.v1.CacheEntryH\x00\x88\x01\x01\x42\x0e\n\x0c_cache_entry\"\x07\n\x05\x45mpty2\x94\x02\n\x0eTextGeneration\x12O\n\x10ServiceDiscovery\x12\x12.generate.v1.Empty\x1a%.generate.v1.ServiceDiscoveryResponse\"\x00\x12\x34\n\nClearCache\x12\x12.generate.v1.Empty\x1a\x12.generate.v1.Empty\x12\x35\n\x08Generate\x12\x12.generate.v1.Batch\x1a\x15.generate.v1.Response\x12\x44\n\x11GenerateWithCache\x12\x18.generate.v1.BatchCached\x1a\x15.generate.v1.Responseb\x06proto3')
_builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, globals())
_builder.BuildTopDescriptorsAndMessages(DESCRIPTOR, 'generate_pb2', globals())
if _descriptor._USE_C_DESCRIPTORS == False:
DESCRIPTOR._options = None
_SERVICEDISCOVERYRESPONSE._serialized_start=31
_SERVICEDISCOVERYRESPONSE._serialized_end=71
_LOGITSWARPERPARAMETERS._serialized_start=73
_LOGITSWARPERPARAMETERS._serialized_end=167
_REQUEST._serialized_start=169
_REQUEST._serialized_end=287
_BATCH._serialized_start=289
_BATCH._serialized_end=348
_BATCHCACHED._serialized_start=350
_BATCHCACHED._serialized_end=477
_FINISHEDGENERATION._serialized_start=479
_FINISHEDGENERATION._serialized_end=527
_CACHEENTRY._serialized_start=529
_CACHEENTRY._serialized_end=599
_RESPONSE._serialized_start=602
_RESPONSE._serialized_end=730
_EMPTY._serialized_start=732
_EMPTY._serialized_end=739
_TEXTGENERATION._serialized_start=742
_TEXTGENERATION._serialized_end=1018
# @@protoc_insertion_point(module_scope)
# -*- coding: utf-8 -*-
# Generated by the protocol buffer compiler. DO NOT EDIT!
# source: generate.proto
"""Generated protocol buffer code."""
from google.protobuf.internal import builder as _builder
from google.protobuf import descriptor as _descriptor
from google.protobuf import descriptor_pool as _descriptor_pool
from google.protobuf import symbol_database as _symbol_database
# @@protoc_insertion_point(imports)
_sym_db = _symbol_database.Default()
DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\x0egenerate.proto\x12\x0bgenerate.v1\"(\n\x18ServiceDiscoveryResponse\x12\x0c\n\x04urls\x18\x01 \x03(\t\"^\n\x16LogitsWarperParameters\x12\x13\n\x0btemperature\x18\x01 \x01(\x02\x12\r\n\x05top_k\x18\x02 \x01(\r\x12\r\n\x05top_p\x18\x03 \x01(\x02\x12\x11\n\tdo_sample\x18\x04 \x01(\x08\"v\n\x07Request\x12\n\n\x02id\x18\x01 \x01(\x04\x12\x0e\n\x06inputs\x18\x02 \x01(\t\x12\x37\n\nparameters\x18\x03 \x01(\x0b\x32#.generate.v1.LogitsWarperParameters\x12\x16\n\x0emax_new_tokens\x18\x04 \x01(\r\";\n\x05\x42\x61tch\x12\n\n\x02id\x18\x01 \x01(\x04\x12&\n\x08requests\x18\x02 \x03(\x0b\x32\x14.generate.v1.Request\"\x7f\n\x0b\x42\x61tchCached\x12\n\n\x02id\x18\x01 \x01(\x04\x12\x13\n\x0brequest_ids\x18\x02 \x03(\x04\x12\x18\n\x10\x62\x61tch_cached_ids\x18\x03 \x03(\x04\x12\x18\n\x10total_batch_size\x18\x04 \x01(\r\x12\x1b\n\x13max_sequence_length\x18\x05 \x01(\r\"0\n\x12\x46inishedGeneration\x12\n\n\x02id\x18\x01 \x01(\x04\x12\x0e\n\x06output\x18\x02 \x01(\t\"F\n\nCacheEntry\x12\n\n\x02id\x18\x01 \x01(\x04\x12\x13\n\x0brequest_ids\x18\x02 \x03(\x04\x12\x17\n\x0fsequence_length\x18\x03 \x01(\r\"\x80\x01\n\x08Response\x12\x31\n\x08\x66inished\x18\x01 \x03(\x0b\x32\x1f.generate.v1.FinishedGeneration\x12\x31\n\x0b\x63\x61\x63he_entry\x18\x02 \x01(\x0b\x32\x17.generate.v1.CacheEntryH\x00\x88\x01\x01\x42\x0e\n\x0c_cache_entry\"\x07\n\x05\x45mpty2\x94\x02\n\x0eTextGeneration\x12O\n\x10ServiceDiscovery\x12\x12.generate.v1.Empty\x1a%.generate.v1.ServiceDiscoveryResponse\"\x00\x12\x34\n\nClearCache\x12\x12.generate.v1.Empty\x1a\x12.generate.v1.Empty\x12\x35\n\x08Generate\x12\x12.generate.v1.Batch\x1a\x15.generate.v1.Response\x12\x44\n\x11GenerateWithCache\x12\x18.generate.v1.BatchCached\x1a\x15.generate.v1.Responseb\x06proto3')
_builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, globals())
_builder.BuildTopDescriptorsAndMessages(DESCRIPTOR, 'generate_pb2', globals())
if _descriptor._USE_C_DESCRIPTORS == False:
DESCRIPTOR._options = None
_SERVICEDISCOVERYRESPONSE._serialized_start=31
_SERVICEDISCOVERYRESPONSE._serialized_end=71
_LOGITSWARPERPARAMETERS._serialized_start=73
_LOGITSWARPERPARAMETERS._serialized_end=167
_REQUEST._serialized_start=169
_REQUEST._serialized_end=287
_BATCH._serialized_start=289
_BATCH._serialized_end=348
_BATCHCACHED._serialized_start=350
_BATCHCACHED._serialized_end=477
_FINISHEDGENERATION._serialized_start=479
_FINISHEDGENERATION._serialized_end=527
_CACHEENTRY._serialized_start=529
_CACHEENTRY._serialized_end=599
_RESPONSE._serialized_start=602
_RESPONSE._serialized_end=730
_EMPTY._serialized_start=732
_EMPTY._serialized_end=739
_TEXTGENERATION._serialized_start=742
_TEXTGENERATION._serialized_end=1018
# @@protoc_insertion_point(module_scope)
# Generated by the gRPC Python protocol compiler plugin. DO NOT EDIT!
"""Client and server classes corresponding to protobuf-defined services."""
import grpc
from . import generate_pb2 as generate__pb2
class TextGenerationStub(object):
"""Missing associated documentation comment in .proto file."""
def __init__(self, channel):
"""Constructor.
Args:
channel: A grpc.Channel.
"""
self.ServiceDiscovery = channel.unary_unary(
'/generate.v1.TextGeneration/ServiceDiscovery',
request_serializer=generate__pb2.Empty.SerializeToString,
response_deserializer=generate__pb2.ServiceDiscoveryResponse.FromString,
)
self.ClearCache = channel.unary_unary(
'/generate.v1.TextGeneration/ClearCache',
request_serializer=generate__pb2.Empty.SerializeToString,
response_deserializer=generate__pb2.Empty.FromString,
)
self.Generate = channel.unary_unary(
'/generate.v1.TextGeneration/Generate',
request_serializer=generate__pb2.Batch.SerializeToString,
response_deserializer=generate__pb2.Response.FromString,
)
self.GenerateWithCache = channel.unary_unary(
'/generate.v1.TextGeneration/GenerateWithCache',
request_serializer=generate__pb2.BatchCached.SerializeToString,
response_deserializer=generate__pb2.Response.FromString,
)
class TextGenerationServicer(object):
"""Missing associated documentation comment in .proto file."""
def ServiceDiscovery(self, request, context):
"""/ Service discovery
"""
context.set_code(grpc.StatusCode.UNIMPLEMENTED)
context.set_details('Method not implemented!')
raise NotImplementedError('Method not implemented!')
def ClearCache(self, request, context):
"""/ Empties batch cache
"""
context.set_code(grpc.StatusCode.UNIMPLEMENTED)
context.set_details('Method not implemented!')
raise NotImplementedError('Method not implemented!')
def Generate(self, request, context):
"""/ Generate tokens for a batch without cache
"""
context.set_code(grpc.StatusCode.UNIMPLEMENTED)
context.set_details('Method not implemented!')
raise NotImplementedError('Method not implemented!')
def GenerateWithCache(self, request, context):
"""/ Generate tokens for a batch with cache
"""
context.set_code(grpc.StatusCode.UNIMPLEMENTED)
context.set_details('Method not implemented!')
raise NotImplementedError('Method not implemented!')
def add_TextGenerationServicer_to_server(servicer, server):
rpc_method_handlers = {
'ServiceDiscovery': grpc.unary_unary_rpc_method_handler(
servicer.ServiceDiscovery,
request_deserializer=generate__pb2.Empty.FromString,
response_serializer=generate__pb2.ServiceDiscoveryResponse.SerializeToString,
),
'ClearCache': grpc.unary_unary_rpc_method_handler(
servicer.ClearCache,
request_deserializer=generate__pb2.Empty.FromString,
response_serializer=generate__pb2.Empty.SerializeToString,
),
'Generate': grpc.unary_unary_rpc_method_handler(
servicer.Generate,
request_deserializer=generate__pb2.Batch.FromString,
response_serializer=generate__pb2.Response.SerializeToString,
),
'GenerateWithCache': grpc.unary_unary_rpc_method_handler(
servicer.GenerateWithCache,
request_deserializer=generate__pb2.BatchCached.FromString,
response_serializer=generate__pb2.Response.SerializeToString,
),
}
generic_handler = grpc.method_handlers_generic_handler(
'generate.v1.TextGeneration', rpc_method_handlers)
server.add_generic_rpc_handlers((generic_handler,))
# This class is part of an EXPERIMENTAL API.
class TextGeneration(object):
"""Missing associated documentation comment in .proto file."""
@staticmethod
def ServiceDiscovery(request,
target,
options=(),
channel_credentials=None,
call_credentials=None,
insecure=False,
compression=None,
wait_for_ready=None,
timeout=None,
metadata=None):
return grpc.experimental.unary_unary(request, target, '/generate.v1.TextGeneration/ServiceDiscovery',
generate__pb2.Empty.SerializeToString,
generate__pb2.ServiceDiscoveryResponse.FromString,
options, channel_credentials,
insecure, call_credentials, compression, wait_for_ready, timeout, metadata)
@staticmethod
def ClearCache(request,
target,
options=(),
channel_credentials=None,
call_credentials=None,
insecure=False,
compression=None,
wait_for_ready=None,
timeout=None,
metadata=None):
return grpc.experimental.unary_unary(request, target, '/generate.v1.TextGeneration/ClearCache',
generate__pb2.Empty.SerializeToString,
generate__pb2.Empty.FromString,
options, channel_credentials,
insecure, call_credentials, compression, wait_for_ready, timeout, metadata)
@staticmethod
def Generate(request,
target,
options=(),
channel_credentials=None,
call_credentials=None,
insecure=False,
compression=None,
wait_for_ready=None,
timeout=None,
metadata=None):
return grpc.experimental.unary_unary(request, target, '/generate.v1.TextGeneration/Generate',
generate__pb2.Batch.SerializeToString,
generate__pb2.Response.FromString,
options, channel_credentials,
insecure, call_credentials, compression, wait_for_ready, timeout, metadata)
@staticmethod
def GenerateWithCache(request,
target,
options=(),
channel_credentials=None,
call_credentials=None,
insecure=False,
compression=None,
wait_for_ready=None,
timeout=None,
metadata=None):
return grpc.experimental.unary_unary(request, target, '/generate.v1.TextGeneration/GenerateWithCache',
generate__pb2.BatchCached.SerializeToString,
generate__pb2.Response.FromString,
options, channel_credentials,
insecure, call_credentials, compression, wait_for_ready, timeout, metadata)
# Generated by the gRPC Python protocol compiler plugin. DO NOT EDIT!
"""Client and server classes corresponding to protobuf-defined services."""
import grpc
import generate_pb2 as generate__pb2
class TextGenerationStub(object):
"""Missing associated documentation comment in .proto file."""
def __init__(self, channel):
"""Constructor.
Args:
channel: A grpc.Channel.
"""
self.ServiceDiscovery = channel.unary_unary(
'/generate.v1.TextGeneration/ServiceDiscovery',
request_serializer=generate__pb2.Empty.SerializeToString,
response_deserializer=generate__pb2.ServiceDiscoveryResponse.FromString,
)
self.ClearCache = channel.unary_unary(
'/generate.v1.TextGeneration/ClearCache',
request_serializer=generate__pb2.Empty.SerializeToString,
response_deserializer=generate__pb2.Empty.FromString,
)
self.Generate = channel.unary_unary(
'/generate.v1.TextGeneration/Generate',
request_serializer=generate__pb2.Batch.SerializeToString,
response_deserializer=generate__pb2.Response.FromString,
)
self.GenerateWithCache = channel.unary_unary(
'/generate.v1.TextGeneration/GenerateWithCache',
request_serializer=generate__pb2.BatchCached.SerializeToString,
response_deserializer=generate__pb2.Response.FromString,
)
class TextGenerationServicer(object):
"""Missing associated documentation comment in .proto file."""
def ServiceDiscovery(self, request, context):
"""/ Service discovery
"""
context.set_code(grpc.StatusCode.UNIMPLEMENTED)
context.set_details('Method not implemented!')
raise NotImplementedError('Method not implemented!')
def ClearCache(self, request, context):
"""/ Empties batch cache
"""
context.set_code(grpc.StatusCode.UNIMPLEMENTED)
context.set_details('Method not implemented!')
raise NotImplementedError('Method not implemented!')
def Generate(self, request, context):
"""/ Generate tokens for a batch without cache
"""
context.set_code(grpc.StatusCode.UNIMPLEMENTED)
context.set_details('Method not implemented!')
raise NotImplementedError('Method not implemented!')
def GenerateWithCache(self, request, context):
"""/ Generate tokens for a batch with cache
"""
context.set_code(grpc.StatusCode.UNIMPLEMENTED)
context.set_details('Method not implemented!')
raise NotImplementedError('Method not implemented!')
def add_TextGenerationServicer_to_server(servicer, server):
rpc_method_handlers = {
'ServiceDiscovery': grpc.unary_unary_rpc_method_handler(
servicer.ServiceDiscovery,
request_deserializer=generate__pb2.Empty.FromString,
response_serializer=generate__pb2.ServiceDiscoveryResponse.SerializeToString,
),
'ClearCache': grpc.unary_unary_rpc_method_handler(
servicer.ClearCache,
request_deserializer=generate__pb2.Empty.FromString,
response_serializer=generate__pb2.Empty.SerializeToString,
),
'Generate': grpc.unary_unary_rpc_method_handler(
servicer.Generate,
request_deserializer=generate__pb2.Batch.FromString,
response_serializer=generate__pb2.Response.SerializeToString,
),
'GenerateWithCache': grpc.unary_unary_rpc_method_handler(
servicer.GenerateWithCache,
request_deserializer=generate__pb2.BatchCached.FromString,
response_serializer=generate__pb2.Response.SerializeToString,
),
}
generic_handler = grpc.method_handlers_generic_handler(
'generate.v1.TextGeneration', rpc_method_handlers)
server.add_generic_rpc_handlers((generic_handler,))
# This class is part of an EXPERIMENTAL API.
class TextGeneration(object):
"""Missing associated documentation comment in .proto file."""
@staticmethod
def ServiceDiscovery(request,
target,
options=(),
channel_credentials=None,
call_credentials=None,
insecure=False,
compression=None,
wait_for_ready=None,
timeout=None,
metadata=None):
return grpc.experimental.unary_unary(request, target, '/generate.v1.TextGeneration/ServiceDiscovery',
generate__pb2.Empty.SerializeToString,
generate__pb2.ServiceDiscoveryResponse.FromString,
options, channel_credentials,
insecure, call_credentials, compression, wait_for_ready, timeout, metadata)
@staticmethod
def ClearCache(request,
target,
options=(),
channel_credentials=None,
call_credentials=None,
insecure=False,
compression=None,
wait_for_ready=None,
timeout=None,
metadata=None):
return grpc.experimental.unary_unary(request, target, '/generate.v1.TextGeneration/ClearCache',
generate__pb2.Empty.SerializeToString,
generate__pb2.Empty.FromString,
options, channel_credentials,
insecure, call_credentials, compression, wait_for_ready, timeout, metadata)
@staticmethod
def Generate(request,
target,
options=(),
channel_credentials=None,
call_credentials=None,
insecure=False,
compression=None,
wait_for_ready=None,
timeout=None,
metadata=None):
return grpc.experimental.unary_unary(request, target, '/generate.v1.TextGeneration/Generate',
generate__pb2.Batch.SerializeToString,
generate__pb2.Response.FromString,
options, channel_credentials,
insecure, call_credentials, compression, wait_for_ready, timeout, metadata)
@staticmethod
def GenerateWithCache(request,
target,
options=(),
channel_credentials=None,
call_credentials=None,
insecure=False,
compression=None,
wait_for_ready=None,
timeout=None,
metadata=None):
return grpc.experimental.unary_unary(request, target, '/generate.v1.TextGeneration/GenerateWithCache',
generate__pb2.BatchCached.SerializeToString,
generate__pb2.Response.FromString,
options, channel_credentials,
insecure, call_credentials, compression, wait_for_ready, timeout, metadata)
import torch
from pathlib import Path
from tqdm import tqdm
MODEL_NAME = "bigscience/bloom"
def match_suffix(text, suffix):
return text[-len(suffix) :] == suffix
def prepare_weights(hub_path: Path, save_path: Path, tp_world_size: int):
save_paths = [
save_path / f"{MODEL_NAME}_tp-rank-{tp_rank}-of-{tp_world_size}.pty"
for tp_rank in range(tp_world_size)
]
if all(save_path.exists() for save_path in save_paths):
print("Weights are already prepared")
return
shards_state_dicts = [{} for _ in range(tp_world_size)]
for weight_path in tqdm(hub_path.glob("*.bin")):
state_dict = torch.load(weight_path, map_location="cpu")
keys = list(state_dict.keys())
for state_name in keys:
state = state_dict[state_name]
if any(
match_suffix(state_name, candidate)
for candidate in [
"self_attention.query_key_value.weight",
"self_attention.query_key_value.bias",
"mlp.dense_h_to_4h.weight",
"mlp.dense_h_to_4h.bias",
"word_embeddings.weight",
"lm_head.weight",
]
):
output_size = state.shape[0]
assert output_size % tp_world_size == 0
block_size = output_size // tp_world_size
sharded_weights = torch.split(state, block_size, dim=0)
assert len(sharded_weights) == tp_world_size
for tp_rank, shard in enumerate(sharded_weights):
assert shard.shape[0] == block_size
if match_suffix(state_name, "lm_head.weight"):
shards_state_dicts[tp_rank][state_name] = shard.detach().clone()
else:
shards_state_dicts[tp_rank][
"transformer." + state_name
] = shard.detach().clone()
elif any(
match_suffix(state_name, candidate)
for candidate in [
"self_attention.dense.weight",
"mlp.dense_4h_to_h.weight",
"lm_head.weight",
]
):
input_size = state.shape[1]
assert input_size % tp_world_size == 0
block_size = input_size // tp_world_size
sharded_weights = torch.split(state, block_size, dim=1)
assert len(sharded_weights) == tp_world_size
for tp_rank, shard in enumerate(sharded_weights):
assert shard.shape[1] == block_size
if match_suffix(state_name, "lm_head.weight"):
shards_state_dicts[tp_rank][state_name] = shard.detach().clone()
else:
shards_state_dicts[tp_rank][
"transformer." + state_name
] = shard.detach().clone()
elif any(
match_suffix(state_name, candidate)
for candidate in [
"self_attention.dense.bias",
"mlp.dense_4h_to_h.bias",
]
):
shards_state_dicts[0][
"transformer." + state_name
] = state.detach().clone()
for tp_rank in range(1, tp_world_size):
shards_state_dicts[tp_rank][
"transformer." + state_name
] = torch.zeros_like(state)
else:
# We duplicate parameters across tp ranks
for tp_rank in range(tp_world_size):
shards_state_dicts[tp_rank][
"transformer." + state_name
] = state.detach().clone()
del state_dict[state_name] # delete key from state_dict
del state # delete tensor
# we save state_dict
for tp_rank, (save_path, shard_state_dict) in enumerate(
zip(save_paths, shards_state_dicts)
):
save_paths.append(save_path)
save_path.parent.mkdir(parents=True, exist_ok=True)
if save_path.exists():
print(f"Skipping {save_path} as it already exists")
else:
torch.save(shard_state_dict, save_path)
return save_paths
if __name__ == "__main__":
from argparse import ArgumentParser
parser = ArgumentParser()
parser.add_argument("--hub-path", required=True, type=str)
parser.add_argument("--save-path", required=True, type=str)
parser.add_argument("--world-size", required=True, type=int)
args = parser.parse_args()
prepare_weights(Path(args.hub_path), Path(args.save_path), args.world_size)
import asyncio
from grpc import aio
from grpc_reflection.v1alpha import reflection
from pathlib import Path
from typing import Optional, List
from bloom_inference.cache import Cache
from bloom_inference.model import BLOOM, Batch, BLOOMSharded
from bloom_inference.pb import generate_pb2_grpc, generate_pb2
class TextGeneration(generate_pb2_grpc.TextGenerationServicer):
def __init__(self, model: BLOOM, cache: Cache, server_urls: List[str]):
self.cache = cache
self.model = model
self.server_urls = server_urls
async def ServiceDiscovery(self, request, context):
return generate_pb2.ServiceDiscoveryResponse(urls=self.server_urls)
async def ClearCache(self, request, context):
self.cache.clear()
return generate_pb2.Empty()
async def Generate(self, request, context):
batch = Batch.from_batch_pb(request, self.model.tokenizer, self.model.device)
finished_generations, cache_entry = self.model.generate_token(batch)
self.cache.set(cache_entry)
return generate_pb2.Response(
finished=[
finished_generation.to_pb()
for finished_generation in finished_generations
],
cache_entry=cache_entry.to_pb() if cache_entry else None,
)
async def GenerateWithCache(self, request, context):
batch = Batch.from_batch_cached_pb(request, self.cache)
finished_generations, cache_entry = self.model.generate_token(batch)
self.cache.set(cache_entry)
return generate_pb2.Response(
finished=[
finished_generation.to_pb()
for finished_generation in finished_generations
],
cache_entry=cache_entry.to_pb() if cache_entry else None,
)
def serve(model_name, sharded, shard_directory):
async def serve_inner(
model_name: str,
sharded: bool = False,
shard_directory: Optional[Path] = None,
):
unix_socket_template = "unix:///tmp/bloom-inference-{}"
if sharded:
if shard_directory is None:
raise ValueError("shard_directory must be set when sharded is True")
model = BLOOMSharded(model_name, shard_directory)
server_urls = [
unix_socket_template.format(rank) for rank in range(model.world_size)
]
local_url = unix_socket_template.format(model.rank)
else:
model = BLOOM(model_name)
local_url = unix_socket_template.format(0)
server_urls = [local_url]
server = aio.server()
generate_pb2_grpc.add_TextGenerationServicer_to_server(
TextGeneration(model, Cache(), server_urls), server
)
SERVICE_NAMES = (
generate_pb2.DESCRIPTOR.services_by_name["TextGeneration"].full_name,
reflection.SERVICE_NAME,
)
reflection.enable_server_reflection(SERVICE_NAMES, server)
server.add_insecure_port(local_url)
await server.start()
print("Server started at {}".format(local_url))
await server.wait_for_termination()
asyncio.run(serve_inner(model_name, sharded, shard_directory))
if __name__ == "__main__":
serve("bigscience/bloom-560m", True, Path("/tmp/models"))
from pathlib import Path
import torch
from torch import nn
from transformers import AutoModelForCausalLM
def match_suffix(text, suffix):
return text[-len(suffix) :] == suffix
def shard_model(model_name: str, path: Path, tp_world_size: int, dtype: torch.dtype):
"""BLOOM specific sharding mechanism"""
save_paths = [
path / f"{model_name}_tp-rank-{tp_rank}-of-{tp_world_size}.pty"
for tp_rank in range(tp_world_size)
]
if all(save_path.exists() for save_path in save_paths):
print("Loading already cached values")
return save_paths
model: nn.Module = AutoModelForCausalLM.from_pretrained(
model_name, torch_dtype=dtype, local_files_only=True
)
shards_state_dicts = [{} for _ in range(tp_world_size)]
state_dict = model.state_dict()
keys = list(state_dict.keys())
for state_name in keys:
print(state_name)
state = state_dict[state_name]
if any(
match_suffix(state_name, candidate)
for candidate in [
"self_attention.query_key_value.weight",
"self_attention.query_key_value.bias",
"mlp.dense_h_to_4h.weight",
"mlp.dense_h_to_4h.bias",
"transformer.word_embeddings.weight",
"lm_head.weight",
]
):
output_size = state.shape[0]
assert output_size % tp_world_size == 0
block_size = output_size // tp_world_size
sharded_weights = torch.split(state, block_size, dim=0)
assert len(sharded_weights) == tp_world_size
for tp_rank, shard in enumerate(sharded_weights):
assert shard.shape[0] == block_size
shards_state_dicts[tp_rank][state_name] = shard.detach().clone()
elif any(
match_suffix(state_name, candidate)
for candidate in [
"self_attention.dense.weight",
"mlp.dense_4h_to_h.weight",
"lm_head.weight",
]
):
input_size = state.shape[1]
assert input_size % tp_world_size == 0
block_size = input_size // tp_world_size
sharded_weights = torch.split(state, block_size, dim=1)
assert len(sharded_weights) == tp_world_size
for tp_rank, shard in enumerate(sharded_weights):
assert shard.shape[1] == block_size
shards_state_dicts[tp_rank][state_name] = shard.detach().clone()
elif any(
match_suffix(state_name, candidate)
for candidate in [
"self_attention.dense.bias",
"mlp.dense_4h_to_h.bias",
]
):
shards_state_dicts[0][state_name] = state.detach().clone()
for tp_rank in range(1, tp_world_size):
shards_state_dicts[tp_rank][state_name] = torch.zeros_like(state)
else:
# We duplicate parameters across tp ranks
for tp_rank in range(tp_world_size):
shards_state_dicts[tp_rank][state_name] = state.detach().clone()
del state_dict[state_name] # delete key from state_dict
del state # delete tensor
# we save state_dict
for tp_rank, (save_path, shard_state_dict) in enumerate(
zip(save_paths, shards_state_dicts)
):
save_path.parent.mkdir(parents=True, exist_ok=True)
torch.save(shard_state_dict, save_path)
save_paths.append(save_path)
return save_paths
if __name__ == "__main__":
model_name = "bigscience/bloom"
save_path = Path("/data/shards")
tp_world_size = 8
dtype = torch.bfloat16
shard_model(model_name, save_path, tp_world_size=tp_world_size, dtype=dtype)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment