Unverified Commit 47954b81 authored by OlivierDehaene's avatar OlivierDehaene Committed by GitHub
Browse files

feat: format code (#1070)

parent b32e9ce9
...@@ -712,9 +712,11 @@ class Seq2SeqLM(Model): ...@@ -712,9 +712,11 @@ class Seq2SeqLM(Model):
# Decode all tokens # Decode all tokens
output_text, _, _ = self.decode_token( output_text, _, _ = self.decode_token(
all_decoder_input_ids, all_decoder_input_ids,
prefix_offset=len(all_decoder_input_ids) - decoder_input_length - 1, prefix_offset=len(all_decoder_input_ids)
- decoder_input_length
- 1,
read_offset=len(all_decoder_input_ids) - decoder_input_length, read_offset=len(all_decoder_input_ids) - decoder_input_length,
skip_special_tokens=True skip_special_tokens=True,
) )
# Get seed # Get seed
......
...@@ -16,6 +16,7 @@ from text_generation_server.pb import generate_pb2_grpc, generate_pb2 ...@@ -16,6 +16,7 @@ from text_generation_server.pb import generate_pb2_grpc, generate_pb2
from text_generation_server.tracing import UDSOpenTelemetryAioServerInterceptor from text_generation_server.tracing import UDSOpenTelemetryAioServerInterceptor
from text_generation_server.models.idefics_causal_lm import IdeficsCausalLMBatch from text_generation_server.models.idefics_causal_lm import IdeficsCausalLMBatch
class TextGenerationService(generate_pb2_grpc.TextGenerationServiceServicer): class TextGenerationService(generate_pb2_grpc.TextGenerationServiceServicer):
def __init__(self, model: Model, cache: Cache, server_urls: List[str]): def __init__(self, model: Model, cache: Cache, server_urls: List[str]):
self.cache = cache self.cache = cache
...@@ -26,7 +27,6 @@ class TextGenerationService(generate_pb2_grpc.TextGenerationServiceServicer): ...@@ -26,7 +27,6 @@ class TextGenerationService(generate_pb2_grpc.TextGenerationServiceServicer):
# Force inference mode for the lifetime of TextGenerationService # Force inference mode for the lifetime of TextGenerationService
self._inference_mode_raii_guard = torch._C._InferenceMode(True) self._inference_mode_raii_guard = torch._C._InferenceMode(True)
async def Info(self, request, context): async def Info(self, request, context):
return self.model.info return self.model.info
...@@ -55,9 +55,15 @@ class TextGenerationService(generate_pb2_grpc.TextGenerationServiceServicer): ...@@ -55,9 +55,15 @@ class TextGenerationService(generate_pb2_grpc.TextGenerationServiceServicer):
return generate_pb2.FilterBatchResponse(batch=filtered_batch.to_pb()) return generate_pb2.FilterBatchResponse(batch=filtered_batch.to_pb())
async def Warmup(self, request, context): async def Warmup(self, request, context):
if self.model.batch_type == IdeficsCausalLMBatch: #Hack, i would rather use kwargs in the `from_pb` call if (
self.model.batch_type == IdeficsCausalLMBatch
): # Hack, i would rather use kwargs in the `from_pb` call
batch = self.model.batch_type.from_pb( batch = self.model.batch_type.from_pb(
request.batch, self.model.tokenizer, self.model.processor, self.model.dtype, self.model.device request.batch,
self.model.tokenizer,
self.model.processor,
self.model.dtype,
self.model.device,
) )
else: else:
batch = self.model.batch_type.from_pb( batch = self.model.batch_type.from_pb(
...@@ -70,9 +76,15 @@ class TextGenerationService(generate_pb2_grpc.TextGenerationServiceServicer): ...@@ -70,9 +76,15 @@ class TextGenerationService(generate_pb2_grpc.TextGenerationServiceServicer):
) )
async def Prefill(self, request, context): async def Prefill(self, request, context):
if self.model.batch_type == IdeficsCausalLMBatch: #Hack, i would rather use kwargs in the `from_pb` call if (
self.model.batch_type == IdeficsCausalLMBatch
): # Hack, i would rather use kwargs in the `from_pb` call
batch = self.model.batch_type.from_pb( batch = self.model.batch_type.from_pb(
request.batch, self.model.tokenizer, self.model.processor, self.model.dtype, self.model.device request.batch,
self.model.tokenizer,
self.model.processor,
self.model.dtype,
self.model.device,
) )
else: else:
batch = self.model.batch_type.from_pb( batch = self.model.batch_type.from_pb(
......
...@@ -42,7 +42,9 @@ class WQLinear(nn.Module): ...@@ -42,7 +42,9 @@ class WQLinear(nn.Module):
@torch.no_grad() @torch.no_grad()
def forward(self, x): def forward(self, x):
out_shape = x.shape[:-1] + (self.out_features, ) out_shape = x.shape[:-1] + (self.out_features,)
out = awq_inference_engine.gemm_forward_cuda(x.reshape(-1, x.shape[-1]), self.qweight, self.scales, self.qzeros, 8) out = awq_inference_engine.gemm_forward_cuda(
x.reshape(-1, x.shape[-1]), self.qweight, self.scales, self.qzeros, 8
)
out = out + self.bias if self.bias is not None else out out = out + self.bias if self.bias is not None else out
return out.reshape(out_shape) return out.reshape(out_shape)
...@@ -578,7 +578,9 @@ def get_c4_new(nsamples, seed, seqlen, model_id, trust_remote_code): ...@@ -578,7 +578,9 @@ def get_c4_new(nsamples, seed, seqlen, model_id, trust_remote_code):
return trainloader, valenc return trainloader, valenc
def get_loaders(name, nsamples=128, seed=0, seqlen=2048, model_id="", trust_remote_code=False): def get_loaders(
name, nsamples=128, seed=0, seqlen=2048, model_id="", trust_remote_code=False
):
if "wikitext2" in name: if "wikitext2" in name:
return get_wikitext2(nsamples, seed, seqlen, model_id, trust_remote_code) return get_wikitext2(nsamples, seed, seqlen, model_id, trust_remote_code)
if "ptb" in name: if "ptb" in name:
...@@ -927,7 +929,7 @@ def quantize( ...@@ -927,7 +929,7 @@ def quantize(
seed=seed, seed=seed,
model_id=model_id, model_id=model_id,
seqlen=model.seqlen, seqlen=model.seqlen,
trust_remote_code=trust_remote_code trust_remote_code=trust_remote_code,
) )
tick = time.time() tick = time.time()
......
...@@ -38,6 +38,7 @@ if os.getenv("DISABLE_EXLLAMA") == "True": ...@@ -38,6 +38,7 @@ if os.getenv("DISABLE_EXLLAMA") == "True":
elif CAN_EXLLAMA: elif CAN_EXLLAMA:
try: try:
from text_generation_server.utils.gptq.exllama import Ex4bitLinear from text_generation_server.utils.gptq.exllama import Ex4bitLinear
HAS_EXLLAMA = True HAS_EXLLAMA = True
except ImportError: except ImportError:
pass pass
...@@ -47,6 +48,7 @@ from typing import Optional ...@@ -47,6 +48,7 @@ from typing import Optional
HAS_EETQ = False HAS_EETQ = False
try: try:
from EETQ import quant_weights, w8_a16_gemm from EETQ import quant_weights, w8_a16_gemm
HAS_EETQ = True HAS_EETQ = True
except ImportError: except ImportError:
pass pass
...@@ -74,12 +76,18 @@ def load_layer_norm_no_bias(cls, prefix, weights, eps): ...@@ -74,12 +76,18 @@ def load_layer_norm_no_bias(cls, prefix, weights, eps):
ln.bias = None ln.bias = None
return ln return ln
@classmethod @classmethod
def load_conv2d(cls, prefix, weights, in_channels, out_channels, kernel_size, stride): def load_conv2d(cls, prefix, weights, in_channels, out_channels, kernel_size, stride):
weight = weights.get_tensor(f"{prefix}.weight") weight = weights.get_tensor(f"{prefix}.weight")
bias = weights.get_tensor(f"{prefix}.bias") bias = weights.get_tensor(f"{prefix}.bias")
with init_empty_weights(): with init_empty_weights():
conv2d = cls(in_channels=in_channels, out_channels=out_channels, kernel_size=kernel_size, stride=stride) conv2d = cls(
in_channels=in_channels,
out_channels=out_channels,
kernel_size=kernel_size,
stride=stride,
)
conv2d.weight = nn.Parameter(weight) conv2d.weight = nn.Parameter(weight)
conv2d.bias = nn.Parameter(bias) conv2d.bias = nn.Parameter(bias)
...@@ -87,10 +95,17 @@ def load_conv2d(cls, prefix, weights, in_channels, out_channels, kernel_size, st ...@@ -87,10 +95,17 @@ def load_conv2d(cls, prefix, weights, in_channels, out_channels, kernel_size, st
@classmethod @classmethod
def load_conv2d_no_bias(cls, prefix, weights, in_channels, out_channels, kernel_size, stride): def load_conv2d_no_bias(
cls, prefix, weights, in_channels, out_channels, kernel_size, stride
):
weight = weights.get_tensor(f"{prefix}.weight") weight = weights.get_tensor(f"{prefix}.weight")
with init_empty_weights(): with init_empty_weights():
conv2d = cls(in_channels=in_channels, out_channels=out_channels, kernel_size=kernel_size, stride=stride) conv2d = cls(
in_channels=in_channels,
out_channels=out_channels,
kernel_size=kernel_size,
stride=stride,
)
conv2d.weight = nn.Parameter(weight) conv2d.weight = nn.Parameter(weight)
conv2d.bias = None conv2d.bias = None
...@@ -215,7 +230,10 @@ class Linear4bit(nn.Module): ...@@ -215,7 +230,10 @@ class Linear4bit(nn.Module):
def __init__(self, weight, bias, quant_type): def __init__(self, weight, bias, quant_type):
super().__init__() super().__init__()
self.weight = Params4bit( self.weight = Params4bit(
weight.data, requires_grad=False, compress_statistics=True, quant_type=quant_type weight.data,
requires_grad=False,
compress_statistics=True,
quant_type=quant_type,
) )
self.compute_dtype = None self.compute_dtype = None
self.weight.cuda(weight.device) self.weight.cuda(weight.device)
...@@ -246,7 +264,10 @@ class Linear4bit(nn.Module): ...@@ -246,7 +264,10 @@ class Linear4bit(nn.Module):
@lru_cache(1) @lru_cache(1)
def warn_deprecate_bnb(): def warn_deprecate_bnb():
logger.warning("Bitsandbytes 8bit is deprecated, using `eetq` is a drop-in replacement, and has much better performnce") logger.warning(
"Bitsandbytes 8bit is deprecated, using `eetq` is a drop-in replacement, and has much better performnce"
)
def get_linear(weight, bias, quantize): def get_linear(weight, bias, quantize):
if quantize is None: if quantize is None:
...@@ -255,7 +276,9 @@ def get_linear(weight, bias, quantize): ...@@ -255,7 +276,9 @@ def get_linear(weight, bias, quantize):
if HAS_EETQ: if HAS_EETQ:
linear = EETQLinear(weight, bias) linear = EETQLinear(weight, bias)
else: else:
raise ImportError("Please install EETQ from https://github.com/NetEase-FuXi/EETQ") raise ImportError(
"Please install EETQ from https://github.com/NetEase-FuXi/EETQ"
)
elif quantize == "bitsandbytes": elif quantize == "bitsandbytes":
warn_deprecate_bnb() warn_deprecate_bnb()
linear = Linear8bitLt( linear = Linear8bitLt(
...@@ -305,7 +328,14 @@ def get_linear(weight, bias, quantize): ...@@ -305,7 +328,14 @@ def get_linear(weight, bias, quantize):
raise NotImplementedError( raise NotImplementedError(
f"The passed weight is not `awq` compatible, loader needs to be updated." f"The passed weight is not `awq` compatible, loader needs to be updated."
) )
linear = WQLinear(w_bit=bits, group_size=groupsize, qweight=qweight, qzeros=qzeros, scales=scales, bias=bias is not None) linear = WQLinear(
w_bit=bits,
group_size=groupsize,
qweight=qweight,
qzeros=qzeros,
scales=scales,
bias=bias is not None,
)
else: else:
raise NotImplementedError(f"Quantization `{quantize}` is not implemented yet.") raise NotImplementedError(f"Quantization `{quantize}` is not implemented yet.")
return linear return linear
...@@ -392,9 +422,7 @@ class TensorParallelColumnLinear(SuperLayer): ...@@ -392,9 +422,7 @@ class TensorParallelColumnLinear(SuperLayer):
@classmethod @classmethod
def load_qkv(cls, config, prefix: str, weights, bias: bool): def load_qkv(cls, config, prefix: str, weights, bias: bool):
"""Specific method when the QKV was joined after the fact""" """Specific method when the QKV was joined after the fact"""
weight = weights.get_weights_col_packed_qkv( weight = weights.get_weights_col_packed_qkv(prefix, quantize=config.quantize)
prefix, quantize=config.quantize
)
if bias: if bias:
raise NotImplementedError("packed_qkv only implemented for baichuan") raise NotImplementedError("packed_qkv only implemented for baichuan")
else: else:
...@@ -530,14 +558,16 @@ try: ...@@ -530,14 +558,16 @@ try:
def _create_inv_freq(dim, base, device): def _create_inv_freq(dim, base, device):
inv_freq = 1.0 / ( inv_freq = 1.0 / (
base base ** (torch.arange(0, dim, 2, device=device, dtype=torch.float32) / dim)
** (torch.arange(0, dim, 2, device=device, dtype=torch.float32) / dim)
) )
return inv_freq return inv_freq
def _get_rope_config(config): def _get_rope_config(config):
if os.getenv("ROPE_SCALING", None) is not None: if os.getenv("ROPE_SCALING", None) is not None:
rope_scaling = {"type": os.environ["ROPE_SCALING"], "factor": float(os.environ["ROPE_FACTOR"])} rope_scaling = {
"type": os.environ["ROPE_SCALING"],
"factor": float(os.environ["ROPE_FACTOR"]),
}
return rope_scaling return rope_scaling
return getattr(config, "rope_scaling", None) return getattr(config, "rope_scaling", None)
...@@ -563,9 +593,17 @@ try: ...@@ -563,9 +593,17 @@ try:
if rope_scaling["type"] == "linear": if rope_scaling["type"] == "linear":
pass pass
elif rope_scaling["type"] == "dynamic": elif rope_scaling["type"] == "dynamic":
return DynamicPositionRotaryEmbedding(dim=dim, max_position_embeddings=config.max_position_embeddings, base=base, device=inv_freq.device, scaling_factor=scaling_factor) return DynamicPositionRotaryEmbedding(
dim=dim,
max_position_embeddings=config.max_position_embeddings,
base=base,
device=inv_freq.device,
scaling_factor=scaling_factor,
)
else: else:
raise NotImplementedError(f"rope scaling type {rope_scaling['type']} is not implemented or invalid") raise NotImplementedError(
f"rope scaling type {rope_scaling['type']} is not implemented or invalid"
)
return cls(inv_freq, scaling_factor) return cls(inv_freq, scaling_factor)
@classmethod @classmethod
...@@ -583,9 +621,17 @@ try: ...@@ -583,9 +621,17 @@ try:
if rope_scaling["type"] == "linear": if rope_scaling["type"] == "linear":
pass pass
elif rope_scaling["type"] == "dynamic": elif rope_scaling["type"] == "dynamic":
return DynamicPositionRotaryEmbedding(dim=2*inv_freq.shape[0], max_position_embeddings=config.max_position_embeddings, base=10000.0, device=inv_freq.device, scaling_factor=scaling_factor) return DynamicPositionRotaryEmbedding(
dim=2 * inv_freq.shape[0],
max_position_embeddings=config.max_position_embeddings,
base=10000.0,
device=inv_freq.device,
scaling_factor=scaling_factor,
)
else: else:
raise NotImplementedError(f"rope scaling type {rope_scaling['type']} is not implemented or invalid") raise NotImplementedError(
f"rope scaling type {rope_scaling['type']} is not implemented or invalid"
)
return cls(inv_freq, scaling_factor) return cls(inv_freq, scaling_factor)
def _update_cos_sin_cache(self, dtype, device, seqlen): def _update_cos_sin_cache(self, dtype, device, seqlen):
...@@ -645,8 +691,13 @@ try: ...@@ -645,8 +691,13 @@ try:
or self._cos_cached.dtype != dtype or self._cos_cached.dtype != dtype
): ):
if seqlen > self.max_position_embeddings: if seqlen > self.max_position_embeddings:
newbase = self.base * ((self.scaling_factor * seqlen / self.max_position_embeddings) - (self.scaling_factor - 1)) ** (self.dim / (self.dim - 2)) newbase = self.base * (
self.inv_freq = _create_inv_freq(self.dim, newbase, self.inv_freq.device) (self.scaling_factor * seqlen / self.max_position_embeddings)
- (self.scaling_factor - 1)
) ** (self.dim / (self.dim - 2))
self.inv_freq = _create_inv_freq(
self.dim, newbase, self.inv_freq.device
)
self._seq_len_cached = seqlen self._seq_len_cached = seqlen
t = torch.arange(seqlen, device=device, dtype=self.inv_freq.dtype) t = torch.arange(seqlen, device=device, dtype=self.inv_freq.dtype)
# Don't do einsum, it converts fp32 to fp16 # Don't do einsum, it converts fp32 to fp16
...@@ -656,6 +707,5 @@ try: ...@@ -656,6 +707,5 @@ try:
self._cos_cached = torch.cos(freqs).to(dtype) self._cos_cached = torch.cos(freqs).to(dtype)
self._sin_cached = torch.sin(freqs).to(dtype) self._sin_cached = torch.sin(freqs).to(dtype)
except ImportError: except ImportError:
pass pass
...@@ -6,6 +6,7 @@ import torch ...@@ -6,6 +6,7 @@ import torch
from transformers import AutoTokenizer from transformers import AutoTokenizer
from peft import AutoPeftModelForCausalLM, AutoPeftModelForSeq2SeqLM from peft import AutoPeftModelForCausalLM, AutoPeftModelForSeq2SeqLM
def download_and_unload_peft(model_id, revision, trust_remote_code): def download_and_unload_peft(model_id, revision, trust_remote_code):
torch_dtype = torch.float16 torch_dtype = torch.float16
...@@ -41,6 +42,3 @@ def download_and_unload_peft(model_id, revision, trust_remote_code): ...@@ -41,6 +42,3 @@ def download_and_unload_peft(model_id, revision, trust_remote_code):
model.save_pretrained(cache_dir, safe_serialization=True) model.save_pretrained(cache_dir, safe_serialization=True)
model.config.save_pretrained(cache_dir) model.config.save_pretrained(cache_dir)
tokenizer.save_pretrained(cache_dir) tokenizer.save_pretrained(cache_dir)
...@@ -62,7 +62,7 @@ class Weights: ...@@ -62,7 +62,7 @@ class Weights:
def get_shape(self, tensor_name: str): def get_shape(self, tensor_name: str):
return self._get_slice(tensor_name).get_shape() return self._get_slice(tensor_name).get_shape()
def get_tensor(self, tensor_name: str, to_device = True): def get_tensor(self, tensor_name: str, to_device=True):
filename, tensor_name = self.get_filename(tensor_name) filename, tensor_name = self.get_filename(tensor_name)
f = self._get_handle(filename) f = self._get_handle(filename)
tensor = f.get_tensor(tensor_name) tensor = f.get_tensor(tensor_name)
...@@ -110,7 +110,6 @@ class Weights: ...@@ -110,7 +110,6 @@ class Weights:
), f"The choosen size {size} is not compatible with sharding on {world_size} shards" ), f"The choosen size {size} is not compatible with sharding on {world_size} shards"
return self.get_partial_sharded(tensor_name, dim) return self.get_partial_sharded(tensor_name, dim)
def _get_qweight(self, name: str): def _get_qweight(self, name: str):
slice_ = self._get_slice(name) slice_ = self._get_slice(name)
total_size = slice_.get_shape()[1] total_size = slice_.get_shape()[1]
...@@ -119,14 +118,16 @@ class Weights: ...@@ -119,14 +118,16 @@ class Weights:
world_size = self.process_group.size() world_size = self.process_group.size()
rank = self.process_group.rank() rank = self.process_group.rank()
assert single_size % world_size == 0, f"Prepacked quantized qkv cannot be sharded across {world_size} shards" assert (
single_size % world_size == 0
), f"Prepacked quantized qkv cannot be sharded across {world_size} shards"
block_size = single_size // world_size block_size = single_size // world_size
start = rank * block_size start = rank * block_size
stop = (rank + 1) * block_size stop = (rank + 1) * block_size
q = slice_[:, start:stop] q = slice_[:, start:stop]
k = slice_[:, start+single_size:stop+single_size] k = slice_[:, start + single_size : stop + single_size]
v = slice_[:, start+2*single_size:stop+2*single_size] v = slice_[:, start + 2 * single_size : stop + 2 * single_size]
weight = torch.cat([q,k,v], dim=1) weight = torch.cat([q, k, v], dim=1)
weight = weight.to(device=self.device) weight = weight.to(device=self.device)
return weight return weight
...@@ -161,14 +162,16 @@ class Weights: ...@@ -161,14 +162,16 @@ class Weights:
world_size = self.process_group.size() world_size = self.process_group.size()
rank = self.process_group.rank() rank = self.process_group.rank()
assert single_size % world_size == 0, f"Prepacked qkv cannot be sharded across {world_size} shards" assert (
single_size % world_size == 0
), f"Prepacked qkv cannot be sharded across {world_size} shards"
block_size = single_size // world_size block_size = single_size // world_size
start = rank * block_size start = rank * block_size
stop = (rank + 1) * block_size stop = (rank + 1) * block_size
q = slice_[start:stop] q = slice_[start:stop]
k = slice_[start+single_size:stop+single_size] k = slice_[start + single_size : stop + single_size]
v = slice_[start+2*single_size:stop+2*single_size] v = slice_[start + 2 * single_size : stop + 2 * single_size]
weight = torch.cat([q,k,v], dim=0) weight = torch.cat([q, k, v], dim=0)
weight = weight.to(device=self.device) weight = weight.to(device=self.device)
weight = weight.to(dtype=self.dtype) weight = weight.to(dtype=self.dtype)
return weight return weight
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment