Commit 72ee382d authored by OlivierDehaene's avatar OlivierDehaene
Browse files

chore: formatting

parent 3a521c92
......@@ -3,17 +3,20 @@ import torch
from typing import Optional
from text_generation_server.models.flash_mistral import BaseFlashMistral
from text_generation_server.models.custom_modeling.flash_mixtral_modeling import MixtralConfig, FlashMixtralForCausalLM
from text_generation_server.models.custom_modeling.flash_mixtral_modeling import (
MixtralConfig,
FlashMixtralForCausalLM,
)
class FlashMixtral(BaseFlashMistral):
def __init__(
self,
model_id: str,
revision: Optional[str] = None,
quantize: Optional[str] = None,
dtype: Optional[torch.dtype] = None,
trust_remote_code: bool = False,
self,
model_id: str,
revision: Optional[str] = None,
quantize: Optional[str] = None,
dtype: Optional[torch.dtype] = None,
trust_remote_code: bool = False,
):
super(FlashMixtral, self).__init__(
config_cls=MixtralConfig,
......@@ -22,5 +25,5 @@ class FlashMixtral(BaseFlashMistral):
revision=revision,
quantize=quantize,
dtype=dtype,
trust_remote_code=trust_remote_code
trust_remote_code=trust_remote_code,
)
......@@ -792,7 +792,10 @@ class IdeficsCausalLM(Model):
skip_special_tokens=False,
)
prefill_tokens = Tokens(
prefill_token_ids, prefill_logprobs, prefill_texts, is_special=[]
prefill_token_ids,
prefill_logprobs,
prefill_texts,
is_special=[],
)
else:
prefill_tokens = None
......@@ -803,10 +806,10 @@ class IdeficsCausalLM(Model):
request.id,
prefill_tokens,
Tokens(
[next_token_id_squeezed],
[next_token_logprob],
[next_token_text],
[next_token_id_squeezed.item() in self.all_special_ids],
[next_token_id_squeezed],
[next_token_logprob],
[next_token_text],
[next_token_id_squeezed.item() in self.all_special_ids],
),
generated_text,
top_tokens,
......
......@@ -56,7 +56,7 @@ class Model(ABC):
dtype=str(self.dtype),
device_type=self.device.type,
window_size=self.sliding_window,
speculate=self.speculate
speculate=self.speculate,
)
@property
......
......@@ -736,7 +736,7 @@ class Seq2SeqLM(Model):
[self.tokenizer.bos_token_id],
[float("nan")],
[self.tokenizer.bos_token],
[False]
[False],
)
else:
prefill_tokens = None
......@@ -763,10 +763,10 @@ class Seq2SeqLM(Model):
request.id,
prefill_tokens,
Tokens(
[next_token_id_squeezed],
[next_token_logprob],
[next_token_text],
[next_token_id_squeezed.item() in self.all_special_ids],
[next_token_id_squeezed],
[next_token_logprob],
[next_token_text],
[next_token_id_squeezed.item() in self.all_special_ids],
),
generated_text,
top_tokens,
......
......@@ -66,7 +66,10 @@ class Tokens:
def to_pb(self) -> generate_pb2.Tokens:
return generate_pb2.Tokens(
ids=self.token_ids, logprobs=self.logprobs, texts=self.texts, is_special=self.is_special
ids=self.token_ids,
logprobs=self.logprobs,
texts=self.texts,
is_special=self.is_special,
)
def __len__(self):
......
......@@ -159,7 +159,13 @@ def serve(
try:
model = get_model(
model_id, revision, sharded, quantize, speculate, dtype, trust_remote_code
model_id,
revision,
sharded,
quantize,
speculate,
dtype,
trust_remote_code,
)
except Exception:
logger.exception("Error when initializing model")
......@@ -207,5 +213,7 @@ def serve(
await server.stop(0)
asyncio.run(
serve_inner(model_id, revision, sharded, quantize, speculate, dtype, trust_remote_code)
serve_inner(
model_id, revision, sharded, quantize, speculate, dtype, trust_remote_code
)
)
......@@ -51,7 +51,9 @@ except ImportError as e:
) from e
elif IS_ROCM_SYSTEM:
for idx in range(torch.cuda.device_count()):
if "MI210" not in torch.cuda.get_device_name(idx) and "MI250" not in torch.cuda.get_device_name(idx):
if "MI210" not in torch.cuda.get_device_name(
idx
) and "MI250" not in torch.cuda.get_device_name(idx):
raise ImportError(
f"AMD GPU {torch.cuda.get_device_name(idx)} does not support flash-attention"
)
......@@ -91,8 +93,10 @@ def attention(
)
elif HAS_FLASH_ATTN_V2_ROCM:
if window_size_left != -1:
raise ValueError(f"RoCm version of Flash Attention v2 does not support window attention (window_size_left != -1, got window_size_left={window_size_left}).")
raise ValueError(
f"RoCm version of Flash Attention v2 does not support window attention (window_size_left != -1, got window_size_left={window_size_left})."
)
# RoCm flash API does not take the window_size_left and window_size_right arguments.
return flash_attn_2_cuda.varlen_fwd(
q,
......
......@@ -11,40 +11,44 @@ logger = getLogger(__name__)
try:
from exllamav2_kernels import make_q_matrix, gemm_half_q_half
except ImportError:
logger.error('exllamav2_kernels not installed.')
logger.error("exllamav2_kernels not installed.")
raise
# Dummy tensor to pass instead of g_idx since there is no way to pass "None" to a C++ extension
none_tensor = torch.empty((1, 1), device="meta")
def ext_gemm_half_q_half(x, q_handle, q4_width, force_cuda):
"""Matrix multiplication, returns x @ q4"""
output_shape = x.shape[:-1] + (q4_width,)
x = x.view(-1, x.shape[-1])
output = torch.empty((x.shape[0], q4_width), dtype = torch.half, device = x.device)
output = torch.empty((x.shape[0], q4_width), dtype=torch.half, device=x.device)
gemm_half_q_half(x, q_handle, output, force_cuda)
return output.view(output_shape)
def ext_make_q_matrix(w: dict, temp_dq, key: str = None):
"""
Create Q matrix
Create Q matrix
"""
# EXL2
# won't work as the moment because the tensors are not the same.
# won't work as the moment because the tensors are not the same.
if "q_weight" in w:
w["q_scale_max"] /= 256
w["q_perm"] = w["q_perm"].short()
w["q_invperm"] = w["q_invperm"].short()
return make_q_matrix(w["q_weight"],
w["q_perm"],
w["q_invperm"],
w["q_scale"],
w["q_scale_max"],
w["q_groups"],
none_tensor,
none_tensor,
none_tensor,
temp_dq)
return make_q_matrix(
w["q_weight"],
w["q_perm"],
w["q_invperm"],
w["q_scale"],
w["q_scale_max"],
w["q_groups"],
none_tensor,
none_tensor,
none_tensor,
temp_dq,
)
# GPTQ
elif "qweight" in w:
if w["scales"].dtype == torch.float:
......@@ -52,31 +56,40 @@ def ext_make_q_matrix(w: dict, temp_dq, key: str = None):
# GPTQ with g_idx (act_order)
if w.get("g_idx", None) is not None and not (w["g_idx"] == 0).all().item():
w["q_perm"] = torch.empty((w["qweight"].shape[0] * 8,), dtype = torch.short, device = w["qweight"].device)
w["q_perm"] = torch.empty(
(w["qweight"].shape[0] * 8,),
dtype=torch.short,
device=w["qweight"].device,
)
w["q_invperm"] = torch.empty_like(w["q_perm"])
# make_q4 segfaults if g_idx is not on cpu in the act-order case. In the non act-order case, None needs to be passed for g_idx.
return make_q_matrix(w["qweight"],
w["q_perm"],
w["q_invperm"],
none_tensor,
none_tensor,
none_tensor,
w["qzeros"],
w["scales"],
w["g_idx"].cpu(),
temp_dq)
return make_q_matrix(
w["qweight"],
w["q_perm"],
w["q_invperm"],
none_tensor,
none_tensor,
none_tensor,
w["qzeros"],
w["scales"],
w["g_idx"].cpu(),
temp_dq,
)
# GPTQ without g_idx
else:
return make_q_matrix(w["qweight"],
none_tensor,
none_tensor,
none_tensor,
none_tensor,
none_tensor,
w["qzeros"],
w["scales"],
none_tensor,
temp_dq)
return make_q_matrix(
w["qweight"],
none_tensor,
none_tensor,
none_tensor,
none_tensor,
none_tensor,
w["qzeros"],
w["scales"],
none_tensor,
temp_dq,
)
DEVICE = None
FIXED_BYTES = 0
......@@ -106,14 +119,15 @@ class QuantLinear(nn.Module):
super().__init__()
if bits != 4:
raise ValueError(
f"Exllamav2 kernel supports only bits=4, requested bits={bits}. Something is wrong in the model initialization.")
f"Exllamav2 kernel supports only bits=4, requested bits={bits}. Something is wrong in the model initialization."
)
self.q_handle = None
self.q_tensors = None
self.bits = bits
self.maxq = 2 ** self.bits - 1
self.maxq = 2**self.bits - 1
self.infeatures = qweight.shape[0] // self.bits * 32
self.outfeatures = qweight.shape[1]
self.padding = - self.outfeatures % 32
self.padding = -self.outfeatures % 32
self.outfeatures = self.outfeatures + self.padding
self.device = qweight.device
......@@ -128,9 +142,12 @@ class QuantLinear(nn.Module):
outfeatures = self.outfeatures
assert qweight.shape == (infeatures // 32 * self.bits, outfeatures)
assert infeatures % self.group_size == 0
assert qzeros.shape == (infeatures // self.group_size, outfeatures // 32 * self.bits)
assert qzeros.shape == (
infeatures // self.group_size,
outfeatures // 32 * self.bits,
)
assert scales.shape == (infeatures // self.group_size, outfeatures)
assert g_idx.shape == (infeatures, ), f"{g_idx.shape}, {infeatures}"
assert g_idx.shape == (infeatures,), f"{g_idx.shape}, {infeatures}"
global FIXED_BYTES, LAYERS
FIXED_BYTES = max(FIXED_BYTES, self.scratch_space_fixed())
......@@ -140,33 +157,31 @@ class QuantLinear(nn.Module):
assert self.qweight.device.type == "cuda"
assert self.qweight.device.index is not None
self.q_tensors = {
"qweight":self.qweight,
"qzeros":self.qzeros,
"scales":self.scales,
"g_idx":self.g_idx
"qweight": self.qweight,
"qzeros": self.qzeros,
"scales": self.scales,
"g_idx": self.g_idx,
}
temp_dq = temp_dq.get_scratch_slice(self.temp_dq_size())
self.q_handle = ext_make_q_matrix(
self.q_tensors, temp_dq
)
def forward(self, x, force_cuda = False):
self.q_handle = ext_make_q_matrix(self.q_tensors, temp_dq)
def forward(self, x, force_cuda=False):
output = ext_gemm_half_q_half(x, self.q_handle, self.outfeatures, force_cuda)
if self.bias is not None:
output.add_(self.bias)
return output
def temp_dq_size(self):
return self.infeatures * self.outfeatures * 2 + 128
def temp_fwd_size(self, max_input_len, max_batch_size):
return self.outfeatures * max_input_len * max_batch_size * 4 + 128
def scratch_space_fixed(self, max_input_len=4096, max_batch_size=16):
return self.temp_dq_size() + self.temp_fwd_size(max_input_len, max_batch_size)
class ExLlamaV2DeviceTensors:
device_idx: int
......@@ -177,13 +192,16 @@ class ExLlamaV2DeviceTensors:
def __init__(self, device, scratch_bytes):
self.device = device
self.scratch_bytes = scratch_bytes
def prepare(self):
self.scratch = torch.empty((self.scratch_bytes // 2,), dtype = torch.half, device = self.device)
self.scratch = torch.empty(
(self.scratch_bytes // 2,), dtype=torch.half, device=self.device
)
def get_scratch_slice(self, size_bytes):
if self.scratch is None: self.prepare()
if self.scratch is None:
self.prepare()
size_bytes = ((size_bytes + 127) // 128) * 128
size_half = size_bytes // 2
......
......@@ -35,7 +35,9 @@ HAS_EXLLAMA = False
CAN_EXLLAMA = major >= 8
V2 = os.getenv("EXLLAMA_VERSION", "2") == "2"
if V2 and int(os.getenv("WORLD_SIZE", "1")) > 1:
logger.warning("Disabling exllama v2 and using v1 instead because there are issues when sharding")
logger.warning(
"Disabling exllama v2 and using v1 instead because there are issues when sharding"
)
V2 = False
if os.getenv("DISABLE_EXLLAMA") == "True":
......@@ -43,17 +45,19 @@ if os.getenv("DISABLE_EXLLAMA") == "True":
elif CAN_EXLLAMA:
try:
if V2:
from text_generation_server.utils.gptq.exllamav2 import (QuantLinear as ExllamaQuantLinear,
create_exllama_buffers,
set_device,
)
from text_generation_server.utils.gptq.exllamav2 import (
QuantLinear as ExllamaQuantLinear,
create_exllama_buffers,
set_device,
)
HAS_EXLLAMA = "2"
else:
from text_generation_server.utils.gptq.exllama import (Ex4bitLinear as ExllamaQuantLinear,
create_exllama_buffers,
set_device,
)
from text_generation_server.utils.gptq.exllama import (
Ex4bitLinear as ExllamaQuantLinear,
create_exllama_buffers,
set_device,
)
HAS_EXLLAMA = "1"
......@@ -114,7 +118,7 @@ def load_conv2d(cls, prefix, weights, in_channels, out_channels, kernel_size, st
@classmethod
def load_conv2d_no_bias(
cls, prefix, weights, in_channels, out_channels, kernel_size, stride
cls, prefix, weights, in_channels, out_channels, kernel_size, stride
):
weight = weights.get_tensor(f"{prefix}.weight")
with init_empty_weights():
......@@ -138,9 +142,9 @@ torch.nn.LayerNorm.load_no_bias = load_layer_norm_no_bias
class FastLinear(nn.Module):
def __init__(
self,
weight,
bias,
self,
weight,
bias,
) -> None:
super().__init__()
self.weight = nn.Parameter(weight)
......@@ -164,9 +168,9 @@ class FastLinear(nn.Module):
class EETQLinear(nn.Module):
def __init__(
self,
weight,
bias,
self,
weight,
bias,
) -> None:
super().__init__()
device = weight.device
......@@ -185,13 +189,13 @@ class EETQLinear(nn.Module):
class Linear8bitLt(nn.Module):
def __init__(
self,
weight,
bias,
has_fp16_weights=True,
memory_efficient_backward=False,
threshold=0.0,
index=None,
self,
weight,
bias,
has_fp16_weights=True,
memory_efficient_backward=False,
threshold=0.0,
index=None,
):
super().__init__()
assert (
......@@ -325,7 +329,9 @@ def get_linear(weight, bias, quantize):
)
if use_exllama:
linear = ExllamaQuantLinear(qweight, qzeros, scales, g_idx, bias, bits, groupsize)
linear = ExllamaQuantLinear(
qweight, qzeros, scales, g_idx, bias, bits, groupsize
)
else:
linear = QuantLinear(
qweight,
......@@ -533,7 +539,6 @@ try:
else:
dropout_layer_norm = None
class FastLayerNorm(nn.LayerNorm):
def forward(self, hidden_states, residual=None):
if hidden_states.shape[-1] > 8192 or IS_ROCM_SYSTEM:
......@@ -569,7 +574,6 @@ try:
return normed_hidden_states, residual
class FastRMSNorm(nn.Module):
def __init__(self, weight: torch.Tensor, eps: float):
super().__init__()
......@@ -601,7 +605,11 @@ try:
return self.weight * hidden_states, residual
elif IS_CUDA_SYSTEM:
# faster post attention rms norm
normed_hidden_states, res, *rest = dropout_layer_norm.dropout_add_ln_fwd(
(
normed_hidden_states,
res,
*rest,
) = dropout_layer_norm.dropout_add_ln_fwd(
hidden_states,
residual,
self.weight,
......@@ -638,7 +646,8 @@ try:
return out, residual
else:
raise ValueError(
"Your system seem to be not supported. Please check your install or open an issue at https://github.com/huggingface/text-generation-inference/issues with a clear reproduction.")
"Your system seem to be not supported. Please check your install or open an issue at https://github.com/huggingface/text-generation-inference/issues with a clear reproduction."
)
except ImportError:
pass
......@@ -650,14 +659,12 @@ try:
elif IS_ROCM_SYSTEM:
from vllm import pos_encoding_ops
def _create_inv_freq(dim, base, device):
inv_freq = 1.0 / (
base ** (torch.arange(0, dim, 2, device=device, dtype=torch.float32) / dim)
base ** (torch.arange(0, dim, 2, device=device, dtype=torch.float32) / dim)
)
return inv_freq
def _get_rope_config(config):
if os.getenv("ROPE_SCALING", None) is not None:
rope_scaling = {
......@@ -667,7 +674,6 @@ try:
return rope_scaling
return getattr(config, "rope_scaling", None)
class PositionRotaryEmbedding(nn.Module):
def __init__(self, inv_freq, scaling_factor):
super().__init__()
......@@ -680,17 +686,23 @@ try:
self.scaling_factor = scaling_factor
self.dynamic_args = None
def forward(self, query: torch.Tensor, key: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor):
def forward(
self,
query: torch.Tensor,
key: torch.Tensor,
cos: torch.Tensor,
sin: torch.Tensor,
):
# Such controlflows may add some overhead.
if IS_CUDA_SYSTEM:
rotary_dim = cos.shape[-1]
q1 = query[..., :rotary_dim]
q2 = query[..., rotary_dim: 2 * rotary_dim]
q2 = query[..., rotary_dim : 2 * rotary_dim]
rotary_emb.apply_rotary(q1, q2, cos, sin, q1, q2, False)
k1 = key[..., :rotary_dim]
k2 = key[..., rotary_dim: 2 * rotary_dim]
k2 = key[..., rotary_dim : 2 * rotary_dim]
rotary_emb.apply_rotary(k1, k2, cos, sin, k1, k2, False)
elif IS_ROCM_SYSTEM:
......@@ -700,17 +712,11 @@ try:
head_size = query.shape[-1]
# Inplace operation, updating query and key.
pos_encoding_ops.rotary_embedding(
query,
key,
head_size,
cos,
sin,
True
)
pos_encoding_ops.rotary_embedding(query, key, head_size, cos, sin, True)
else:
raise ValueError(
"Your system seem to be not supported. Please check your install or open an issue at https://github.com/huggingface/text-generation-inference/issues with a clear reproduction.")
"Your system seem to be not supported. Please check your install or open an issue at https://github.com/huggingface/text-generation-inference/issues with a clear reproduction."
)
@classmethod
def static(cls, config, dim, base, device):
......@@ -732,15 +738,16 @@ try:
elif rope_scaling["type"] == "yarn":
return YarnPositionRotaryEmbedding(
dim=2 * inv_freq.shape[0],
max_position_embeddings=rope_scaling["original_max_position_embeddings"],
max_position_embeddings=rope_scaling[
"original_max_position_embeddings"
],
base=10000.0,
device=inv_freq.device,
scaling_factor=scaling_factor,
extrapolation_factor=1,
attn_factor=1,
beta_fast=32,
beta_slow=1
beta_slow=1,
)
else:
raise NotImplementedError(
......@@ -773,15 +780,16 @@ try:
elif rope_scaling["type"] == "yarn":
return YarnPositionRotaryEmbedding(
dim=2 * inv_freq.shape[0],
max_position_embeddings=rope_scaling["original_max_position_embeddings"],
max_position_embeddings=rope_scaling[
"original_max_position_embeddings"
],
base=10000.0,
device=inv_freq.device,
scaling_factor=scaling_factor,
extrapolation_factor=1,
attn_factor=1,
beta_fast=32,
beta_slow=1
beta_slow=1,
)
else:
raise NotImplementedError(
......@@ -793,9 +801,9 @@ try:
# Reset the tables if the sequence length has changed,
# or if we're on a new device (possibly due to tracing for instance)
if (
seqlen > self._seq_len_cached
or self._cos_cached.device != device
or self._cos_cached.dtype != dtype
seqlen > self._seq_len_cached
or self._cos_cached.device != device
or self._cos_cached.dtype != dtype
):
self._seq_len_cached = seqlen
t = torch.arange(seqlen, device=device, dtype=self.inv_freq.dtype)
......@@ -809,7 +817,7 @@ try:
self._sin_cached = torch.sin(freqs).to(dtype)
def get_cos_sin(
self, position_ids: torch.Tensor, max_s: int, dtype: torch.dtype
self, position_ids: torch.Tensor, max_s: int, dtype: torch.dtype
):
"""
Return cos and sin for the asked position ids
......@@ -827,7 +835,6 @@ try:
# Note: this unsqueeze is not necessary on RoCm + VLLM ROPE implementation, but we leave it as is to avoid yet an other controlflow.
return cos.unsqueeze(1), sin.unsqueeze(1)
class DynamicPositionRotaryEmbedding(PositionRotaryEmbedding):
def __init__(self, dim, max_position_embeddings, base, device, scaling_factor):
inv_freq = _create_inv_freq(dim, base, device)
......@@ -840,14 +847,14 @@ try:
# Reset the tables if the sequence length has changed,
# or if we're on a new device (possibly due to tracing for instance)
if (
seqlen > self._seq_len_cached
or self._cos_cached.device != device
or self._cos_cached.dtype != dtype
seqlen > self._seq_len_cached
or self._cos_cached.device != device
or self._cos_cached.dtype != dtype
):
if seqlen > self.max_position_embeddings:
newbase = self.base * (
(self.scaling_factor * seqlen / self.max_position_embeddings)
- (self.scaling_factor - 1)
(self.scaling_factor * seqlen / self.max_position_embeddings)
- (self.scaling_factor - 1)
) ** (self.dim / (self.dim - 2))
self.inv_freq = _create_inv_freq(
self.dim, newbase, self.inv_freq.device
......@@ -861,24 +868,28 @@ try:
self._cos_cached = torch.cos(freqs).to(dtype)
self._sin_cached = torch.sin(freqs).to(dtype)
# Inverse dim formula to find dim based on number of rotations
import math
def find_correction_dim(num_rotations, dim, base=10000, max_position_embeddings=2048):
return (dim * math.log(max_position_embeddings / (num_rotations * 2 * math.pi))) / (2 * math.log(base))
def find_correction_dim(
num_rotations, dim, base=10000, max_position_embeddings=2048
):
return (
dim * math.log(max_position_embeddings / (num_rotations * 2 * math.pi))
) / (2 * math.log(base))
# Find dim range bounds based on rotations
def find_correction_range(low_rot, high_rot, dim, base=10000, max_position_embeddings=2048):
low = math.floor(find_correction_dim(
low_rot, dim, base, max_position_embeddings))
high = math.ceil(find_correction_dim(
high_rot, dim, base, max_position_embeddings))
def find_correction_range(
low_rot, high_rot, dim, base=10000, max_position_embeddings=2048
):
low = math.floor(
find_correction_dim(low_rot, dim, base, max_position_embeddings)
)
high = math.ceil(
find_correction_dim(high_rot, dim, base, max_position_embeddings)
)
return max(low, 0), min(high, dim - 1) # Clamp values just in case
def linear_ramp_mask(min, max, dim):
if min == max:
max += 0.001 # Prevent singularity
......@@ -887,16 +898,25 @@ try:
ramp_func = torch.clamp(linear_func, 0, 1)
return ramp_func
def get_mscale(scale=1):
if scale <= 1:
return 1.0
return 0.1 * math.log(scale) + 1.0
class YarnPositionRotaryEmbedding(PositionRotaryEmbedding):
def __init__(self, dim, max_position_embeddings, base, device, scaling_factor, *, extrapolation_factor,
attn_factor, beta_fast, beta_slow):
def __init__(
self,
dim,
max_position_embeddings,
base,
device,
scaling_factor,
*,
extrapolation_factor,
attn_factor,
beta_fast,
beta_slow,
):
inv_freq = _create_inv_freq(dim, base, device)
super().__init__(inv_freq, scaling_factor)
self.dim = dim
......@@ -906,16 +926,17 @@ try:
self.attn_factor = attn_factor
self.beta_fast = beta_fast
self.beta_slow = beta_slow
self.mscale = float(get_mscale(
self.scaling_factor) * self.attn_factor) # Get n-d magnitude scaling corrected for interpolation
self.mscale = float(
get_mscale(self.scaling_factor) * self.attn_factor
) # Get n-d magnitude scaling corrected for interpolation
def _update_cos_sin_cache(self, dtype, device, seqlen):
# Reset the tables if the sequence length has changed,
# or if we're on a new device (possibly due to tracing for instance)
if (
seqlen > self._seq_len_cached
or self._cos_cached.device != device
or self._cos_cached.dtype != dtype
seqlen > self._seq_len_cached
or self._cos_cached.device != device
or self._cos_cached.dtype != dtype
):
if seqlen > self.max_position_embeddings:
inv_freq_extrapolation = _create_inv_freq(
......@@ -923,15 +944,26 @@ try:
)
freqs = 1.0 / inv_freq_extrapolation
inv_freq_interpolation = 1.0 / (self.scaling_factor * freqs)
low, high = find_correction_range(self.beta_fast, self.beta_slow, self.dim, self.base,
self.max_position_embeddings)
inv_freq_mask = (1 - linear_ramp_mask(low, high, self.dim // 2).float().to(
device)) * self.extrapolation_factor # Get n-d rotational scaling corrected for extrapolation
inv_freq = inv_freq_interpolation * (1 - inv_freq_mask) + inv_freq_extrapolation * inv_freq_mask
low, high = find_correction_range(
self.beta_fast,
self.beta_slow,
self.dim,
self.base,
self.max_position_embeddings,
)
inv_freq_mask = (
1
- linear_ramp_mask(low, high, self.dim // 2).float().to(device)
) * self.extrapolation_factor # Get n-d rotational scaling corrected for extrapolation
inv_freq = (
inv_freq_interpolation * (1 - inv_freq_mask)
+ inv_freq_extrapolation * inv_freq_mask
)
self.inv_freq = inv_freq
self.mscale = float(get_mscale(
self.scaling_factor) * self.attn_factor) # Get n-d magnitude scaling corrected for interpolation
self.mscale = float(
get_mscale(self.scaling_factor) * self.attn_factor
) # Get n-d magnitude scaling corrected for interpolation
self._seq_len_cached = seqlen
t = torch.arange(seqlen, device=device, dtype=self.inv_freq.dtype)
......
......@@ -2,6 +2,7 @@ import torch
from dataclasses import dataclass
from text_generation_server.utils.layers import TensorParallelHead, FastLinear
@dataclass
class Output:
logits: torch.FloatTensor = None
......@@ -11,7 +12,9 @@ class Output:
class ResBlock(torch.nn.Module):
def __init__(self, config, prefix, weights):
super().__init__()
self.linear = FastLinear.load(config, prefix=f"{prefix}.linear", weights=weights, bias=True)
self.linear = FastLinear.load(
config, prefix=f"{prefix}.linear", weights=weights, bias=True
)
self.act = torch.nn.SiLU()
def forward(self, x):
......@@ -19,15 +22,13 @@ class ResBlock(torch.nn.Module):
class MedusaModel(torch.nn.Module):
def __init__(
self,
config,
weights,
lm_head
):
def __init__(self, config, weights, lm_head):
super().__init__()
self.heads = torch.nn.ModuleList(
[MedusaHead(config, prefix=f"{i}", weights=weights) for i in range(config["medusa_num_heads"])]
[
MedusaHead(config, prefix=f"{i}", weights=weights)
for i in range(config["medusa_num_heads"])
]
)
self.lm_head = lm_head
......@@ -40,9 +41,16 @@ class MedusaModel(torch.nn.Module):
class MedusaHead(torch.nn.Module):
def __init__(self, config, prefix, weights):
super().__init__()
self.blocks = torch.nn.ModuleList([ResBlock(config, prefix=f"{prefix}.{i}", weights=weights) for i in range(config["medusa_num_layers"])])
self.blocks = torch.nn.ModuleList(
[
ResBlock(config, prefix=f"{prefix}.{i}", weights=weights)
for i in range(config["medusa_num_layers"])
]
)
n = len(self.blocks)
self.out = FastLinear.load(config, prefix=f"{prefix}.{n}", weights=weights, bias=False)
self.out = FastLinear.load(
config, prefix=f"{prefix}.{n}", weights=weights, bias=False
)
def forward(self, x):
for block in self.blocks:
......
......@@ -7,23 +7,26 @@ from vllm import attention_ops
_PARTITION_SIZE = 512
def reshape_and_cache(key: torch.Tensor, value: torch.Tensor, key_cache: torch.Tensor, value_cache: torch.Tensor,
slots: torch.Tensor):
cache_ops.reshape_and_cache(
key, value, key_cache, value_cache, slots
)
def reshape_and_cache(
key: torch.Tensor,
value: torch.Tensor,
key_cache: torch.Tensor,
value_cache: torch.Tensor,
slots: torch.Tensor,
):
cache_ops.reshape_and_cache(key, value, key_cache, value_cache, slots)
def attention(
out: torch.Tensor,
query: torch.Tensor,
key_cache: torch.Tensor,
value_cache: torch.Tensor,
kv_head_mapping: torch.Tensor,
softmax_scale: float,
block_tables: torch.Tensor,
input_lengths: torch.Tensor,
max_s: int,
out: torch.Tensor,
query: torch.Tensor,
key_cache: torch.Tensor,
value_cache: torch.Tensor,
kv_head_mapping: torch.Tensor,
softmax_scale: float,
block_tables: torch.Tensor,
input_lengths: torch.Tensor,
max_s: int,
):
# Adapted from: https://github.com/vllm-project/vllm/blob/f8a1e39fae05ca610be8d5a78be9d40f5274e5fc/vllm/model_executor/layers/attention.py
# Copyright 2023 The vLLM team. All rights
......@@ -45,9 +48,7 @@ def attention(
# value_cache => [num_blocks, num_heads, head_size, block_size]
block_size = value_cache.shape[3]
num_seqs, num_heads, head_size = query.shape
max_num_partitions = (
(max_s + _PARTITION_SIZE - 1) //
_PARTITION_SIZE)
max_num_partitions = (max_s + _PARTITION_SIZE - 1) // _PARTITION_SIZE
# NOTE(woosuk): We use a simple heuristic to decide whether to use
# PagedAttention V1 or V2. If the number of partitions is 1, we use
# V1 to avoid the overhead of reduction. Also, if the number of
......
......@@ -38,7 +38,9 @@ def download_and_unload_peft(model_id, revision, trust_remote_code):
os.makedirs(model_id, exist_ok=True)
cache_dir = model_id
logger.info(f"Saving the newly created merged model to {cache_dir}")
tokenizer = AutoTokenizer.from_pretrained(base_model_id, trust_remote_code=trust_remote_code)
tokenizer = AutoTokenizer.from_pretrained(
base_model_id, trust_remote_code=trust_remote_code
)
model.save_pretrained(cache_dir, safe_serialization=True)
model.config.save_pretrained(cache_dir)
tokenizer.save_pretrained(cache_dir)
SPECULATE = None
def get_speculate() -> int:
global SPECULATE
return SPECULATE
def set_speculate(speculate: int):
global SPECULATE
SPECULATE = speculate
......@@ -16,6 +16,7 @@ from text_generation_server.utils.logits_process import (
from text_generation_server.utils.watermark import WatermarkLogitsProcessor
from transformers import PreTrainedTokenizerBase, RepetitionPenaltyLogitsProcessor
class NextTokenChooser:
def __init__(
self,
......@@ -145,21 +146,31 @@ class StoppingCriteria:
pb.ignore_eos_token,
)
def create_n_gram_speculation(input_ids: torch.Tensor, next_ids: torch.Tensor, accepted_ids: torch.Tensor, speculate: int, verbose: bool):
def create_n_gram_speculation(
input_ids: torch.Tensor,
next_ids: torch.Tensor,
accepted_ids: torch.Tensor,
speculate: int,
verbose: bool,
):
# Very trivial approach, find first match in the string.
# This is much less refined than actual n-gram but seems to work
# relatively OK in grounded mode and is by far much faster with
# much less worst case complexity as everything happens on device.
B = accepted_ids.shape[0]
device = input_ids.device
seeds = next_ids[accepted_ids.cumsum(dim=-1) -1 ]
seeds = next_ids[accepted_ids.cumsum(dim=-1) - 1]
indices = (input_ids == seeds.unsqueeze(-1)).max(dim=1).indices + 1
all_indices = indices.unsqueeze(-1).expand(B, speculate) + torch.arange(speculate, device=device)
all_indices = indices.unsqueeze(-1).expand(B, speculate) + torch.arange(
speculate, device=device
)
all_indices = torch.clamp(all_indices, max=input_ids.shape[1] - 1)
speculative_ids = input_ids.gather(dim=-1, index=all_indices)
return speculative_ids
class HeterogeneousNextTokenChooser:
def __init__(
self,
......@@ -228,7 +239,15 @@ class HeterogeneousNextTokenChooser:
self.dtype = dtype
self.device = device
def __call__(self, input_ids: torch.Tensor, scores: torch.Tensor, speculate: int, speculated_ids: Optional[torch.Tensor] = None, speculative_scores: Optional[torch.Tensor] = None, verbose=False):
def __call__(
self,
input_ids: torch.Tensor,
scores: torch.Tensor,
speculate: int,
speculated_ids: Optional[torch.Tensor] = None,
speculative_scores: Optional[torch.Tensor] = None,
verbose=False,
):
if speculated_ids is not None:
B = scores.shape[0] // (speculated_ids.shape[1] + 1)
S = speculated_ids.shape[1] + 1
......@@ -249,12 +268,11 @@ class HeterogeneousNextTokenChooser:
for warper in self.warpers:
_scores = warper(input_ids, _scores)
_next_ids = self.choice(_scores)
scores[:, j] = _scores
next_ids[:, j] = _next_ids
next_ids = next_ids.view(B*S)
scores = scores.view( B* S, -1)
next_ids = next_ids.view(B * S)
scores = scores.view(B * S, -1)
if speculated_ids is not None:
accepted_ids = []
......@@ -262,7 +280,7 @@ class HeterogeneousNextTokenChooser:
S = speculated_ids.shape[1] + 1
indices = []
for i in range(B):
_next_ids = next_ids[i*S: (i + 1)*S]
_next_ids = next_ids[i * S : (i + 1) * S]
_speculated_ids = speculated_ids[i]
validate_speculative = _next_ids[:-1] == _speculated_ids
index = i * S
......@@ -278,7 +296,9 @@ class HeterogeneousNextTokenChooser:
break
accepted_ids.append(accepted)
accepted_ids = torch.tensor(accepted_ids, device=input_ids.device, dtype=input_ids.dtype)
accepted_ids = torch.tensor(
accepted_ids, device=input_ids.device, dtype=input_ids.dtype
)
next_ids = next_ids[indices]
scores = scores[indices]
indices = torch.arange(B, device=input_ids.device) * S
......@@ -296,7 +316,9 @@ class HeterogeneousNextTokenChooser:
speculative_ids = Greedy()(speculative_scores)
else:
# n-gram
speculative_ids = create_n_gram_speculation(input_ids, next_ids, accepted_ids, speculate, verbose)
speculative_ids = create_n_gram_speculation(
input_ids, next_ids, accepted_ids, speculate, verbose
)
else:
speculative_ids = None
......
......@@ -16,7 +16,7 @@ class Weights:
dtype,
process_group,
aliases: Optional[Dict[str, List[str]]] = None,
prefix: Optional[str] = None
prefix: Optional[str] = None,
):
routing = {}
for filename in filenames:
......@@ -213,7 +213,8 @@ class Weights:
bits, groupsize = self._get_gptq_params()
from text_generation_server.utils.layers import HAS_EXLLAMA
use_exllama = bits==4 and HAS_EXLLAMA and quantize == "gptq"
use_exllama = bits == 4 and HAS_EXLLAMA and quantize == "gptq"
weight = (qweight, qzeros, scales, g_idx, bits, groupsize, use_exllama)
else:
w = [self.get_sharded(f"{p}.weight", dim=0) for p in prefixes]
......@@ -283,7 +284,7 @@ class Weights:
if use_exllama:
qzeros = self.get_sharded(f"{prefix}.qzeros", dim=0)
scales = self.get_sharded(f"{prefix}.scales", dim=0)
g_idx = self.get_sharded(f"{prefix}.g_idx", dim= 0)
g_idx = self.get_sharded(f"{prefix}.g_idx", dim=0)
g_idx = g_idx - g_idx[0]
else:
# The triton kernel reorders the scales/zero points instead of the weight/activation.
......
......@@ -21,14 +21,14 @@ def main():
block = []
for line in lines:
if line.startswith(" -") or line.startswith(" -"):
rendered_block = '\n'.join(block)
rendered_block = "\n".join(block)
if header:
final_doc += f"## {header}\n```shell\n{rendered_block}\n```\n"
else:
final_doc += f"```shell\n{rendered_block}\n```\n"
block = []
tokens = line.split("<")
if len(tokens)>1:
if len(tokens) > 1:
header = tokens[-1][:-1]
else:
header = line.split("--")[-1]
......@@ -36,7 +36,7 @@ def main():
block.append(line)
rendered_block = '\n'.join(block)
rendered_block = "\n".join(block)
final_doc += f"## {header}\n```shell\n{rendered_block}\n```\n"
block = []
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment