Unverified Commit 47954b81 authored by OlivierDehaene's avatar OlivierDehaene Committed by GitHub
Browse files

feat: format code (#1070)

parent b32e9ce9
......@@ -712,9 +712,11 @@ class Seq2SeqLM(Model):
# Decode all tokens
output_text, _, _ = self.decode_token(
all_decoder_input_ids,
prefix_offset=len(all_decoder_input_ids) - decoder_input_length - 1,
prefix_offset=len(all_decoder_input_ids)
- decoder_input_length
- 1,
read_offset=len(all_decoder_input_ids) - decoder_input_length,
skip_special_tokens=True
skip_special_tokens=True,
)
# Get seed
......
......@@ -16,6 +16,7 @@ from text_generation_server.pb import generate_pb2_grpc, generate_pb2
from text_generation_server.tracing import UDSOpenTelemetryAioServerInterceptor
from text_generation_server.models.idefics_causal_lm import IdeficsCausalLMBatch
class TextGenerationService(generate_pb2_grpc.TextGenerationServiceServicer):
def __init__(self, model: Model, cache: Cache, server_urls: List[str]):
self.cache = cache
......@@ -26,7 +27,6 @@ class TextGenerationService(generate_pb2_grpc.TextGenerationServiceServicer):
# Force inference mode for the lifetime of TextGenerationService
self._inference_mode_raii_guard = torch._C._InferenceMode(True)
async def Info(self, request, context):
return self.model.info
......@@ -55,9 +55,15 @@ class TextGenerationService(generate_pb2_grpc.TextGenerationServiceServicer):
return generate_pb2.FilterBatchResponse(batch=filtered_batch.to_pb())
async def Warmup(self, request, context):
if self.model.batch_type == IdeficsCausalLMBatch: #Hack, i would rather use kwargs in the `from_pb` call
if (
self.model.batch_type == IdeficsCausalLMBatch
): # Hack, i would rather use kwargs in the `from_pb` call
batch = self.model.batch_type.from_pb(
request.batch, self.model.tokenizer, self.model.processor, self.model.dtype, self.model.device
request.batch,
self.model.tokenizer,
self.model.processor,
self.model.dtype,
self.model.device,
)
else:
batch = self.model.batch_type.from_pb(
......@@ -70,9 +76,15 @@ class TextGenerationService(generate_pb2_grpc.TextGenerationServiceServicer):
)
async def Prefill(self, request, context):
if self.model.batch_type == IdeficsCausalLMBatch: #Hack, i would rather use kwargs in the `from_pb` call
if (
self.model.batch_type == IdeficsCausalLMBatch
): # Hack, i would rather use kwargs in the `from_pb` call
batch = self.model.batch_type.from_pb(
request.batch, self.model.tokenizer, self.model.processor, self.model.dtype, self.model.device
request.batch,
self.model.tokenizer,
self.model.processor,
self.model.dtype,
self.model.device,
)
else:
batch = self.model.batch_type.from_pb(
......
......@@ -11,7 +11,7 @@ import awq_inference_engine # with CUDA kernels
# super().__init__()
# self.act = module
# self.scales = nn.Parameter(scales.data)
#
#
# def forward(self, x):
# return self.act(x) / self.scales.view(1, 1, -1).to(x.device)
......@@ -19,10 +19,10 @@ import awq_inference_engine # with CUDA kernels
class WQLinear(nn.Module):
def __init__(self, w_bit, group_size, qweight, qzeros, scales, bias):
super().__init__()
if w_bit not in [4]:
raise NotImplementedError("Only 4-bit are supported for now.")
self.in_features = qweight.shape[0]
self.out_features = qweight.shape[1] * 32 // w_bit
......@@ -42,7 +42,9 @@ class WQLinear(nn.Module):
@torch.no_grad()
def forward(self, x):
out_shape = x.shape[:-1] + (self.out_features, )
out = awq_inference_engine.gemm_forward_cuda(x.reshape(-1, x.shape[-1]), self.qweight, self.scales, self.qzeros, 8)
out_shape = x.shape[:-1] + (self.out_features,)
out = awq_inference_engine.gemm_forward_cuda(
x.reshape(-1, x.shape[-1]), self.qweight, self.scales, self.qzeros, 8
)
out = out + self.bias if self.bias is not None else out
return out.reshape(out_shape)
......@@ -578,7 +578,9 @@ def get_c4_new(nsamples, seed, seqlen, model_id, trust_remote_code):
return trainloader, valenc
def get_loaders(name, nsamples=128, seed=0, seqlen=2048, model_id="", trust_remote_code=False):
def get_loaders(
name, nsamples=128, seed=0, seqlen=2048, model_id="", trust_remote_code=False
):
if "wikitext2" in name:
return get_wikitext2(nsamples, seed, seqlen, model_id, trust_remote_code)
if "ptb" in name:
......@@ -927,7 +929,7 @@ def quantize(
seed=seed,
model_id=model_id,
seqlen=model.seqlen,
trust_remote_code=trust_remote_code
trust_remote_code=trust_remote_code,
)
tick = time.time()
......
......@@ -22,7 +22,7 @@ from text_generation_server.utils.gptq.quant_linear import QuantLinear
HAS_AWQ = True
try:
try:
from text_generation_server.utils.awq.quantize.qmodule import WQLinear
except ImportError:
HAS_AWQ = False
......@@ -36,17 +36,19 @@ CAN_EXLLAMA = major >= 8
if os.getenv("DISABLE_EXLLAMA") == "True":
HAS_EXLLAMA = False
elif CAN_EXLLAMA:
try:
from text_generation_server.utils.gptq.exllama import Ex4bitLinear
HAS_EXLLAMA = True
except ImportError:
pass
try:
from text_generation_server.utils.gptq.exllama import Ex4bitLinear
HAS_EXLLAMA = True
except ImportError:
pass
from typing import Optional
HAS_EETQ = False
try:
from EETQ import quant_weights, w8_a16_gemm
HAS_EETQ = True
except ImportError:
pass
......@@ -74,12 +76,18 @@ def load_layer_norm_no_bias(cls, prefix, weights, eps):
ln.bias = None
return ln
@classmethod
def load_conv2d(cls, prefix, weights, in_channels, out_channels, kernel_size, stride):
weight = weights.get_tensor(f"{prefix}.weight")
bias = weights.get_tensor(f"{prefix}.bias")
with init_empty_weights():
conv2d = cls(in_channels=in_channels, out_channels=out_channels, kernel_size=kernel_size, stride=stride)
conv2d = cls(
in_channels=in_channels,
out_channels=out_channels,
kernel_size=kernel_size,
stride=stride,
)
conv2d.weight = nn.Parameter(weight)
conv2d.bias = nn.Parameter(bias)
......@@ -87,10 +95,17 @@ def load_conv2d(cls, prefix, weights, in_channels, out_channels, kernel_size, st
@classmethod
def load_conv2d_no_bias(cls, prefix, weights, in_channels, out_channels, kernel_size, stride):
def load_conv2d_no_bias(
cls, prefix, weights, in_channels, out_channels, kernel_size, stride
):
weight = weights.get_tensor(f"{prefix}.weight")
with init_empty_weights():
conv2d = cls(in_channels=in_channels, out_channels=out_channels, kernel_size=kernel_size, stride=stride)
conv2d = cls(
in_channels=in_channels,
out_channels=out_channels,
kernel_size=kernel_size,
stride=stride,
)
conv2d.weight = nn.Parameter(weight)
conv2d.bias = None
......@@ -215,7 +230,10 @@ class Linear4bit(nn.Module):
def __init__(self, weight, bias, quant_type):
super().__init__()
self.weight = Params4bit(
weight.data, requires_grad=False, compress_statistics=True, quant_type=quant_type
weight.data,
requires_grad=False,
compress_statistics=True,
quant_type=quant_type,
)
self.compute_dtype = None
self.weight.cuda(weight.device)
......@@ -246,7 +264,10 @@ class Linear4bit(nn.Module):
@lru_cache(1)
def warn_deprecate_bnb():
logger.warning("Bitsandbytes 8bit is deprecated, using `eetq` is a drop-in replacement, and has much better performnce")
logger.warning(
"Bitsandbytes 8bit is deprecated, using `eetq` is a drop-in replacement, and has much better performnce"
)
def get_linear(weight, bias, quantize):
if quantize is None:
......@@ -255,7 +276,9 @@ def get_linear(weight, bias, quantize):
if HAS_EETQ:
linear = EETQLinear(weight, bias)
else:
raise ImportError("Please install EETQ from https://github.com/NetEase-FuXi/EETQ")
raise ImportError(
"Please install EETQ from https://github.com/NetEase-FuXi/EETQ"
)
elif quantize == "bitsandbytes":
warn_deprecate_bnb()
linear = Linear8bitLt(
......@@ -305,7 +328,14 @@ def get_linear(weight, bias, quantize):
raise NotImplementedError(
f"The passed weight is not `awq` compatible, loader needs to be updated."
)
linear = WQLinear(w_bit=bits, group_size=groupsize, qweight=qweight, qzeros=qzeros, scales=scales, bias=bias is not None)
linear = WQLinear(
w_bit=bits,
group_size=groupsize,
qweight=qweight,
qzeros=qzeros,
scales=scales,
bias=bias is not None,
)
else:
raise NotImplementedError(f"Quantization `{quantize}` is not implemented yet.")
return linear
......@@ -392,9 +422,7 @@ class TensorParallelColumnLinear(SuperLayer):
@classmethod
def load_qkv(cls, config, prefix: str, weights, bias: bool):
"""Specific method when the QKV was joined after the fact"""
weight = weights.get_weights_col_packed_qkv(
prefix, quantize=config.quantize
)
weight = weights.get_weights_col_packed_qkv(prefix, quantize=config.quantize)
if bias:
raise NotImplementedError("packed_qkv only implemented for baichuan")
else:
......@@ -530,14 +558,16 @@ try:
def _create_inv_freq(dim, base, device):
inv_freq = 1.0 / (
base
** (torch.arange(0, dim, 2, device=device, dtype=torch.float32) / dim)
base ** (torch.arange(0, dim, 2, device=device, dtype=torch.float32) / dim)
)
return inv_freq
def _get_rope_config(config):
if os.getenv("ROPE_SCALING", None) is not None:
rope_scaling = {"type": os.environ["ROPE_SCALING"], "factor": float(os.environ["ROPE_FACTOR"])}
rope_scaling = {
"type": os.environ["ROPE_SCALING"],
"factor": float(os.environ["ROPE_FACTOR"]),
}
return rope_scaling
return getattr(config, "rope_scaling", None)
......@@ -563,9 +593,17 @@ try:
if rope_scaling["type"] == "linear":
pass
elif rope_scaling["type"] == "dynamic":
return DynamicPositionRotaryEmbedding(dim=dim, max_position_embeddings=config.max_position_embeddings, base=base, device=inv_freq.device, scaling_factor=scaling_factor)
return DynamicPositionRotaryEmbedding(
dim=dim,
max_position_embeddings=config.max_position_embeddings,
base=base,
device=inv_freq.device,
scaling_factor=scaling_factor,
)
else:
raise NotImplementedError(f"rope scaling type {rope_scaling['type']} is not implemented or invalid")
raise NotImplementedError(
f"rope scaling type {rope_scaling['type']} is not implemented or invalid"
)
return cls(inv_freq, scaling_factor)
@classmethod
......@@ -583,9 +621,17 @@ try:
if rope_scaling["type"] == "linear":
pass
elif rope_scaling["type"] == "dynamic":
return DynamicPositionRotaryEmbedding(dim=2*inv_freq.shape[0], max_position_embeddings=config.max_position_embeddings, base=10000.0, device=inv_freq.device, scaling_factor=scaling_factor)
return DynamicPositionRotaryEmbedding(
dim=2 * inv_freq.shape[0],
max_position_embeddings=config.max_position_embeddings,
base=10000.0,
device=inv_freq.device,
scaling_factor=scaling_factor,
)
else:
raise NotImplementedError(f"rope scaling type {rope_scaling['type']} is not implemented or invalid")
raise NotImplementedError(
f"rope scaling type {rope_scaling['type']} is not implemented or invalid"
)
return cls(inv_freq, scaling_factor)
def _update_cos_sin_cache(self, dtype, device, seqlen):
......@@ -645,8 +691,13 @@ try:
or self._cos_cached.dtype != dtype
):
if seqlen > self.max_position_embeddings:
newbase = self.base * ((self.scaling_factor * seqlen / self.max_position_embeddings) - (self.scaling_factor - 1)) ** (self.dim / (self.dim - 2))
self.inv_freq = _create_inv_freq(self.dim, newbase, self.inv_freq.device)
newbase = self.base * (
(self.scaling_factor * seqlen / self.max_position_embeddings)
- (self.scaling_factor - 1)
) ** (self.dim / (self.dim - 2))
self.inv_freq = _create_inv_freq(
self.dim, newbase, self.inv_freq.device
)
self._seq_len_cached = seqlen
t = torch.arange(seqlen, device=device, dtype=self.inv_freq.dtype)
# Don't do einsum, it converts fp32 to fp16
......@@ -656,6 +707,5 @@ try:
self._cos_cached = torch.cos(freqs).to(dtype)
self._sin_cached = torch.sin(freqs).to(dtype)
except ImportError:
pass
......@@ -6,6 +6,7 @@ import torch
from transformers import AutoTokenizer
from peft import AutoPeftModelForCausalLM, AutoPeftModelForSeq2SeqLM
def download_and_unload_peft(model_id, revision, trust_remote_code):
torch_dtype = torch.float16
......@@ -33,7 +34,7 @@ def download_and_unload_peft(model_id, revision, trust_remote_code):
base_model_id = model.peft_config["default"].base_model_name_or_path
model = model.merge_and_unload()
os.makedirs(model_id, exist_ok=True)
cache_dir = model_id
logger.info(f"Saving the newly created merged model to {cache_dir}")
......@@ -41,6 +42,3 @@ def download_and_unload_peft(model_id, revision, trust_remote_code):
model.save_pretrained(cache_dir, safe_serialization=True)
model.config.save_pretrained(cache_dir)
tokenizer.save_pretrained(cache_dir)
......@@ -363,7 +363,7 @@ def batch_top_tokens(
# Find the new "fuzzy" top n values
top_n_indices = (logprobs >= nth_highest).nonzero()
_, top_n_ishes = torch.unique_consecutive(top_n_indices[:, 0], return_counts=True)
k = 1 if top_n_ishes.numel() == 0 else top_n_ishes.max()
# Take a new topk for these new max n values
top_k = torch.topk(logprobs, k=k, dim=1, sorted=True)
......
......@@ -62,7 +62,7 @@ class Weights:
def get_shape(self, tensor_name: str):
return self._get_slice(tensor_name).get_shape()
def get_tensor(self, tensor_name: str, to_device = True):
def get_tensor(self, tensor_name: str, to_device=True):
filename, tensor_name = self.get_filename(tensor_name)
f = self._get_handle(filename)
tensor = f.get_tensor(tensor_name)
......@@ -110,7 +110,6 @@ class Weights:
), f"The choosen size {size} is not compatible with sharding on {world_size} shards"
return self.get_partial_sharded(tensor_name, dim)
def _get_qweight(self, name: str):
slice_ = self._get_slice(name)
total_size = slice_.get_shape()[1]
......@@ -119,14 +118,16 @@ class Weights:
world_size = self.process_group.size()
rank = self.process_group.rank()
assert single_size % world_size == 0, f"Prepacked quantized qkv cannot be sharded across {world_size} shards"
assert (
single_size % world_size == 0
), f"Prepacked quantized qkv cannot be sharded across {world_size} shards"
block_size = single_size // world_size
start = rank * block_size
stop = (rank + 1) * block_size
q = slice_[:, start:stop]
k = slice_[:, start+single_size:stop+single_size]
v = slice_[:, start+2*single_size:stop+2*single_size]
weight = torch.cat([q,k,v], dim=1)
k = slice_[:, start + single_size : stop + single_size]
v = slice_[:, start + 2 * single_size : stop + 2 * single_size]
weight = torch.cat([q, k, v], dim=1)
weight = weight.to(device=self.device)
return weight
......@@ -137,14 +138,14 @@ class Weights:
"""
if quantize in ["gptq", "awq"]:
try:
qweight = self._get_qweight(f"{prefix}.qweight")
qweight = self._get_qweight(f"{prefix}.qweight")
except RuntimeError:
raise RuntimeError(
f"Cannot load `{quantize}` weight, make sure the model is already quantized."
)
qzeros = self._get_qweight(f"{prefix}.qzeros")
scales = self._get_qweight(f"{prefix}.scales")
qzeros = self._get_qweight(f"{prefix}.qzeros")
scales = self._get_qweight(f"{prefix}.scales")
scales = scales.to(dtype=self.dtype)
if quantize == "gptq":
g_idx = self.get_tensor(f"{prefix}.g_idx")
......@@ -154,21 +155,23 @@ class Weights:
bits, groupsize = self._get_gptq_params()
weight = (qweight, qzeros, scales, g_idx, bits, groupsize, False)
else:
slice_ = self._get_slice(f"{prefix}.weight")
slice_ = self._get_slice(f"{prefix}.weight")
total_size = slice_.get_shape()[0]
assert total_size % 3 == 0, "Prepacked qkv is not divisible by 3"
single_size = total_size // 3
world_size = self.process_group.size()
rank = self.process_group.rank()
assert single_size % world_size == 0, f"Prepacked qkv cannot be sharded across {world_size} shards"
assert (
single_size % world_size == 0
), f"Prepacked qkv cannot be sharded across {world_size} shards"
block_size = single_size // world_size
start = rank * block_size
stop = (rank + 1) * block_size
q = slice_[start:stop]
k = slice_[start+single_size:stop+single_size]
v = slice_[start+2*single_size:stop+2*single_size]
weight = torch.cat([q,k,v], dim=0)
k = slice_[start + single_size : stop + single_size]
v = slice_[start + 2 * single_size : stop + 2 * single_size]
weight = torch.cat([q, k, v], dim=0)
weight = weight.to(device=self.device)
weight = weight.to(dtype=self.dtype)
return weight
......@@ -205,7 +208,7 @@ class Weights:
w = [self.get_sharded(f"{p}.weight", dim=0) for p in prefixes]
weight = torch.cat(w, dim=dim)
return weight
def get_tensor_shard(self, var, dim):
world_size = self.process_group.size()
rank = self.process_group.rank()
......@@ -220,7 +223,7 @@ class Weights:
raise NotImplementedError("Let's make that generic when needed")
tensor = tensor.to(dtype=self.dtype)
tensor = tensor.to(device=self.device)
return tensor
return tensor
def get_multi_weights_row(self, prefix: str, quantize: str):
if quantize == "gptq":
......@@ -303,7 +306,7 @@ class Weights:
scales = self.get_sharded(f"{prefix}.scales", dim=0)
g_idx = None
use_exllama = False
weight = (qweight, qzeros, scales, g_idx, bits, groupsize, use_exllama)
else:
weight = self.get_sharded(f"{prefix}.weight", dim=1)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment