Unverified Commit e74bd41e authored by OlivierDehaene's avatar OlivierDehaene Committed by GitHub
Browse files

feat(server): add paged attention to flash models (#516)

Closes #478
parent 70f485bf
......@@ -64,10 +64,12 @@ class FlashLlama(FlashCausalLM):
model = FlashLlamaForCausalLM(config, weights)
torch.distributed.barrier(group=self.process_group)
super(FlashCausalLM, self).__init__(
super(FlashLlama, self).__init__(
model=model,
tokenizer=tokenizer,
requires_padding=False,
num_layers=len(model.model.layers),
num_kv_heads=model.model.num_heads,
head_size=model.model.head_size,
dtype=dtype,
device=device,
rank=rank,
......
......@@ -55,10 +55,12 @@ class FlashNeoXSharded(FlashCausalLM):
model = FlashGPTNeoXForCausalLM(config, weights)
torch.distributed.barrier(group=self.process_group)
super(FlashCausalLM, self).__init__(
super(FlashNeoXSharded, self).__init__(
model=model.to(device),
tokenizer=tokenizer,
requires_padding=False,
num_layers=len(model.gpt_neox.layers),
num_kv_heads=model.gpt_neox.num_heads,
head_size=model.gpt_neox.head_size,
dtype=dtype,
device=device,
rank=rank,
......
......@@ -55,10 +55,12 @@ class FlashRWSharded(FlashCausalLM):
model = FlashRWForCausalLM(config, weights)
torch.distributed.barrier(group=self.process_group)
super(FlashCausalLM, self).__init__(
super(FlashRWSharded, self).__init__(
model=model.to(device),
tokenizer=tokenizer,
requires_padding=False,
num_layers=len(model.transformer.h),
num_kv_heads=model.transformer.cache_size,
head_size=model.transformer.head_size,
dtype=dtype,
device=device,
rank=rank,
......
......@@ -52,17 +52,22 @@ class FlashSantacoderSharded(FlashCausalLM):
torch.distributed.barrier(group=self.process_group)
filenames = weight_files(model_id, revision=revision, extension=".safetensors")
weights = Weights(
filenames, device=device, dtype=dtype, process_group=self.process_group,
aliases = {"transformer.wte.weight": ["lm_head.weight"]}
filenames,
device=device,
dtype=dtype,
process_group=self.process_group,
aliases={"transformer.wte.weight": ["lm_head.weight"]},
)
model = FlashSantacoderForCausalLM(config, weights)
torch.distributed.barrier(group=self.process_group)
super(FlashCausalLM, self).__init__(
super(FlashSantacoderSharded, self).__init__(
model=model.to(device),
tokenizer=tokenizer,
requires_padding=False,
num_layers=len(model.transformer.h),
num_kv_heads=1,
head_size=model.transformer.head_size,
dtype=dtype,
device=device,
rank=rank,
......
......@@ -22,6 +22,9 @@ class Model(ABC):
rank: int = 0,
world_size: int = 1,
):
if torch.cuda.is_available():
torch.cuda.set_per_process_memory_fraction(1.0)
self.model = model.eval()
self.tokenizer = tokenizer
self.all_special_ids = set(tokenizer.all_special_ids)
......@@ -55,6 +58,9 @@ class Model(ABC):
def generate_token(self, batch: B) -> Tuple[List[GeneratedText], Optional[B]]:
raise NotImplementedError
def warmup(self, batch: B, max_total_tokens: int):
self.generate_token(batch)
def decode_token(
self,
all_input_ids: List[int],
......
......@@ -127,7 +127,7 @@ class Seq2SeqLMBatch(Batch):
read_offsets.append(1)
all_decoder_input_ids = decoder_input_ids.view(-1).split(1)
max_tokens = len(inputs) * max_input_length + max_decode_tokens
max_tokens = len(inputs) * (max_input_length + max_decode_tokens)
return cls(
batch_id=pb.id,
......
......@@ -53,6 +53,13 @@ class TextGenerationService(generate_pb2_grpc.TextGenerationServiceServicer):
return generate_pb2.FilterBatchResponse(batch=filtered_batch.to_pb())
async def Warmup(self, request, context):
batch = self.model.batch_type.from_pb(
request.batch, self.model.tokenizer, self.model.dtype, self.model.device
)
self.model.warmup(batch, request.max_total_tokens)
return generate_pb2.WarmupResponse()
async def Prefill(self, request, context):
batch = self.model.batch_type.from_pb(
request.batch, self.model.tokenizer, self.model.dtype, self.model.device
......
......@@ -216,6 +216,8 @@ class HeterogeneousNextTokenChooser:
self.seeds = seeds
self.do_sample = do_sample
self.dtype = dtype
self.device = device
def __call__(self, input_ids: torch.Tensor, scores: torch.Tensor):
if self.watermark_processor is not None:
......
......@@ -5,7 +5,14 @@ import torch
class Weights:
def __init__(self, filenames: List[Path], device, dtype, process_group, aliases: Optional[Dict[str, List[str]]]=None):
def __init__(
self,
filenames: List[Path],
device,
dtype,
process_group,
aliases: Optional[Dict[str, List[str]]] = None,
):
routing = {}
for filename in filenames:
with safe_open(filename, framework="pytorch") as f:
......@@ -43,7 +50,7 @@ class Weights:
return str(filename), tensor_name
def _get_slice(self, tensor_name: str):
filename, tensor_name= self.get_filename(tensor_name)
filename, tensor_name = self.get_filename(tensor_name)
f = self._get_handle(filename)
slice_ = f.get_slice(tensor_name)
return slice_
......@@ -94,12 +101,20 @@ class Weights:
def get_multi_weights_col(self, prefixes: List[str], quantize: str, dim: int):
if quantize == "gptq":
try:
qweight = torch.cat([self.get_sharded(f"{p}.qweight", dim=1) for p in prefixes], dim=1)
qweight = torch.cat(
[self.get_sharded(f"{p}.qweight", dim=1) for p in prefixes], dim=1
)
except RuntimeError:
raise RuntimeError("Cannot load `gptq` weight, make sure the model is already quantized, or quantize it with `text-generation-server quantize ORIGINAL_MODEL_ID NEW_MODEL_ID`")
raise RuntimeError(
"Cannot load `gptq` weight, make sure the model is already quantized, or quantize it with `text-generation-server quantize ORIGINAL_MODEL_ID NEW_MODEL_ID`"
)
qzeros = torch.cat([self.get_sharded(f"{p}.qzeros", dim=1) for p in prefixes], dim=1)
scales = torch.cat([self.get_sharded(f"{p}.scales", dim=1) for p in prefixes], dim=1)
qzeros = torch.cat(
[self.get_sharded(f"{p}.qzeros", dim=1) for p in prefixes], dim=1
)
scales = torch.cat(
[self.get_sharded(f"{p}.scales", dim=1) for p in prefixes], dim=1
)
w = [self.get_tensor(f"{p}.g_idx") for p in prefixes]
for w2 in w[1:]:
torch.testing.assert_close(w2, w[0])
......@@ -118,7 +133,9 @@ class Weights:
try:
qweight = self.get_sharded(f"{prefix}.qweight", dim=0)
except RuntimeError:
raise RuntimeError("Cannot load `gptq` weight, make sure the model is already quantized, or quantize it with `text-generation-server quantize ORIGINAL_MODEL_ID NEW_MODEL_ID`")
raise RuntimeError(
"Cannot load `gptq` weight, make sure the model is already quantized, or quantize it with `text-generation-server quantize ORIGINAL_MODEL_ID NEW_MODEL_ID`"
)
qzeros = self.get_tensor(f"{prefix}.qzeros")
scales = self.get_tensor(f"{prefix}.scales")
g_idx = self.get_sharded(f"{prefix}.g_idx", dim=0)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment