"tests/vscode:/vscode.git/clone" did not exist on "e57d9e79ca6e75fbcdf76cfd950bdf33e9c9203f"
Unverified Commit e3d76564 authored by Nicolas Patry's avatar Nicolas Patry Committed by GitHub
Browse files

MLPSpeculator. (#1865)

# What does this PR do?

<!--
Congratulations! You've made it this far! You're not quite done yet
though.

Once merged, your PR is going to appear in the release notes with the
title you set, so make sure it's a great title that fully reflects the
extent of your awesome contribution.

Then, please replace this with a description of the change and which
issue is fixed (if applicable). Please also include relevant motivation
and context. List any dependencies (if any) that are required for this
change.

Once you're done, someone will review your PR shortly (see the section
"Who can review?" below to tag some potential reviewers). They may
suggest changes to make the code even better. If no one reviewed your PR
after a week has passed, don't hesitate to post a new comment
@-mentioning the same persons---sometimes notifications get lost.
-->

<!-- Remove if not applicable -->

Fixes # (issue)


## Before submitting
- [ ] This PR fixes a typo or improves the docs (you can dismiss the
other checks if that's the case).
- [ ] Did you read the [contributor
guideline](https://github.com/huggingface/transformers/blob/main/CONTRIBUTING.md#start-contributing-pull-requests),
      Pull Request section?
- [ ] Was this discussed/approved via a Github issue or the
[forum](https://discuss.huggingface.co/)? Please add a link
      to it if that's the case.
- [ ] Did you make sure to update the documentation with your changes?
Here are the
[documentation
guidelines](https://github.com/huggingface/transformers/tree/main/docs),
and
[here are tips on formatting
docstrings](https://github.com/huggingface/transformers/tree/main/docs#writing-source-documentation

).
- [ ] Did you write any new necessary tests?


## Who can review?

Anyone in the community is free to review the PR once the tests have
passed. Feel free to tag
members/contributors who may be interested in your PR.

<!-- Your PR will be replied to more quickly if you can figure out the
right person to tag with @


@OlivierDehaene OR @Narsil

 -->

---------
Co-authored-by: default avatarJoshua Rosenkranz <joshua.rosenkranz@gmail.com>
parent 3136f27f
...@@ -3,11 +3,11 @@ from text_generation_server.layers.tensor_parallel import ( ...@@ -3,11 +3,11 @@ from text_generation_server.layers.tensor_parallel import (
TensorParallelRowLinear, TensorParallelRowLinear,
TensorParallelEmbedding, TensorParallelEmbedding,
) )
from text_generation_server.layers.speculative import SpeculativeHead
from text_generation_server.layers.linear import ( from text_generation_server.layers.linear import (
get_linear, get_linear,
FastLinear, FastLinear,
) )
from text_generation_server.layers.speculative import SpeculativeHead
# Just to add the `load` methods. # Just to add the `load` methods.
from text_generation_server.layers.layernorm import load_layer_norm from text_generation_server.layers.layernorm import load_layer_norm
......
...@@ -69,10 +69,13 @@ class MedusaHeadV1(nn.Module): ...@@ -69,10 +69,13 @@ class MedusaHeadV1(nn.Module):
from safetensors import safe_open from safetensors import safe_open
import json import json
use_medusa = config.use_medusa speculator = config.speculator
medusa_config = str(Path(use_medusa) / "config.json") path = speculator["path"]
filename = str(Path(use_medusa) / "medusa_lm_head.safetensors") medusa_config = str(Path(path) / "config.json")
for fname in speculator["model_paths"]:
filename = str(Path(path) / fname)
with open(medusa_config, "r") as f: with open(medusa_config, "r") as f:
medusa_config = json.load(f) medusa_config = json.load(f)
...@@ -108,10 +111,10 @@ class MedusaHeadV2(nn.Module): ...@@ -108,10 +111,10 @@ class MedusaHeadV2(nn.Module):
from safetensors import safe_open from safetensors import safe_open
import json import json
use_medusa = config.use_medusa speculator = config.speculator
medusa_config = str(Path(use_medusa) / "config.json") medusa_config = str(Path(speculator) / "config.json")
filename = str(Path(use_medusa) / "medusa_lm_head.safetensors") filename = str(Path(speculator) / "medusa_lm_head.safetensors")
with open(medusa_config, "r") as f: with open(medusa_config, "r") as f:
medusa_config = json.load(f) medusa_config = json.load(f)
......
import torch
import math
from torch import nn
from torch.nn import functional as F
from typing import Optional, Tuple
from text_generation_server.layers import TensorParallelEmbedding, FastLinear
from text_generation_server.layers.tensor_parallel import TensorParallelHead
from text_generation_server.utils.speculate import get_speculate
class MLPSpeculatorLayerNorm(nn.Module):
"""
A L2 normalization implementation
...
Args
----
normalized_shape : int
Dimensionality of input data (size of final tensor axis)
elementwise_scale_weight : torch.Tensor
learned scaling term after normalization?
elementwise_shift_bias : torch.Tensor
learned bias term after normalization?
eps : float
Safety term to prevent division by zero. Make sure the chosen value fits in the range of your encoding scheme (i.e. fp16 requires eps >= 6e-8).
"""
def __init__(
self,
prefix,
config,
weights,
eps=1e-06,
):
super(MLPSpeculatorLayerNorm, self).__init__()
self.weight = weights.get_tensor(f"{prefix}.weight")
self.bias = weights.get_tensor(f"{prefix}.bias")
self.eps = eps
def forward(self, x):
xf = x
xf = xf * torch.rsqrt(xf.pow(2).mean(-1, keepdim=True) + self.eps)
x = xf.type_as(x)
x = self.weight * x
x = x + self.bias
return x
class MLPSpeculatorModel(torch.nn.Module):
def __init__(self, config, prefix, weights):
super().__init__()
self.config = config
self.n_predict = get_speculate()
self.hidden_size = config.hidden_size
self.emb = nn.ModuleList(
[
TensorParallelEmbedding(f"{prefix}.emb.{i}", weights)
for i in range(self.n_predict)
]
)
self.proj = [
FastLinear.load(
config,
prefix=f"{prefix}.proj.{i}",
weights=weights,
bias=False,
)
for i in range(self.n_predict)
]
self.head = nn.ModuleList(
[
FastLinear.load(config, f"{prefix}.head.{i}", weights, bias=False)
for i in range(self.n_predict)
]
)
self.ln = nn.ModuleList(
[
MLPSpeculatorLayerNorm(
prefix=f"{prefix}.ln.{i}",
config=config,
weights=weights,
)
for i in range(self.n_predict)
]
)
# Weights ensure that state_0 accounts for 50% of state magnitude by final head in expectation
self.state_weight = 0.5 ** (0.5 / self.n_predict)
self.emb_weight = math.sqrt(1 - self.state_weight**2)
self.activation = nn.GELU()
# TODO
self.vsize = config.vocab_size
self.inner_dim = config.speculator_config["inner_dim"]
self.top_k_tokens_per_head = [1] * self.n_predict
def forward(
self,
hidden_states: torch.Tensor,
input_ids: torch.Tensor,
):
top_k_tokens_per_head = self.top_k_tokens_per_head
# k indicates # of candidates
# h indicates # of generated tokens
state = hidden_states
b = state.size(0)
ind = input_ids.unsqueeze(0)
all_probs = torch.empty(
b, self.n_predict, self.vsize, device=state.device
) # b k h v
assert (
len(top_k_tokens_per_head) == self.n_predict
), f"You must provide a topk number for each head ({self.n_predict} heads, {len(top_k_tokens_per_head)} provided)"
for i in range(self.n_predict):
# Project and predict
z = self.emb[i](ind)
z = z.mul(self.emb_weight * math.sqrt(self.inner_dim / 2)) # b k d
state = self.proj[i](state) * self.state_weight + z
state = self.activation(self.ln[i](state)) # b k d
probs = F.log_softmax(self.head[i](state), dim=-1) # b k v
_probs, preds = probs.topk(top_k_tokens_per_head[i], dim=-1) # b k k'
# Update candidate set with new predictions
# Update distribution set with new logits
all_probs[:, i] = probs.exp()
# Update state, log_probs and ind for new predictions
state = state.unsqueeze(2).expand(
-1, -1, top_k_tokens_per_head[i], -1
) # b k k' d
state = state.reshape(-1, b, state.size(3)) # b kk' d
ind = preds.view(-1, b) # b kk'
speculative_logits = all_probs
return speculative_logits
class MLPSpeculatorHead(nn.Module):
def __init__(self, lm_head, mlp_speculator):
super().__init__()
self.lm_head = lm_head
self.mlp_speculator = mlp_speculator
def forward(
self, input: torch.Tensor
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
logits = self.lm_head(input)
# If we have too many tokens, we skip speculative logits
if input.shape[0] > 128:
return logits, None
input_ids = logits.argmax(dim=-1)
speculative_logits = self.mlp_speculator(input, input_ids)
return logits, speculative_logits
@staticmethod
def load(config, prefix: str, weights):
from pathlib import Path
from safetensors import safe_open
speculator_path = config.speculator["path"]
for fname in config.speculator["model_paths"]:
filename = str(Path(speculator_path) / fname)
routing = weights.routing
with safe_open(filename, framework="pytorch") as f:
for k in f.keys():
if k in routing and routing[k] != filename:
raise RuntimeError(
f"Key {k} was found in multiple files: {filename} and {routing[k]}"
)
routing[k] = filename
mlp_speculator = MLPSpeculatorModel(config, "speculator", weights)
lm_head = TensorParallelHead.load(config, prefix, weights)
return MLPSpeculatorHead(lm_head, mlp_speculator)
import torch import torch
import json
from typing import Tuple, Optional from typing import Tuple, Optional
from text_generation_server.layers.medusa import MedusaHeadV1, MedusaHeadV2
from text_generation_server.layers.tensor_parallel import TensorParallelHead from text_generation_server.layers.tensor_parallel import TensorParallelHead
from text_generation_server.layers.medusa import MedusaHeadV1, MedusaHeadV2
from text_generation_server.layers.mlp import MLPSpeculatorHead
class SpeculativeHead(torch.nn.Module): class SpeculativeHead(torch.nn.Module):
def __init__(self, lm_head, medusa): def __init__(self, lm_head, speculator):
super().__init__() super().__init__()
self.head = lm_head self.head = lm_head
self.medusa = medusa self.speculator = speculator
@staticmethod @staticmethod
def load(config, prefix: str, weights): def load(config, prefix: str, weights):
use_medusa = config.use_medusa speculator = config.speculator
if use_medusa: if speculator:
lm_head = None speculator_path = config.speculator["path"]
speculator_config = str(speculator_path / "config.json")
with open(speculator_config, "r") as f:
speculator_config = json.load(f)
config.speculator_config = speculator_config
try: try:
medusa = MedusaHeadV1.load(config, prefix, weights) architecture = speculator_config["architectures"][0]
if architecture == "MLPSpeculatorPreTrainedModel":
speculator = MLPSpeculatorHead.load(config, prefix, weights)
else:
speculator = None
except KeyError:
try:
speculator = MedusaHeadV1.load(config, prefix, weights)
except: except:
medusa = MedusaHeadV2(config, prefix, weights) speculator = MedusaHeadV2(config, prefix, weights)
lm_head = None
else: else:
lm_head = TensorParallelHead.load(config, prefix, weights) lm_head = TensorParallelHead.load(config, prefix, weights)
medusa = None speculator = None
return SpeculativeHead(lm_head, medusa) return SpeculativeHead(lm_head, speculator)
def forward( def forward(
self, input: torch.Tensor self, input: torch.Tensor
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
if self.medusa is not None: if self.speculator is not None:
return self.medusa(input) return self.speculator(input)
assert self.head is not None assert self.head is not None
logits = self.head(input) logits = self.head(input)
......
import torch import torch
import os
from loguru import logger from loguru import logger
from transformers.configuration_utils import PretrainedConfig from transformers.configuration_utils import PretrainedConfig
from transformers.models.auto import modeling_auto from transformers.models.auto import modeling_auto
from huggingface_hub import hf_hub_download from huggingface_hub import hf_hub_download, HfApi
from typing import Optional from typing import Optional
from pathlib import Path from pathlib import Path
...@@ -135,8 +136,9 @@ def get_model( ...@@ -135,8 +136,9 @@ def get_model(
config_dict, _ = PretrainedConfig.get_config_dict( config_dict, _ = PretrainedConfig.get_config_dict(
model_id, revision=revision, trust_remote_code=trust_remote_code model_id, revision=revision, trust_remote_code=trust_remote_code
) )
model_type = config_dict.get("model_type", None)
use_medusa = None speculator = None
if "medusa_num_heads" in config_dict: if "medusa_num_heads" in config_dict:
medusa_model_id = model_id medusa_model_id = model_id
medusa_revision = revision medusa_revision = revision
...@@ -156,6 +158,8 @@ def get_model( ...@@ -156,6 +158,8 @@ def get_model(
config_dict, _ = PretrainedConfig.get_config_dict( config_dict, _ = PretrainedConfig.get_config_dict(
model_id, revision=revision, trust_remote_code=trust_remote_code model_id, revision=revision, trust_remote_code=trust_remote_code
) )
# Reload model type from parent.
model_type = config_dict.get("model_type", None)
is_local = Path(medusa_model_id).exists() is_local = Path(medusa_model_id).exists()
if not is_local: if not is_local:
medusa_config = hf_hub_download( medusa_config = hf_hub_download(
...@@ -166,11 +170,70 @@ def get_model( ...@@ -166,11 +170,70 @@ def get_model(
revision=medusa_revision, revision=medusa_revision,
filename="medusa_lm_head.safetensors", filename="medusa_lm_head.safetensors",
) )
use_medusa = Path(medusa_config).parent speculator = {
"path": Path(medusa_config).parent,
"model_paths": ["medusa_lm_head.safetensors"],
}
else: else:
use_medusa = Path(medusa_model_id) speculator = {
"path": Path(medusa_model_id),
"model_paths": ["medusa_lm_head.safetensors"],
}
method = "medusa" method = "medusa"
elif model_type == "mlp_speculator":
mlp_model_id = model_id
mlp_revision = revision
model_id = config_dict["base_model_name_or_path"]
revision = "main"
speculate_mlp = config_dict["n_predict"]
if speculate is not None:
if speculate > speculate_mlp:
raise RuntimeError(
f"Speculate is set to `{speculate}` but this mlp_speculator models only has `{speculate_mlp}` heads, please make them match"
)
else:
set_speculate(speculate)
else:
set_speculate(speculate_mlp)
config_dict, _ = PretrainedConfig.get_config_dict(
model_id, revision=revision, trust_remote_code=trust_remote_code
)
# Reload model type from parent.
model_type = config_dict.get("model_type", None)
is_local = Path(mlp_model_id).exists()
extension = ".safetensors"
if not is_local:
mlp_speculator_config = hf_hub_download(
mlp_model_id, revision=mlp_revision, filename="config.json"
)
api = HfApi()
info = api.model_info(mlp_model_id, revision=mlp_revision)
filenames = [
s.rfilename
for s in info.siblings
if s.rfilename.endswith(extension)
and len(s.rfilename.split("/")) == 1
and "arguments" not in s.rfilename
and "args" not in s.rfilename
and "training" not in s.rfilename
]
for filename in filenames:
hf_hub_download(
mlp_model_id,
revision=mlp_revision,
filename=filename,
)
speculator = {
"path": Path(mlp_speculator_config).parent,
"model_paths": filenames,
}
else:
speculator = Path(mlp_model_id)
filenames = [p for p in os.listdir(speculator) if p.endswith(extension)]
speculator = {"path": speculator, "model_paths": filenames}
method = "mlp_speculator"
else: else:
method = "n-gram" method = "n-gram"
...@@ -178,7 +241,6 @@ def get_model( ...@@ -178,7 +241,6 @@ def get_model(
if speculate > 0: if speculate > 0:
logger.info(f"Using speculation {method} with {speculate} input ids.") logger.info(f"Using speculation {method} with {speculate} input ids.")
model_type = config_dict.get("model_type", None)
if model_type is None: if model_type is None:
# TODO: fix how we determine model type for Mamba # TODO: fix how we determine model type for Mamba
if "ssm_cfg" in config_dict: if "ssm_cfg" in config_dict:
...@@ -202,7 +264,7 @@ def get_model( ...@@ -202,7 +264,7 @@ def get_model(
model_id, model_id,
revision, revision,
quantize=quantize, quantize=quantize,
use_medusa=use_medusa, speculator=speculator,
dtype=dtype, dtype=dtype,
trust_remote_code=trust_remote_code, trust_remote_code=trust_remote_code,
) )
...@@ -212,7 +274,7 @@ def get_model( ...@@ -212,7 +274,7 @@ def get_model(
model_id, model_id,
revision, revision,
quantize=quantize, quantize=quantize,
use_medusa=use_medusa, speculator=speculator,
dtype=dtype, dtype=dtype,
trust_remote_code=trust_remote_code, trust_remote_code=trust_remote_code,
) )
...@@ -227,7 +289,7 @@ def get_model( ...@@ -227,7 +289,7 @@ def get_model(
model_id, model_id,
revision, revision,
quantize=quantize, quantize=quantize,
use_medusa=use_medusa, speculator=speculator,
dtype=dtype, dtype=dtype,
trust_remote_code=trust_remote_code, trust_remote_code=trust_remote_code,
) )
...@@ -240,7 +302,7 @@ def get_model( ...@@ -240,7 +302,7 @@ def get_model(
model_id, model_id,
revision, revision,
quantize=quantize, quantize=quantize,
use_medusa=use_medusa, speculator=speculator,
dtype=dtype, dtype=dtype,
trust_remote_code=trust_remote_code, trust_remote_code=trust_remote_code,
) )
...@@ -250,7 +312,7 @@ def get_model( ...@@ -250,7 +312,7 @@ def get_model(
model_id, model_id,
revision, revision,
quantize=quantize, quantize=quantize,
use_medusa=use_medusa, speculator=speculator,
dtype=dtype, dtype=dtype,
trust_remote_code=trust_remote_code, trust_remote_code=trust_remote_code,
) )
...@@ -259,7 +321,7 @@ def get_model( ...@@ -259,7 +321,7 @@ def get_model(
model_id, model_id,
revision, revision,
quantize=quantize, quantize=quantize,
use_medusa=use_medusa, speculator=speculator,
dtype=dtype, dtype=dtype,
trust_remote_code=trust_remote_code, trust_remote_code=trust_remote_code,
) )
...@@ -270,7 +332,7 @@ def get_model( ...@@ -270,7 +332,7 @@ def get_model(
model_id, model_id,
revision, revision,
quantize=quantize, quantize=quantize,
use_medusa=use_medusa, speculator=speculator,
dtype=dtype, dtype=dtype,
trust_remote_code=trust_remote_code, trust_remote_code=trust_remote_code,
) )
...@@ -279,7 +341,7 @@ def get_model( ...@@ -279,7 +341,7 @@ def get_model(
model_id, model_id,
revision, revision,
quantize=quantize, quantize=quantize,
use_medusa=use_medusa, speculator=speculator,
dtype=dtype, dtype=dtype,
trust_remote_code=trust_remote_code, trust_remote_code=trust_remote_code,
) )
...@@ -288,7 +350,7 @@ def get_model( ...@@ -288,7 +350,7 @@ def get_model(
model_id, model_id,
revision, revision,
quantize=quantize, quantize=quantize,
use_medusa=use_medusa, speculator=speculator,
dtype=dtype, dtype=dtype,
trust_remote_code=trust_remote_code, trust_remote_code=trust_remote_code,
) )
...@@ -299,7 +361,7 @@ def get_model( ...@@ -299,7 +361,7 @@ def get_model(
model_id, model_id,
revision, revision,
quantize=quantize, quantize=quantize,
use_medusa=use_medusa, speculator=speculator,
dtype=dtype, dtype=dtype,
trust_remote_code=trust_remote_code, trust_remote_code=trust_remote_code,
) )
...@@ -308,7 +370,7 @@ def get_model( ...@@ -308,7 +370,7 @@ def get_model(
model_id, model_id,
revision, revision,
quantize=quantize, quantize=quantize,
use_medusa=use_medusa, speculator=speculator,
dtype=dtype, dtype=dtype,
trust_remote_code=trust_remote_code, trust_remote_code=trust_remote_code,
) )
...@@ -323,7 +385,7 @@ def get_model( ...@@ -323,7 +385,7 @@ def get_model(
model_id, model_id,
revision, revision,
quantize=quantize, quantize=quantize,
use_medusa=use_medusa, speculator=speculator,
dtype=dtype, dtype=dtype,
trust_remote_code=trust_remote_code, trust_remote_code=trust_remote_code,
) )
...@@ -334,7 +396,7 @@ def get_model( ...@@ -334,7 +396,7 @@ def get_model(
model_id, model_id,
revision, revision,
quantize=quantize, quantize=quantize,
use_medusa=use_medusa, speculator=speculator,
dtype=dtype, dtype=dtype,
trust_remote_code=trust_remote_code, trust_remote_code=trust_remote_code,
) )
...@@ -345,7 +407,7 @@ def get_model( ...@@ -345,7 +407,7 @@ def get_model(
model_id, model_id,
revision, revision,
quantize=quantize, quantize=quantize,
use_medusa=use_medusa, speculator=speculator,
dtype=dtype, dtype=dtype,
trust_remote_code=trust_remote_code, trust_remote_code=trust_remote_code,
) )
...@@ -355,7 +417,7 @@ def get_model( ...@@ -355,7 +417,7 @@ def get_model(
model_id, model_id,
revision, revision,
quantize=quantize, quantize=quantize,
use_medusa=use_medusa, speculator=speculator,
dtype=dtype, dtype=dtype,
trust_remote_code=trust_remote_code, trust_remote_code=trust_remote_code,
) )
...@@ -366,7 +428,7 @@ def get_model( ...@@ -366,7 +428,7 @@ def get_model(
model_id, model_id,
revision, revision,
quantize=quantize, quantize=quantize,
use_medusa=use_medusa, speculator=speculator,
dtype=dtype, dtype=dtype,
trust_remote_code=trust_remote_code, trust_remote_code=trust_remote_code,
) )
...@@ -377,7 +439,7 @@ def get_model( ...@@ -377,7 +439,7 @@ def get_model(
model_id, model_id,
revision, revision,
quantize=quantize, quantize=quantize,
use_medusa=use_medusa, speculator=speculator,
dtype=dtype, dtype=dtype,
trust_remote_code=trust_remote_code, trust_remote_code=trust_remote_code,
) )
...@@ -388,7 +450,7 @@ def get_model( ...@@ -388,7 +450,7 @@ def get_model(
model_id, model_id,
revision, revision,
quantize=quantize, quantize=quantize,
use_medusa=use_medusa, speculator=speculator,
dtype=dtype, dtype=dtype,
trust_remote_code=trust_remote_code, trust_remote_code=trust_remote_code,
) )
...@@ -399,7 +461,7 @@ def get_model( ...@@ -399,7 +461,7 @@ def get_model(
model_id, model_id,
revision, revision,
quantize=quantize, quantize=quantize,
use_medusa=use_medusa, speculator=speculator,
dtype=dtype, dtype=dtype,
trust_remote_code=trust_remote_code, trust_remote_code=trust_remote_code,
) )
...@@ -410,7 +472,7 @@ def get_model( ...@@ -410,7 +472,7 @@ def get_model(
model_id, model_id,
revision, revision,
quantize=quantize, quantize=quantize,
use_medusa=use_medusa, speculator=speculator,
dtype=dtype, dtype=dtype,
trust_remote_code=trust_remote_code, trust_remote_code=trust_remote_code,
) )
...@@ -424,7 +486,7 @@ def get_model( ...@@ -424,7 +486,7 @@ def get_model(
model_id, model_id,
revision, revision,
quantize=quantize, quantize=quantize,
use_medusa=use_medusa, speculator=speculator,
dtype=dtype, dtype=dtype,
trust_remote_code=trust_remote_code, trust_remote_code=trust_remote_code,
) )
...@@ -435,7 +497,7 @@ def get_model( ...@@ -435,7 +497,7 @@ def get_model(
model_id, model_id,
revision, revision,
quantize=quantize, quantize=quantize,
use_medusa=use_medusa, speculator=speculator,
dtype=dtype, dtype=dtype,
trust_remote_code=trust_remote_code, trust_remote_code=trust_remote_code,
) )
...@@ -444,7 +506,7 @@ def get_model( ...@@ -444,7 +506,7 @@ def get_model(
model_id, model_id,
revision, revision,
quantize=quantize, quantize=quantize,
use_medusa=use_medusa, speculator=speculator,
dtype=dtype, dtype=dtype,
trust_remote_code=trust_remote_code, trust_remote_code=trust_remote_code,
) )
...@@ -458,7 +520,7 @@ def get_model( ...@@ -458,7 +520,7 @@ def get_model(
model_id, model_id,
revision, revision,
quantize=quantize, quantize=quantize,
use_medusa=use_medusa, speculator=speculator,
dtype=dtype, dtype=dtype,
trust_remote_code=trust_remote_code, trust_remote_code=trust_remote_code,
) )
...@@ -469,7 +531,7 @@ def get_model( ...@@ -469,7 +531,7 @@ def get_model(
model_id, model_id,
revision, revision,
quantize=quantize, quantize=quantize,
use_medusa=use_medusa, speculator=speculator,
dtype=dtype, dtype=dtype,
trust_remote_code=trust_remote_code, trust_remote_code=trust_remote_code,
) )
...@@ -483,7 +545,7 @@ def get_model( ...@@ -483,7 +545,7 @@ def get_model(
model_id, model_id,
revision, revision,
quantize=quantize, quantize=quantize,
use_medusa=use_medusa, speculator=speculator,
dtype=dtype, dtype=dtype,
trust_remote_code=trust_remote_code, trust_remote_code=trust_remote_code,
) )
...@@ -494,7 +556,7 @@ def get_model( ...@@ -494,7 +556,7 @@ def get_model(
model_id, model_id,
revision, revision,
quantize=quantize, quantize=quantize,
use_medusa=use_medusa, speculator=speculator,
dtype=dtype, dtype=dtype,
trust_remote_code=trust_remote_code, trust_remote_code=trust_remote_code,
) )
...@@ -520,7 +582,7 @@ def get_model( ...@@ -520,7 +582,7 @@ def get_model(
model_id, model_id,
revision, revision,
quantize=quantize, quantize=quantize,
use_medusa=use_medusa, speculator=speculator,
dtype=dtype, dtype=dtype,
trust_remote_code=trust_remote_code, trust_remote_code=trust_remote_code,
) )
...@@ -544,7 +606,7 @@ def get_model( ...@@ -544,7 +606,7 @@ def get_model(
model_id, model_id,
revision, revision,
quantize=quantize, quantize=quantize,
use_medusa=use_medusa, speculator=speculator,
dtype=dtype, dtype=dtype,
trust_remote_code=trust_remote_code, trust_remote_code=trust_remote_code,
) )
...@@ -554,7 +616,7 @@ def get_model( ...@@ -554,7 +616,7 @@ def get_model(
model_id, model_id,
revision, revision,
quantize=quantize, quantize=quantize,
use_medusa=use_medusa, speculator=speculator,
dtype=dtype, dtype=dtype,
trust_remote_code=trust_remote_code, trust_remote_code=trust_remote_code,
) )
...@@ -564,7 +626,7 @@ def get_model( ...@@ -564,7 +626,7 @@ def get_model(
model_id, model_id,
revision, revision,
quantize=quantize, quantize=quantize,
use_medusa=use_medusa, speculator=speculator,
dtype=dtype, dtype=dtype,
trust_remote_code=trust_remote_code, trust_remote_code=trust_remote_code,
) )
...@@ -574,7 +636,7 @@ def get_model( ...@@ -574,7 +636,7 @@ def get_model(
model_id, model_id,
revision, revision,
quantize=quantize, quantize=quantize,
use_medusa=use_medusa, speculator=speculator,
dtype=dtype, dtype=dtype,
trust_remote_code=trust_remote_code, trust_remote_code=trust_remote_code,
) )
...@@ -586,7 +648,7 @@ def get_model( ...@@ -586,7 +648,7 @@ def get_model(
model_id, model_id,
revision, revision,
quantize=quantize, quantize=quantize,
use_medusa=use_medusa, speculator=speculator,
dtype=dtype, dtype=dtype,
trust_remote_code=trust_remote_code, trust_remote_code=trust_remote_code,
) )
...@@ -599,7 +661,7 @@ def get_model( ...@@ -599,7 +661,7 @@ def get_model(
model_id, model_id,
revision, revision,
quantize=quantize, quantize=quantize,
use_medusa=use_medusa, speculator=speculator,
dtype=dtype, dtype=dtype,
trust_remote_code=trust_remote_code, trust_remote_code=trust_remote_code,
) )
...@@ -623,7 +685,7 @@ def get_model( ...@@ -623,7 +685,7 @@ def get_model(
model_id, model_id,
revision, revision,
quantize=quantize, quantize=quantize,
use_medusa=use_medusa, speculator=speculator,
dtype=dtype, dtype=dtype,
trust_remote_code=trust_remote_code, trust_remote_code=trust_remote_code,
) )
...@@ -632,7 +694,7 @@ def get_model( ...@@ -632,7 +694,7 @@ def get_model(
model_id, model_id,
revision, revision,
quantize=quantize, quantize=quantize,
use_medusa=use_medusa, speculator=speculator,
dtype=dtype, dtype=dtype,
trust_remote_code=trust_remote_code, trust_remote_code=trust_remote_code,
) )
...@@ -644,7 +706,7 @@ def get_model( ...@@ -644,7 +706,7 @@ def get_model(
model_id, model_id,
revision, revision,
quantize=quantize, quantize=quantize,
use_medusa=use_medusa, speculator=speculator,
dtype=dtype, dtype=dtype,
trust_remote_code=trust_remote_code, trust_remote_code=trust_remote_code,
) )
...@@ -653,7 +715,7 @@ def get_model( ...@@ -653,7 +715,7 @@ def get_model(
model_id, model_id,
revision, revision,
quantize=quantize, quantize=quantize,
use_medusa=use_medusa, speculator=speculator,
dtype=dtype, dtype=dtype,
trust_remote_code=trust_remote_code, trust_remote_code=trust_remote_code,
) )
......
...@@ -42,7 +42,7 @@ class BLOOMSharded(CausalLM): ...@@ -42,7 +42,7 @@ class BLOOMSharded(CausalLM):
model_id: str, model_id: str,
revision: Optional[str] = None, revision: Optional[str] = None,
quantize: Optional[str] = None, quantize: Optional[str] = None,
use_medusa: Optional[str] = None, speculator: Optional[str] = None,
dtype: Optional[torch.dtype] = None, dtype: Optional[torch.dtype] = None,
trust_remote_code: bool = False, trust_remote_code: bool = False,
): ):
...@@ -71,7 +71,7 @@ class BLOOMSharded(CausalLM): ...@@ -71,7 +71,7 @@ class BLOOMSharded(CausalLM):
) )
config.pad_token_id = 3 config.pad_token_id = 3
config.quantize = quantize config.quantize = quantize
config.use_medusa = use_medusa config.speculator = speculator
torch.distributed.barrier(group=self.process_group) torch.distributed.barrier(group=self.process_group)
filenames = weight_files(model_id, revision=revision, extension=".safetensors") filenames = weight_files(model_id, revision=revision, extension=".safetensors")
......
...@@ -482,12 +482,12 @@ class CausalLM(Model): ...@@ -482,12 +482,12 @@ class CausalLM(Model):
model_id: str, model_id: str,
revision: Optional[str] = None, revision: Optional[str] = None,
quantize: Optional[str] = None, quantize: Optional[str] = None,
use_medusa: Optional[str] = None, speculator: Optional[str] = None,
dtype: Optional[torch.dtype] = None, dtype: Optional[torch.dtype] = None,
trust_remote_code: bool = False, trust_remote_code: bool = False,
): ):
if use_medusa: if speculator:
raise RuntimeError("Medusa decoding is not enabled for AutoModel") raise RuntimeError("Speculator decoding is not enabled for AutoModel")
if torch.cuda.is_available(): if torch.cuda.is_available():
device = torch.device("cuda") device = torch.device("cuda")
......
...@@ -683,9 +683,9 @@ class Idefics2ForConditionalGeneration(nn.Module): ...@@ -683,9 +683,9 @@ class Idefics2ForConditionalGeneration(nn.Module):
def __init__(self, prefix, config, weights): def __init__(self, prefix, config, weights):
super().__init__() super().__init__()
config.vision_config.quantize = config.quantize config.vision_config.quantize = config.quantize
config.vision_config.use_medusa = config.use_medusa config.vision_config.speculator = config.speculator
config.text_config.quantize = config.quantize config.text_config.quantize = config.quantize
config.text_config.use_medusa = config.use_medusa config.text_config.speculator = config.speculator
vision_config = config.vision_config vision_config = config.vision_config
self.text_model = load_text_model( self.text_model = load_text_model(
......
...@@ -135,7 +135,7 @@ class LlavaNextForConditionalGeneration(nn.Module): ...@@ -135,7 +135,7 @@ class LlavaNextForConditionalGeneration(nn.Module):
self.vocab_size = config.text_config.vocab_size self.vocab_size = config.text_config.vocab_size
self.config = config self.config = config
config.text_config.quantize = config.quantize config.text_config.quantize = config.quantize
config.text_config.use_medusa = config.use_medusa config.text_config.speculator = config.speculator
self.language_model = load_text_model( self.language_model = load_text_model(
prefix="language_model" if not prefix else f"{prefix}.language_model", prefix="language_model" if not prefix else f"{prefix}.language_model",
config=config.text_config, config=config.text_config,
......
...@@ -1101,6 +1101,8 @@ class FlashCausalLM(Model): ...@@ -1101,6 +1101,8 @@ class FlashCausalLM(Model):
next_token_texts = [] next_token_texts = []
left = 0 left = 0
logger.info(f"Accepted ids {n_accepted_ids}")
current_stopped = False current_stopped = False
for j in range(index, index + n_accepted_ids): for j in range(index, index + n_accepted_ids):
# Generated token # Generated token
......
...@@ -24,7 +24,7 @@ class FlashCohere(FlashCausalLM): ...@@ -24,7 +24,7 @@ class FlashCohere(FlashCausalLM):
model_id: str, model_id: str,
revision: Optional[str] = None, revision: Optional[str] = None,
quantize: Optional[str] = None, quantize: Optional[str] = None,
use_medusa: Optional[str] = None, speculator: Optional[str] = None,
dtype: Optional[torch.dtype] = None, dtype: Optional[torch.dtype] = None,
trust_remote_code: bool = False, trust_remote_code: bool = False,
): ):
...@@ -49,7 +49,7 @@ class FlashCohere(FlashCausalLM): ...@@ -49,7 +49,7 @@ class FlashCohere(FlashCausalLM):
model_id, revision=revision, trust_remote_code=trust_remote_code model_id, revision=revision, trust_remote_code=trust_remote_code
) )
config.quantize = quantize config.quantize = quantize
config.use_medusa = use_medusa config.speculator = speculator
torch.distributed.barrier(group=self.process_group) torch.distributed.barrier(group=self.process_group)
......
...@@ -26,7 +26,7 @@ class FlashDbrx(FlashCausalLM): ...@@ -26,7 +26,7 @@ class FlashDbrx(FlashCausalLM):
model_id: str, model_id: str,
revision: Optional[str] = None, revision: Optional[str] = None,
quantize: Optional[str] = None, quantize: Optional[str] = None,
use_medusa: Optional[str] = None, speculator: Optional[str] = None,
dtype: Optional[torch.dtype] = None, dtype: Optional[torch.dtype] = None,
trust_remote_code: bool = False, trust_remote_code: bool = False,
): ):
...@@ -74,7 +74,7 @@ class FlashDbrx(FlashCausalLM): ...@@ -74,7 +74,7 @@ class FlashDbrx(FlashCausalLM):
model_id, revision=revision, trust_remote_code=trust_remote_code model_id, revision=revision, trust_remote_code=trust_remote_code
) )
config.quantize = quantize config.quantize = quantize
config.use_medusa = use_medusa config.speculator = speculator
torch.distributed.barrier(group=self.process_group) torch.distributed.barrier(group=self.process_group)
......
...@@ -25,7 +25,7 @@ class FlashGemma(FlashCausalLM): ...@@ -25,7 +25,7 @@ class FlashGemma(FlashCausalLM):
model_id: str, model_id: str,
revision: Optional[str] = None, revision: Optional[str] = None,
quantize: Optional[str] = None, quantize: Optional[str] = None,
use_medusa: Optional[str] = None, speculator: Optional[str] = None,
dtype: Optional[torch.dtype] = None, dtype: Optional[torch.dtype] = None,
trust_remote_code: bool = False, trust_remote_code: bool = False,
): ):
...@@ -50,7 +50,7 @@ class FlashGemma(FlashCausalLM): ...@@ -50,7 +50,7 @@ class FlashGemma(FlashCausalLM):
model_id, revision=revision, trust_remote_code=trust_remote_code model_id, revision=revision, trust_remote_code=trust_remote_code
) )
config.quantize = quantize config.quantize = quantize
config.use_medusa = use_medusa config.speculator = speculator
torch.distributed.barrier(group=self.process_group) torch.distributed.barrier(group=self.process_group)
......
...@@ -27,7 +27,7 @@ class FlashLlama(FlashCausalLM): ...@@ -27,7 +27,7 @@ class FlashLlama(FlashCausalLM):
model_id: str, model_id: str,
revision: Optional[str] = None, revision: Optional[str] = None,
quantize: Optional[str] = None, quantize: Optional[str] = None,
use_medusa: Optional[str] = None, speculator: Optional[str] = None,
dtype: Optional[torch.dtype] = None, dtype: Optional[torch.dtype] = None,
trust_remote_code: bool = False, trust_remote_code: bool = False,
): ):
...@@ -71,7 +71,7 @@ class FlashLlama(FlashCausalLM): ...@@ -71,7 +71,7 @@ class FlashLlama(FlashCausalLM):
model_id, revision=revision, trust_remote_code=trust_remote_code model_id, revision=revision, trust_remote_code=trust_remote_code
) )
config.quantize = quantize config.quantize = quantize
config.use_medusa = use_medusa config.speculator = speculator
torch.distributed.barrier(group=self.process_group) torch.distributed.barrier(group=self.process_group)
......
...@@ -313,7 +313,7 @@ class BaseFlashMistral(FlashCausalLM): ...@@ -313,7 +313,7 @@ class BaseFlashMistral(FlashCausalLM):
config_cls=AutoConfig, config_cls=AutoConfig,
revision: Optional[str] = None, revision: Optional[str] = None,
quantize: Optional[str] = None, quantize: Optional[str] = None,
use_medusa: Optional[str] = None, speculator: Optional[str] = None,
dtype: Optional[torch.dtype] = None, dtype: Optional[torch.dtype] = None,
trust_remote_code: bool = False, trust_remote_code: bool = False,
tokenizer_class=AutoTokenizer, tokenizer_class=AutoTokenizer,
...@@ -340,7 +340,7 @@ class BaseFlashMistral(FlashCausalLM): ...@@ -340,7 +340,7 @@ class BaseFlashMistral(FlashCausalLM):
model_id, revision=revision, trust_remote_code=trust_remote_code model_id, revision=revision, trust_remote_code=trust_remote_code
) )
config.quantize = quantize config.quantize = quantize
config.use_medusa = use_medusa config.speculator = speculator
# Set context windows # Set context windows
if getattr(config, "sliding_window", None) is not None: if getattr(config, "sliding_window", None) is not None:
...@@ -567,7 +567,7 @@ class FlashMistral(BaseFlashMistral): ...@@ -567,7 +567,7 @@ class FlashMistral(BaseFlashMistral):
model_id: str, model_id: str,
revision: Optional[str] = None, revision: Optional[str] = None,
quantize: Optional[str] = None, quantize: Optional[str] = None,
use_medusa: Optional[str] = None, speculator: Optional[str] = None,
dtype: Optional[torch.dtype] = None, dtype: Optional[torch.dtype] = None,
trust_remote_code: bool = False, trust_remote_code: bool = False,
): ):
...@@ -577,7 +577,7 @@ class FlashMistral(BaseFlashMistral): ...@@ -577,7 +577,7 @@ class FlashMistral(BaseFlashMistral):
model_id=model_id, model_id=model_id,
revision=revision, revision=revision,
quantize=quantize, quantize=quantize,
use_medusa=use_medusa, speculator=speculator,
dtype=dtype, dtype=dtype,
trust_remote_code=trust_remote_code, trust_remote_code=trust_remote_code,
) )
...@@ -15,7 +15,7 @@ class FlashMixtral(BaseFlashMistral): ...@@ -15,7 +15,7 @@ class FlashMixtral(BaseFlashMistral):
model_id: str, model_id: str,
revision: Optional[str] = None, revision: Optional[str] = None,
quantize: Optional[str] = None, quantize: Optional[str] = None,
use_medusa: Optional[str] = None, speculator: Optional[str] = None,
dtype: Optional[torch.dtype] = None, dtype: Optional[torch.dtype] = None,
trust_remote_code: bool = False, trust_remote_code: bool = False,
): ):
...@@ -25,7 +25,7 @@ class FlashMixtral(BaseFlashMistral): ...@@ -25,7 +25,7 @@ class FlashMixtral(BaseFlashMistral):
model_id=model_id, model_id=model_id,
revision=revision, revision=revision,
quantize=quantize, quantize=quantize,
use_medusa=use_medusa, speculator=speculator,
dtype=dtype, dtype=dtype,
trust_remote_code=trust_remote_code, trust_remote_code=trust_remote_code,
) )
...@@ -25,7 +25,7 @@ class FlashNeoXSharded(FlashCausalLM): ...@@ -25,7 +25,7 @@ class FlashNeoXSharded(FlashCausalLM):
model_id: str, model_id: str,
revision: Optional[str] = None, revision: Optional[str] = None,
quantize: Optional[str] = None, quantize: Optional[str] = None,
use_medusa: Optional[str] = None, speculator: Optional[str] = None,
dtype: Optional[torch.dtype] = None, dtype: Optional[torch.dtype] = None,
trust_remote_code: bool = False, trust_remote_code: bool = False,
): ):
...@@ -51,7 +51,7 @@ class FlashNeoXSharded(FlashCausalLM): ...@@ -51,7 +51,7 @@ class FlashNeoXSharded(FlashCausalLM):
model_id, revision=revision, trust_remote_code=trust_remote_code model_id, revision=revision, trust_remote_code=trust_remote_code
) )
config.quantize = quantize config.quantize = quantize
config.use_medusa = use_medusa config.speculator = speculator
torch.distributed.barrier(group=self.process_group) torch.distributed.barrier(group=self.process_group)
filenames = weight_files(model_id, revision=revision, extension=".safetensors") filenames = weight_files(model_id, revision=revision, extension=".safetensors")
......
...@@ -25,7 +25,7 @@ class FlashPhi(FlashCausalLM): ...@@ -25,7 +25,7 @@ class FlashPhi(FlashCausalLM):
model_id: str, model_id: str,
revision: Optional[str] = None, revision: Optional[str] = None,
quantize: Optional[str] = None, quantize: Optional[str] = None,
use_medusa: Optional[str] = None, speculator: Optional[str] = None,
dtype: Optional[torch.dtype] = None, dtype: Optional[torch.dtype] = None,
trust_remote_code: bool = False, trust_remote_code: bool = False,
): ):
...@@ -48,7 +48,7 @@ class FlashPhi(FlashCausalLM): ...@@ -48,7 +48,7 @@ class FlashPhi(FlashCausalLM):
model_id, revision=revision, trust_remote_code=trust_remote_code model_id, revision=revision, trust_remote_code=trust_remote_code
) )
config.quantize = quantize config.quantize = quantize
config.use_medusa = use_medusa config.speculator = speculator
torch.distributed.barrier(group=self.process_group) torch.distributed.barrier(group=self.process_group)
...@@ -58,7 +58,7 @@ class FlashPhi(FlashCausalLM): ...@@ -58,7 +58,7 @@ class FlashPhi(FlashCausalLM):
weights._set_gptq_params(model_id, revision) weights._set_gptq_params(model_id, revision)
model = FlashPhiForCausalLM(config, weights) model = FlashPhiForCausalLM(config, weights)
if use_medusa: if speculator:
from text_generation_server.utils.medusa import MedusaModel from text_generation_server.utils.medusa import MedusaModel
from huggingface_hub import hf_hub_download from huggingface_hub import hf_hub_download
import json import json
...@@ -66,19 +66,19 @@ class FlashPhi(FlashCausalLM): ...@@ -66,19 +66,19 @@ class FlashPhi(FlashCausalLM):
from pathlib import Path from pathlib import Path
is_local_model = ( is_local_model = (
Path(use_medusa).exists() and Path(use_medusa).is_dir() Path(speculator).exists() and Path(speculator).is_dir()
) or os.getenv("WEIGHTS_CACHE_OVERRIDE", None) is not None ) or os.getenv("WEIGHTS_CACHE_OVERRIDE", None) is not None
if not is_local_model: if not is_local_model:
medusa_config = hf_hub_download( medusa_config = hf_hub_download(
use_medusa, revision=revision, filename="config.json" speculator, revision=revision, filename="config.json"
) )
medusa_head = hf_hub_download( medusa_head = hf_hub_download(
use_medusa, revision=revision, filename="medusa_lm_head.pt" speculator, revision=revision, filename="medusa_lm_head.pt"
) )
else: else:
medusa_config = str(Path(use_medusa) / "config.json") medusa_config = str(Path(speculator) / "config.json")
medusa_head = str(Path(use_medusa) / "medusa_lm_head.pt") medusa_head = str(Path(speculator) / "medusa_lm_head.pt")
with open(medusa_config, "r") as f: with open(medusa_config, "r") as f:
config = json.load(f) config = json.load(f)
......
...@@ -30,7 +30,7 @@ class FlashQwen2(BaseFlashMistral): ...@@ -30,7 +30,7 @@ class FlashQwen2(BaseFlashMistral):
model_id: str, model_id: str,
revision: Optional[str] = None, revision: Optional[str] = None,
quantize: Optional[str] = None, quantize: Optional[str] = None,
use_medusa: Optional[str] = None, speculator: Optional[str] = None,
dtype: Optional[torch.dtype] = None, dtype: Optional[torch.dtype] = None,
trust_remote_code: bool = False, trust_remote_code: bool = False,
): ):
...@@ -53,7 +53,7 @@ class FlashQwen2(BaseFlashMistral): ...@@ -53,7 +53,7 @@ class FlashQwen2(BaseFlashMistral):
model_id, revision=revision, trust_remote_code=trust_remote_code model_id, revision=revision, trust_remote_code=trust_remote_code
) )
config.quantize = quantize config.quantize = quantize
config.use_medusa = use_medusa config.speculator = speculator
# Set context windows # Set context windows
if config.sliding_window is not None: if config.sliding_window is not None:
......
...@@ -26,7 +26,7 @@ class FlashRWSharded(FlashCausalLM): ...@@ -26,7 +26,7 @@ class FlashRWSharded(FlashCausalLM):
model_id: str, model_id: str,
revision: Optional[str] = None, revision: Optional[str] = None,
quantize: Optional[str] = None, quantize: Optional[str] = None,
use_medusa: Optional[str] = None, speculator: Optional[str] = None,
dtype: Optional[torch.dtype] = None, dtype: Optional[torch.dtype] = None,
trust_remote_code: bool = False, trust_remote_code: bool = False,
): ):
...@@ -66,7 +66,7 @@ class FlashRWSharded(FlashCausalLM): ...@@ -66,7 +66,7 @@ class FlashRWSharded(FlashCausalLM):
) )
config.quantize = quantize config.quantize = quantize
config.use_medusa = use_medusa config.speculator = speculator
if config.quantize == "gptq": if config.quantize == "gptq":
weights._set_gptq_params(model_id, revision) weights._set_gptq_params(model_id, revision)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment