Commit 72ee382d authored by OlivierDehaene's avatar OlivierDehaene
Browse files

chore: formatting

parent 3a521c92
......@@ -25,6 +25,7 @@ DOCKER_VOLUME = os.getenv("DOCKER_VOLUME", "/data")
class ResponseComparator(JSONSnapshotExtension):
rtol = 0.2
def serialize(
self,
data,
......@@ -69,7 +70,9 @@ class ResponseComparator(JSONSnapshotExtension):
prefill_token.id == other.id
and prefill_token.text == other.text
and (
math.isclose(prefill_token.logprob, other.logprob, rel_tol=self.rtol)
math.isclose(
prefill_token.logprob, other.logprob, rel_tol=self.rtol
)
if prefill_token.logprob is not None
else prefill_token.logprob == other.logprob
)
......@@ -153,6 +156,7 @@ class GenerousResponseComparator(ResponseComparator):
# Needed for GPTQ with exllama which has serious numerical fluctuations.
rtol = 0.75
class LauncherHandle:
def __init__(self, port: int):
self.client = AsyncClient(f"http://localhost:{port}")
......@@ -198,6 +202,7 @@ class ProcessLauncherHandle(LauncherHandle):
def response_snapshot(snapshot):
return snapshot.use_extension(ResponseComparator)
@pytest.fixture
def generous_response_snapshot(snapshot):
return snapshot.use_extension(GenerousResponseComparator)
......@@ -219,7 +224,7 @@ def launcher(event_loop):
quantize: Optional[str] = None,
trust_remote_code: bool = False,
use_flash_attention: bool = True,
dtype: Optional[str] = None
dtype: Optional[str] = None,
):
port = random.randint(8000, 10_000)
master_port = random.randint(10_000, 20_000)
......@@ -282,7 +287,7 @@ def launcher(event_loop):
quantize: Optional[str] = None,
trust_remote_code: bool = False,
use_flash_attention: bool = True,
dtype: Optional[str] = None
dtype: Optional[str] = None,
):
port = random.randint(8000, 10_000)
......@@ -335,7 +340,7 @@ def launcher(event_loop):
],
volumes=volumes,
ports={"80/tcp": port},
shm_size="1G"
shm_size="1G",
)
yield ContainerLauncherHandle(client, container.name, port)
......
......@@ -50,10 +50,16 @@ async def test_flash_medusa_all_params(flash_medusa, response_snapshot):
@pytest.mark.asyncio
@pytest.mark.private
async def test_flash_medusa_load(flash_medusa, generate_load, response_snapshot):
responses = await generate_load(flash_medusa, "What is Deep Learning?", max_new_tokens=10, n=4)
responses = await generate_load(
flash_medusa, "What is Deep Learning?", max_new_tokens=10, n=4
)
assert len(responses) == 4
assert all([r.generated_text == responses[0].generated_text for r in responses]), f"{[r.generated_text for r in responses]}"
assert responses[0].generated_text == '\nDeep learning is a subset of machine learning'
assert all(
[r.generated_text == responses[0].generated_text for r in responses]
), f"{[r.generated_text for r in responses]}"
assert (
responses[0].generated_text == "\nDeep learning is a subset of machine learning"
)
assert responses == response_snapshot
......@@ -56,7 +56,9 @@ async def test_flash_mistral_load(flash_mistral, generate_load, response_snapsho
)
assert len(responses) == 4
assert all([r.generated_text == responses[0].generated_text for r in responses]), f"{[r.generated_text for r in responses]}"
assert all(
[r.generated_text == responses[0].generated_text for r in responses]
), f"{[r.generated_text for r in responses]}"
assert responses[0].generated_text == ": Let n = 10 - 1"
assert responses == response_snapshot
......@@ -3,7 +3,9 @@ import pytest
@pytest.fixture(scope="module")
def idefics_handle(launcher):
with launcher("HuggingFaceM4/idefics-9b-instruct", num_shard=2, dtype="float16") as handle:
with launcher(
"HuggingFaceM4/idefics-9b-instruct", num_shard=2, dtype="float16"
) as handle:
yield handle
......
......@@ -133,8 +133,20 @@ def test_causal_lm_generate_token(default_bloom, default_bloom_batch):
)
assert all([generation.generated_text is None for generation in generations])
assert all([len(generation.prefill_tokens) == 1 for generation in generations])
assert all([token_id.item() == 10264 for generation in generations for token_id in generation.tokens.token_ids])
assert all([token_text == "Test" for generation in generations for token_text in generation.tokens.texts])
assert all(
[
token_id.item() == 10264
for generation in generations
for token_id in generation.tokens.token_ids
]
)
assert all(
[
token_text == "Test"
for generation in generations
for token_text in generation.tokens.texts
]
)
assert generations[0].request_id == 0
......
......@@ -129,8 +129,20 @@ def test_causal_lm_generate_token(default_causal_lm, default_causal_lm_batch):
)
assert all([generation.generated_text is None for generation in generations])
assert all([len(generation.prefill_tokens) == 1 for generation in generations])
assert all([token_id.item() == 13 for generation in generations for token_id in generation.tokens.token_ids])
assert all([token_text == "." for generation in generations for token_text in generation.tokens.texts])
assert all(
[
token_id.item() == 13
for generation in generations
for token_id in generation.tokens.token_ids
]
)
assert all(
[
token_text == "."
for generation in generations
for token_text in generation.tokens.texts
]
)
assert generations[0].request_id == 0
......
......@@ -151,8 +151,20 @@ def test_seq2seq_lm_generate_token(default_seq2seq_lm, default_seq2seq_lm_batch)
)
assert all([generation.generated_text is None for generation in generations])
assert all([len(generation.prefill_tokens) == 1 for generation in generations])
assert all([token_id.item() == 259 for generation in generations for token_id in generation.tokens.token_ids])
assert all([token_text == " " for generation in generations for token_text in generation.tokens.texts])
assert all(
[
token_id.item() == 259
for generation in generations
for token_id in generation.tokens.token_ids
]
)
assert all(
[
token_text == " "
for generation in generations
for token_text in generation.tokens.texts
]
)
assert generations[0].request_id == 0
......
......@@ -77,12 +77,24 @@ def serve(
# Downgrade enum into str for easier management later on
quantize = None if quantize is None else quantize.value
dtype = None if dtype is None else dtype.value
if dtype is not None and quantize not in {None, "bitsandbytes", "bitsandbytes-nf4", "bitsandbytes-fp4"}:
if dtype is not None and quantize not in {
None,
"bitsandbytes",
"bitsandbytes-nf4",
"bitsandbytes-fp4",
}:
raise RuntimeError(
"Only 1 can be set between `dtype` and `quantize`, as they both decide how goes the final model."
)
server.serve(
model_id, revision, sharded, quantize, speculate, dtype, trust_remote_code, uds_path
model_id,
revision,
sharded,
quantize,
speculate,
dtype,
trust_remote_code,
uds_path,
)
......@@ -140,12 +152,17 @@ def download_weights(
try:
import json
medusa_head = hf_hub_download(model_id, revision=revision, filename="medusa_lm_head.pt")
medusa_head = hf_hub_download(
model_id, revision=revision, filename="medusa_lm_head.pt"
)
if auto_convert:
medusa_sf = Path(medusa_head[:-len(".pt")] + ".safetensors")
medusa_sf = Path(medusa_head[: -len(".pt")] + ".safetensors")
if not medusa_sf.exists():
utils.convert_files([Path(medusa_head)], [medusa_sf], [])
medusa_config = hf_hub_download(model_id, revision=revision, filename="config.json")
medusa_config = hf_hub_download(
model_id, revision=revision, filename="config.json"
)
with open(medusa_config, "r") as f:
config = json.load(f)
......@@ -153,10 +170,17 @@ def download_weights(
revision = "main"
try:
utils.weight_files(model_id, revision, extension)
logger.info(f"Files for parent {model_id} are already present on the host. " "Skipping download.")
logger.info(
f"Files for parent {model_id} are already present on the host. "
"Skipping download."
)
return
# Local files not found
except (utils.LocalEntryNotFoundError, FileNotFoundError, utils.EntryNotFoundError):
except (
utils.LocalEntryNotFoundError,
FileNotFoundError,
utils.EntryNotFoundError,
):
pass
except (utils.LocalEntryNotFoundError, utils.EntryNotFoundError):
pass
......
......@@ -88,7 +88,6 @@ if MIXTRAL:
__all__.append(FlashMixtral)
def get_model(
model_id: str,
revision: Optional[str],
......@@ -157,7 +156,9 @@ def get_model(
speculate_medusa = config_dict["medusa_num_heads"]
if speculate is not None:
if speculate > speculate_medusa:
raise RuntimeError("Speculate is set to `{speculate}` but this medusa models only has `{speculate_medusa}` heads, please make them match")
raise RuntimeError(
"Speculate is set to `{speculate}` but this medusa models only has `{speculate_medusa}` heads, please make them match"
)
else:
set_speculate(speculate)
else:
......@@ -249,7 +250,7 @@ def get_model(
quantize=quantize,
dtype=dtype,
trust_remote_code=trust_remote_code,
use_medusa=use_medusa
use_medusa=use_medusa,
)
elif sharded:
raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Sharded Llama"))
......@@ -313,7 +314,9 @@ def get_model(
dtype=dtype,
trust_remote_code=trust_remote_code,
)
raise NotImplementedError("Mixtral models requires flash attention v2, stk and megablocks")
raise NotImplementedError(
"Mixtral models requires flash attention v2, stk and megablocks"
)
if model_type == "opt":
return OPTSharded(
......@@ -354,7 +357,7 @@ def get_model(
raise ValueError("awq quantization is not supported for AutoModel")
elif (quantize == "bitsandbytes-fp4") or (quantize == "bitsandbytes-nf4"):
raise ValueError("4bit quantization is not supported for AutoModel")
elif (quantize == "eetq"):
elif quantize == "eetq":
raise ValueError("Eetq quantization is not supported for AutoModel")
if model_type in modeling_auto.MODEL_FOR_CAUSAL_LM_MAPPING_NAMES:
return CausalLM(
......
......@@ -74,7 +74,11 @@ class BLOOMSharded(CausalLM):
torch.distributed.barrier(group=self.process_group)
filenames = weight_files(model_id, revision=revision, extension=".safetensors")
weights = Weights(
filenames, device=device, dtype=dtype, process_group=self.process_group, prefix="transformer",
filenames,
device=device,
dtype=dtype,
process_group=self.process_group,
prefix="transformer",
)
if config.quantize == "gptq":
weights._set_gptq_params(model_id)
......
......@@ -510,7 +510,11 @@ class CausalLM(Model):
load_in_8bit=quantize == "bitsandbytes",
trust_remote_code=trust_remote_code,
)
if torch.cuda.is_available() and torch.cuda.device_count() == 1 and quantize != "bitsandbytes":
if (
torch.cuda.is_available()
and torch.cuda.device_count() == 1
and quantize != "bitsandbytes"
):
model = model.cuda()
if tokenizer.pad_token_id is None:
......@@ -676,7 +680,10 @@ class CausalLM(Model):
skip_special_tokens=False,
)
prefill_tokens = Tokens(
prefill_token_ids, prefill_logprobs, prefill_texts, is_special=[]
prefill_token_ids,
prefill_logprobs,
prefill_texts,
is_special=[],
)
else:
prefill_tokens = None
......@@ -703,11 +710,11 @@ class CausalLM(Model):
request.id,
prefill_tokens,
Tokens(
[next_token_id_squeezed],
[next_token_logprob],
[next_token_text],
[next_token_id_squeezed.item() in self.all_special_ids],
),
[next_token_id_squeezed],
[next_token_logprob],
[next_token_text],
[next_token_id_squeezed.item() in self.all_special_ids],
),
generated_text,
top_tokens,
)
......
......@@ -34,9 +34,10 @@ from text_generation_server.utils.layers import (
PositionRotaryEmbedding,
TensorParallelHead,
get_linear,
FastRMSNorm
FastRMSNorm,
)
class LlamaConfig(PretrainedConfig):
def __init__(
self,
......@@ -202,7 +203,7 @@ class FlashLlamaAttention(torch.nn.Module):
)
query = query.view(-1, self.num_heads, self.head_size)
kv = kv.view(-1, 2, self.num_key_value_heads, self.head_size)
self.rotary_emb(query, torch.select(kv, dim=1, index=0), cos, sin)
paged_attention.reshape_and_cache(
......@@ -237,7 +238,7 @@ class FlashLlamaAttention(torch.nn.Module):
input_lengths,
max_s,
)
return self.o_proj(attn_output.view(-1, self.num_heads * self.head_size))
......@@ -288,7 +289,9 @@ class FlashLlamaLayer(nn.Module):
)
self.mlp = LlamaMLP(prefix=f"{prefix}.mlp", config=config, weights=weights)
self.input_layernorm = FastRMSNorm.load(prefix=f"{prefix}.input_layernorm", weights=weights, eps=config.rms_norm_eps)
self.input_layernorm = FastRMSNorm.load(
prefix=f"{prefix}.input_layernorm", weights=weights, eps=config.rms_norm_eps
)
self.post_attention_layernorm = FastRMSNorm.load(
prefix=f"{prefix}.post_attention_layernorm",
weights=weights,
......
......@@ -27,7 +27,11 @@ from transformers.configuration_utils import PretrainedConfig
from typing import Optional, List, Tuple
from text_generation_server.utils import paged_attention, flash_attn
from text_generation_server.utils.flash_attn import attention, HAS_FLASH_ATTN_V2_ROCM, HAS_FLASH_ATTN_V2_CUDA
from text_generation_server.utils.flash_attn import (
attention,
HAS_FLASH_ATTN_V2_ROCM,
HAS_FLASH_ATTN_V2_CUDA,
)
from text_generation_server.utils.layers import (
TensorParallelRowLinear,
TensorParallelColumnLinear,
......@@ -35,7 +39,7 @@ from text_generation_server.utils.layers import (
PositionRotaryEmbedding,
TensorParallelHead,
get_linear,
FastRMSNorm
FastRMSNorm,
)
......@@ -96,6 +100,7 @@ class MistralConfig(PretrainedConfig):
**kwargs,
)
def load_attention(config, prefix, weights):
if config.num_attention_heads != config.num_key_value_heads:
return _load_gqa(config, prefix, weights)
......
......@@ -29,7 +29,10 @@ from transformers.configuration_utils import PretrainedConfig
from typing import Optional, List, Tuple
from text_generation_server.utils import paged_attention, flash_attn
from text_generation_server.utils.flash_attn import HAS_FLASH_ATTN_V2_ROCM, HAS_FLASH_ATTN_V2_CUDA
from text_generation_server.utils.flash_attn import (
HAS_FLASH_ATTN_V2_ROCM,
HAS_FLASH_ATTN_V2_CUDA,
)
from text_generation_server.utils.layers import (
FastLinear,
FastRMSNorm,
......@@ -59,28 +62,28 @@ class MixtralConfig(PretrainedConfig):
model_type = "mixtral"
def __init__(
self,
vocab_size=32000,
hidden_size=4096,
intermediate_size=14336,
num_hidden_layers=32,
num_attention_heads=32,
num_key_value_heads=8,
hidden_act="silu",
max_position_embeddings=4096 * 32,
initializer_range=0.02,
rms_norm_eps=1e-05,
use_cache=True,
pad_token_id=None,
bos_token_id=1,
eos_token_id=2,
pretraining_tp=1,
tie_word_embeddings=False,
rope_theta=10000.0,
sliding_window=4096,
num_experts_per_tok=2,
num_local_experts=8,
**kwargs,
self,
vocab_size=32000,
hidden_size=4096,
intermediate_size=14336,
num_hidden_layers=32,
num_attention_heads=32,
num_key_value_heads=8,
hidden_act="silu",
max_position_embeddings=4096 * 32,
initializer_range=0.02,
rms_norm_eps=1e-05,
use_cache=True,
pad_token_id=None,
bos_token_id=1,
eos_token_id=2,
pretraining_tp=1,
tie_word_embeddings=False,
rope_theta=10000.0,
sliding_window=4096,
num_experts_per_tok=2,
num_local_experts=8,
**kwargs,
):
self.vocab_size = vocab_size
self.max_position_embeddings = max_position_embeddings
......@@ -166,16 +169,18 @@ def _load_experts(config, prefix, mat, weights):
rank = weights.process_group.rank()
assert (
config.intermediate_size % world_size == 0
config.intermediate_size % world_size == 0
), f"The chosen size {config.intermediate_size} is not compatible with sharding on {world_size} shards"
block_size = config.intermediate_size // world_size
start = rank * block_size
stop = (rank + 1) * block_size
tensor = torch.empty((config.num_local_experts * block_size, config.hidden_size),
dtype=weights.dtype,
device=weights.device)
tensor = torch.empty(
(config.num_local_experts * block_size, config.hidden_size),
dtype=weights.dtype,
device=weights.device,
)
for i in range(config.num_local_experts):
slice_ = weights._get_slice(f"{prefix}.{i}.{mat}.weight")
......@@ -184,16 +189,18 @@ def _load_experts(config, prefix, mat, weights):
expert_slice = slice_[:, start:stop].t().contiguous()
else:
expert_slice = slice_[start:stop]
tensor[i * block_size:(i + 1) * block_size] = expert_slice.to(dtype=weights.dtype).to(device=weights.device)
tensor[i * block_size : (i + 1) * block_size] = expert_slice.to(
dtype=weights.dtype
).to(device=weights.device)
return tensor
class MixtralAttention(torch.nn.Module):
def __init__(
self,
prefix: str,
config,
weights,
self,
prefix: str,
config,
weights,
):
super().__init__()
self.max_past = (
......@@ -210,7 +217,7 @@ class MixtralAttention(torch.nn.Module):
device=weights.device,
)
self.softmax_scale = self.head_size ** -0.5
self.softmax_scale = self.head_size**-0.5
if self.num_heads % weights.process_group.size() != 0:
raise ValueError(
......@@ -219,7 +226,7 @@ class MixtralAttention(torch.nn.Module):
)
self.num_heads = self.num_heads // weights.process_group.size()
self.num_key_value_heads = (
config.num_key_value_heads // weights.process_group.size()
config.num_key_value_heads // weights.process_group.size()
)
self.query_key_value = load_attention(config, prefix, weights)
......@@ -236,17 +243,17 @@ class MixtralAttention(torch.nn.Module):
).repeat_interleave(self.num_groups)
def forward(
self,
hidden_states,
cos,
sin,
cu_seqlen_prefill,
kv_cache,
block_tables,
slots,
input_lengths,
max_s,
prefill_cache_indices,
self,
hidden_states,
cos,
sin,
cu_seqlen_prefill,
kv_cache,
block_tables,
slots,
input_lengths,
max_s,
prefill_cache_indices,
):
qkv = self.query_key_value(hidden_states)
query, kv = qkv.split(
......@@ -399,8 +406,9 @@ class BlockSparseMoE(nn.Module):
# Indices for the sparse matrix. The indices for
# the intermediate matrix are dynamic depending
# on the mapping of tokens to experts.
column_indices = ops.topology(padded_bins, self.blocking, block_rows,
blocks_per_row)
column_indices = ops.topology(
padded_bins, self.blocking, block_rows, blocks_per_row
)
# For now, use meta init to save the device memory.
data = torch.empty(
......@@ -444,8 +452,7 @@ class BlockSparseMoE(nn.Module):
# position of each bin.
# List of size num_experts
padded_tokens_per_expert = round_up(tokens_per_expert,
self.blocking)
padded_tokens_per_expert = round_up(tokens_per_expert, self.blocking)
# padded_tokens_per_expert => [128, O, 128, ...]
# Cumulative selected experts per token
......@@ -484,8 +491,7 @@ class BlockSparseMoE(nn.Module):
# Permute tokens and pad to prepare expert computation
# (top_k * sequence_length + padding, model_dim)
x = ops.padded_gather(x, indices, bin_ids, bins, padded_bins,
self.top_k)
x = ops.padded_gather(x, indices, bin_ids, bins, padded_bins, self.top_k)
# Create the sparse matrix topology
with torch.no_grad():
......@@ -496,8 +502,8 @@ class BlockSparseMoE(nn.Module):
# (top_k * sequence_length + padding, ffn_dim * n_experts)
x = stk.Matrix(
topo.size(),
self.act(stk.ops.sdd(x, self.w1, topo).data) *
stk.ops.sdd(x, self.w3, topo).data,
self.act(stk.ops.sdd(x, self.w1, topo).data)
* stk.ops.sdd(x, self.w3, topo).data,
topo.row_indices,
topo.column_indices,
topo.offsets,
......@@ -537,7 +543,9 @@ class MixtralLayer(nn.Module):
self.self_attn = MixtralAttention(
prefix=f"{prefix}.self_attn", config=config, weights=weights
)
self.block_sparse_moe = BlockSparseMoE(f"{prefix}.block_sparse_moe", config, weights)
self.block_sparse_moe = BlockSparseMoE(
f"{prefix}.block_sparse_moe", config, weights
)
self.input_layernorm = FastRMSNorm.load(
prefix=f"{prefix}.input_layernorm", weights=weights, eps=config.rms_norm_eps
......@@ -549,18 +557,18 @@ class MixtralLayer(nn.Module):
)
def forward(
self,
hidden_states,
residual,
cos,
sin,
cu_seqlen_prefill,
kv_cache,
block_tables,
slots,
input_lengths,
max_s,
prefill_cache_indices,
self,
hidden_states,
residual,
cos,
sin,
cu_seqlen_prefill,
kv_cache,
block_tables,
slots,
input_lengths,
max_s,
prefill_cache_indices,
):
normed_hidden_states, res = self.input_layernorm(hidden_states, residual)
......@@ -615,16 +623,16 @@ class MixtralModel(torch.nn.Module):
self.num_key_value_heads = self.layers[0].self_attn.num_key_value_heads
def forward(
self,
input_ids: torch.Tensor,
position_ids: torch.Tensor,
cu_seqlen_prefill: Optional[torch.Tensor],
kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],
block_tables: torch.Tensor,
slots: torch.Tensor,
input_lengths: torch.Tensor,
max_s: int,
prefill_cache_indices: Optional[torch.Tensor],
self,
input_ids: torch.Tensor,
position_ids: torch.Tensor,
cu_seqlen_prefill: Optional[torch.Tensor],
kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],
block_tables: torch.Tensor,
slots: torch.Tensor,
input_lengths: torch.Tensor,
max_s: int,
prefill_cache_indices: Optional[torch.Tensor],
) -> torch.Tensor:
hidden_states = self.embed_tokens(input_ids)
......@@ -670,17 +678,17 @@ class FlashMixtralForCausalLM(torch.nn.Module):
raise ValueError("max_past cannot be None")
def forward(
self,
input_ids: torch.Tensor,
position_ids: torch.Tensor,
cu_seqlen_prefill: Optional[torch.Tensor],
kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],
block_tables: torch.Tensor,
slots: torch.Tensor,
input_lengths: torch.Tensor,
max_s: int,
prefill_cache_indices: Optional[torch.Tensor],
lm_head_indices: Optional[torch.Tensor] = None,
self,
input_ids: torch.Tensor,
position_ids: torch.Tensor,
cu_seqlen_prefill: Optional[torch.Tensor],
kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],
block_tables: torch.Tensor,
slots: torch.Tensor,
input_lengths: torch.Tensor,
max_s: int,
prefill_cache_indices: Optional[torch.Tensor],
lm_head_indices: Optional[torch.Tensor] = None,
) -> torch.Tensor:
if prefill_cache_indices is not None:
# Slots also need to be sliced as it has the same size as the whole kv tensor
......
......@@ -198,7 +198,9 @@ class IdeficsImageProcessor(BaseImageProcessor):
image = image_url_or_urls
if image.startswith("http://") or image.startswith("https://"):
response = requests.get(image_url_or_urls, stream=True, headers=headers, timeout=(1, 5))
response = requests.get(
image_url_or_urls, stream=True, headers=headers, timeout=(1, 5)
)
response.raise_for_status()
content = response.content
elif image.startswith("data:"):
......@@ -213,7 +215,7 @@ class IdeficsImageProcessor(BaseImageProcessor):
image = Image.open(BytesIO(content))
# image.verify()
except Exception:
raise ValueError(f"Could not load image from url {image_url_or_urls}")
raise ValueError(f"Could not load image from url {image_url_or_urls}")
return image
else:
raise ValueError(
......
......@@ -62,6 +62,7 @@ if IS_CUDA_SYSTEM:
elif IS_ROCM_SYSTEM:
from vllm import layernorm_ops
@dataclass
class BaseModelOutputWithPastImage(BaseModelOutputWithPast):
image_hidden_states: Optional[torch.FloatTensor] = None
......@@ -431,7 +432,9 @@ class IdeficsRMSNorm(nn.Module):
return out
else:
raise ValueError("Your system seem to be not supported. Please check your install or open an issue at https://github.com/huggingface/text-generation-inference/issues with a clear reproduction.")
raise ValueError(
"Your system seem to be not supported. Please check your install or open an issue at https://github.com/huggingface/text-generation-inference/issues with a clear reproduction."
)
# this was adapted from LlamaMLP
......@@ -613,8 +616,13 @@ class IdeficsAttention(nn.Module):
query_shape = query_states.shape
key_shape = key_states.shape
self.rotary_emb(query_states.view(-1, *query_shape[2:]), key_states.reshape(-1, *key_shape[2:]), cos, sin)
self.rotary_emb(
query_states.view(-1, *query_shape[2:]),
key_states.reshape(-1, *key_shape[2:]),
cos,
sin,
)
query_states = query_states.view(query_shape)
key_states = key_states.view(key_shape)
......
......@@ -112,6 +112,7 @@ def is_url(string):
result = urlparse(string)
return all([result.scheme, result.netloc])
def is_image(string):
"""Checks if the passed string contains a valid url and nothing else. e.g. if space is included it's immediately
invalidated the url"""
......@@ -344,7 +345,6 @@ class IdeficsProcessor(ProcessorMixin):
image_objects = self.image_processor(image_objects, transform=transform)
text_encoding = self.tokenizer(
text=full_text,
add_special_tokens=False,
......
......@@ -11,7 +11,7 @@ from opentelemetry import trace
from transformers import PreTrainedTokenizerBase
from typing import Optional, Tuple, List, Type, Union, Dict
from text_generation_server.models import Model
from text_generation_server.models import Model
from text_generation_server.utils.speculate import get_speculate
from text_generation_server.models.types import (
Batch,
......@@ -165,8 +165,6 @@ class FlashCausalLMBatch(Batch):
input_length = len(tokenized_input)
input_lengths.append(input_length)
prefix_offsets.append(input_length - 5)
read_offsets.append(input_length)
......@@ -229,7 +227,9 @@ class FlashCausalLMBatch(Batch):
cumulative_max_length += total_tokens
max_seqlen = max(max_seqlen, input_length)
max_blocks = max(max_blocks, needed_blocks)
max_length = max(max_length, input_length + max_new_tokens + speculative_length)
max_length = max(
max_length, input_length + max_new_tokens + speculative_length
)
next_token_chooser = HeterogeneousNextTokenChooser.from_pb(
next_token_chooser_parameters, dtype, device
......@@ -424,7 +424,9 @@ class FlashCausalLMBatch(Batch):
slots = self.slots[slot_filtering_indices]
next_token_chooser = self.next_token_chooser.filter(indices)
top_n_tokens_tensor = self.top_n_tokens_tensor[indices]
speculative_ids = self.speculative_ids[indices] if self.speculative_ids is not None else None
speculative_ids = (
self.speculative_ids[indices] if self.speculative_ids is not None else None
)
start_slots = torch.tensor(start_slots, dtype=torch.int64)
......@@ -480,7 +482,9 @@ class FlashCausalLMBatch(Batch):
total_batch_size += len(b)
total_slots += len(b.slots)
blocks += b.blocks
speculative_length = b.speculative_ids.shape[1] if b.speculative_ids is not None else 0
speculative_length = (
b.speculative_ids.shape[1] if b.speculative_ids is not None else 0
)
max_blocks = max(max_blocks, b.max_blocks)
max_seqlen = max(max_seqlen, b.max_seqlen)
max_length = max(
......@@ -586,7 +590,11 @@ class FlashCausalLMBatch(Batch):
device=batches[0].next_token_chooser.device,
)
speculative_ids = torch.cat([b.speculative_ids for b in batches], dim=0) if batches[0].speculative_ids is not None else None
speculative_ids = (
torch.cat([b.speculative_ids for b in batches], dim=0)
if batches[0].speculative_ids is not None
else None
)
# Needed to avoid dropping blocks when the batches will go out of scope
for b in batches:
......@@ -622,7 +630,7 @@ class FlashCausalLMBatch(Batch):
top_n_tokens_tensor=top_n_tokens_tensor,
blocks=blocks,
max_blocks=max_blocks,
speculative_ids=speculative_ids
speculative_ids=speculative_ids,
)
def __del__(self):
......@@ -727,43 +735,54 @@ class FlashCausalLM(Model):
def forward(self, batch: FlashCausalLMBatch) -> Tuple[torch.Tensor, torch.Tensor]:
# Model Forward
if batch.speculative_ids is not None:
input_ids=batch.input_ids
position_ids=batch.position_ids
cu_seqlen_prefill=batch.cu_seqlen_prefill
kv_cache=get_cache_manager().kv_cache
block_tables=batch.block_tables_tensor
slots=batch.slots[batch.slot_indices]
input_lengths=batch.input_lengths_tensor
max_s=batch.max_seqlen
lm_head_indices=batch.prefill_head_indices
input_ids = batch.input_ids
position_ids = batch.position_ids
cu_seqlen_prefill = batch.cu_seqlen_prefill
kv_cache = get_cache_manager().kv_cache
block_tables = batch.block_tables_tensor
slots = batch.slots[batch.slot_indices]
input_lengths = batch.input_lengths_tensor
max_s = batch.max_seqlen
lm_head_indices = batch.prefill_head_indices
speculative_ids = batch.speculative_ids
B, speculative_length = speculative_ids.shape
B, speculative_length = speculative_ids.shape
new_length = speculative_length + 1
new_input_ids = torch.cat([input_ids.unsqueeze(-1), speculative_ids], dim=1).reshape(-1)
new_input_ids = torch.cat(
[input_ids.unsqueeze(-1), speculative_ids], dim=1
).reshape(-1)
arange = torch.arange(new_length, device=position_ids.device).unsqueeze(0)
arange_int = arange.to(dtype=torch.int32)
new_position_ids = (position_ids.unsqueeze(-1).expand(B, new_length) + arange).view(-1)
new_position_ids = (
position_ids.unsqueeze(-1).expand(B, new_length) + arange
).view(-1)
slots = (slots.unsqueeze(-1).expand(B, new_length) + arange_int).view(-1)
input_lengths = (input_lengths.unsqueeze(-1).expand(B, new_length) + arange_int).view(-1)
input_lengths = (
input_lengths.unsqueeze(-1).expand(B, new_length) + arange_int
).view(-1)
# Add Copy the block tables for all members
block_tables = block_tables.unsqueeze(1).expand(B, new_length, -1).reshape(B* new_length, -1).contiguous()
block_tables = (
block_tables.unsqueeze(1)
.expand(B, new_length, -1)
.reshape(B * new_length, -1)
.contiguous()
)
max_s = max_s + speculative_length
input_ids = new_input_ids
position_ids = new_position_ids
else:
input_ids=batch.input_ids
position_ids=batch.position_ids
cu_seqlen_prefill=batch.cu_seqlen_prefill
kv_cache=get_cache_manager().kv_cache
block_tables=batch.block_tables_tensor
slots=batch.slots[batch.slot_indices]
input_lengths=batch.input_lengths_tensor
max_s=batch.max_seqlen
lm_head_indices=batch.prefill_head_indices
input_ids = batch.input_ids
position_ids = batch.position_ids
cu_seqlen_prefill = batch.cu_seqlen_prefill
kv_cache = get_cache_manager().kv_cache
block_tables = batch.block_tables_tensor
slots = batch.slots[batch.slot_indices]
input_lengths = batch.input_lengths_tensor
max_s = batch.max_seqlen
lm_head_indices = batch.prefill_head_indices
return self.model.forward(
input_ids=input_ids,
......@@ -808,20 +827,31 @@ class FlashCausalLM(Model):
else:
speculative_logits = None
if prefill:
next_token_logits = (
out[batch.prefill_next_token_indices] if prefill_logprobs else out
)
if speculative_logits is not None:
speculative_logits = (
speculative_logits[batch.prefill_next_token_indices] if prefill_logprobs else speculative_logits
speculative_logits[batch.prefill_next_token_indices]
if prefill_logprobs
else speculative_logits
)
else:
next_token_logits = out
next_input_ids, next_token_logprobs, logprobs, accepted_ids, speculative_ids = batch.next_token_chooser(
batch.all_input_ids_tensor[:, : batch.max_seqlen], next_token_logits, get_speculate(), batch.speculative_ids, speculative_logits
(
next_input_ids,
next_token_logprobs,
logprobs,
accepted_ids,
speculative_ids,
) = batch.next_token_chooser(
batch.all_input_ids_tensor[:, : batch.max_seqlen],
next_token_logits,
get_speculate(),
batch.speculative_ids,
speculative_logits,
)
batch_top_token_ids, batch_top_token_logprobs = batch_top_tokens(
......@@ -851,11 +881,7 @@ class FlashCausalLM(Model):
stopped = True
# Zipped iterator
iterator = zip(
batch.input_lengths,
batch.all_input_ids,
accepted_ids
)
iterator = zip(batch.input_lengths, batch.all_input_ids, accepted_ids)
# We do two for loops as the first one can run completely asynchronously from the GPU while for the second
# one, we need to first do a GPU <-> CPU sync
......@@ -863,11 +889,7 @@ class FlashCausalLM(Model):
# For each member of the batch
index = 0
for i, (
input_length,
all_input_ids,
n_accepted_ids
) in enumerate(iterator):
for i, (input_length, all_input_ids, n_accepted_ids) in enumerate(iterator):
# Indexing metadata
start_index = cumulative_length
end_index = cumulative_length + input_length
......@@ -901,7 +923,6 @@ class FlashCausalLM(Model):
cumulative_length += input_length
batch.input_ids = next_input_ids[accepted_ids.cumsum(dim=-1) - 1]
batch.speculative_ids = speculative_ids
batch.position_ids = next_position_ids + accepted_ids
......@@ -983,8 +1004,10 @@ class FlashCausalLM(Model):
current_stopped = False
stopped = stopped and current_stopped
_next_token_ids = next_token_ids[index: index+n_accepted_ids - left]
_next_token_logprobs = next_token_logprobs[index: index+n_accepted_ids - left]
_next_token_ids = next_token_ids[index : index + n_accepted_ids - left]
_next_token_logprobs = next_token_logprobs[
index : index + n_accepted_ids - left
]
index += n_accepted_ids
# Shard generations
......@@ -1027,7 +1050,10 @@ class FlashCausalLM(Model):
)
prefill_tokens = Tokens(
prefill_token_ids, request_prefill_logprobs, prefill_texts, is_special = []
prefill_token_ids,
request_prefill_logprobs,
prefill_texts,
is_special=[],
)
else:
prefill_tokens = None
......
......@@ -71,12 +71,19 @@ class FlashLlama(FlashCausalLM):
from text_generation_server.utils.medusa import MedusaModel
from huggingface_hub import hf_hub_download
import json
medusa_config = hf_hub_download(use_medusa, revision=revision, filename="config.json")
medusa_config = hf_hub_download(
use_medusa, revision=revision, filename="config.json"
)
with open(medusa_config, "r") as f:
config = json.load(f)
medusa_head = hf_hub_download(use_medusa, revision=revision, filename="medusa_lm_head.pt")
medusa_sf = medusa_head[:-len(".pt")] + ".safetensors"
weights = Weights([medusa_sf], device, dtype, process_group=self.process_group)
medusa_head = hf_hub_download(
use_medusa, revision=revision, filename="medusa_lm_head.pt"
)
medusa_sf = medusa_head[: -len(".pt")] + ".safetensors"
weights = Weights(
[medusa_sf], device, dtype, process_group=self.process_group
)
lm_head = model.lm_head
model.lm_head = MedusaModel(config, weights, lm_head)
......
......@@ -45,11 +45,11 @@ class FlashMistralBatch(FlashCausalLMBatch):
@classmethod
def from_pb(
cls,
pb: generate_pb2.Batch,
tokenizer: PreTrainedTokenizerBase,
dtype: torch.dtype,
device: torch.device,
cls,
pb: generate_pb2.Batch,
tokenizer: PreTrainedTokenizerBase,
dtype: torch.dtype,
device: torch.device,
) -> "FlashCausalLMBatch":
global SLIDING_WINDOW
global SLIDING_WINDOW_BLOCKS
......@@ -99,12 +99,12 @@ class FlashMistralBatch(FlashCausalLMBatch):
# Parse batch
for i, (r, tokenized_input) in enumerate(
zip(pb.requests, batch_tokenized_inputs)
zip(pb.requests, batch_tokenized_inputs)
):
# request id -> idx in list mapping
requests_idx_mapping[r.id] = i
tokenized_input = tokenized_input[-r.truncate:]
tokenized_input = tokenized_input[-r.truncate :]
input_length = len(tokenized_input)
input_lengths.append(input_length)
......@@ -184,7 +184,9 @@ class FlashMistralBatch(FlashCausalLMBatch):
cumulative_max_length += total_tokens
max_seqlen = max(max_seqlen, input_length)
max_blocks = max(max_blocks, needed_blocks)
max_length = max(max_length, input_length + max_new_tokens + speculative_length)
max_length = max(
max_length, input_length + max_new_tokens + speculative_length
)
next_token_chooser = HeterogeneousNextTokenChooser.from_pb(
next_token_chooser_parameters, dtype, device
......@@ -273,20 +275,20 @@ class FlashMistralBatch(FlashCausalLMBatch):
blocks=blocks,
max_blocks=max_blocks,
prefill_cache_indices=prefill_cache_indices,
speculative_ids=None
speculative_ids=None,
)
class BaseFlashMistral(FlashCausalLM):
def __init__(
self,
config_cls,
model_cls,
model_id: str,
revision: Optional[str] = None,
quantize: Optional[str] = None,
dtype: Optional[torch.dtype] = None,
trust_remote_code: bool = False,
self,
config_cls,
model_cls,
model_id: str,
revision: Optional[str] = None,
quantize: Optional[str] = None,
dtype: Optional[torch.dtype] = None,
trust_remote_code: bool = False,
):
global SLIDING_WINDOW
global SLIDING_WINDOW_BLOCKS
......@@ -345,43 +347,54 @@ class BaseFlashMistral(FlashCausalLM):
def forward(self, batch: FlashMistralBatch) -> Tuple[torch.Tensor, torch.Tensor]:
# Model Forward
if batch.speculative_ids is not None:
input_ids=batch.input_ids
position_ids=batch.position_ids
cu_seqlen_prefill=batch.cu_seqlen_prefill
kv_cache=get_cache_manager().kv_cache
block_tables=batch.block_tables_tensor
slots=batch.slots[batch.slot_indices]
input_lengths=batch.input_lengths_tensor
max_s=batch.max_seqlen
lm_head_indices=batch.prefill_head_indices
input_ids = batch.input_ids
position_ids = batch.position_ids
cu_seqlen_prefill = batch.cu_seqlen_prefill
kv_cache = get_cache_manager().kv_cache
block_tables = batch.block_tables_tensor
slots = batch.slots[batch.slot_indices]
input_lengths = batch.input_lengths_tensor
max_s = batch.max_seqlen
lm_head_indices = batch.prefill_head_indices
speculative_ids = batch.speculative_ids
B, speculative_length = speculative_ids.shape
B, speculative_length = speculative_ids.shape
new_length = speculative_length + 1
new_input_ids = torch.cat([input_ids.unsqueeze(-1), speculative_ids], dim=1).reshape(-1)
new_input_ids = torch.cat(
[input_ids.unsqueeze(-1), speculative_ids], dim=1
).reshape(-1)
arange = torch.arange(new_length, device=position_ids.device).unsqueeze(0)
arange_int = arange.to(dtype=torch.int32)
new_position_ids = (position_ids.unsqueeze(-1).expand(B, new_length) + arange).view(-1)
new_position_ids = (
position_ids.unsqueeze(-1).expand(B, new_length) + arange
).view(-1)
slots = (slots.unsqueeze(-1).expand(B, new_length) + arange_int).view(-1)
input_lengths = (input_lengths.unsqueeze(-1).expand(B, new_length) + arange_int).view(-1)
input_lengths = (
input_lengths.unsqueeze(-1).expand(B, new_length) + arange_int
).view(-1)
# Add Copy the block tables for all members
block_tables = block_tables.unsqueeze(1).expand(B, new_length, -1).reshape(B* new_length, -1).contiguous()
block_tables = (
block_tables.unsqueeze(1)
.expand(B, new_length, -1)
.reshape(B * new_length, -1)
.contiguous()
)
max_s = max_s + speculative_length
input_ids = new_input_ids
position_ids = new_position_ids
else:
input_ids=batch.input_ids
position_ids=batch.position_ids
cu_seqlen_prefill=batch.cu_seqlen_prefill
kv_cache=get_cache_manager().kv_cache
block_tables=batch.block_tables_tensor
slots=batch.slots[batch.slot_indices]
input_lengths=batch.input_lengths_tensor
max_s=batch.max_seqlen
lm_head_indices=batch.prefill_head_indices
input_ids = batch.input_ids
position_ids = batch.position_ids
cu_seqlen_prefill = batch.cu_seqlen_prefill
kv_cache = get_cache_manager().kv_cache
block_tables = batch.block_tables_tensor
slots = batch.slots[batch.slot_indices]
input_lengths = batch.input_lengths_tensor
max_s = batch.max_seqlen
lm_head_indices = batch.prefill_head_indices
logits = self.model.forward(
input_ids=input_ids,
position_ids=position_ids,
......@@ -401,12 +414,12 @@ class BaseFlashMistral(FlashCausalLM):
class FlashMistral(BaseFlashMistral):
def __init__(
self,
model_id: str,
revision: Optional[str] = None,
quantize: Optional[str] = None,
dtype: Optional[torch.dtype] = None,
trust_remote_code: bool = False,
self,
model_id: str,
revision: Optional[str] = None,
quantize: Optional[str] = None,
dtype: Optional[torch.dtype] = None,
trust_remote_code: bool = False,
):
super(FlashMistral, self).__init__(
config_cls=MistralConfig,
......@@ -415,5 +428,5 @@ class FlashMistral(BaseFlashMistral):
revision=revision,
quantize=quantize,
dtype=dtype,
trust_remote_code=trust_remote_code
trust_remote_code=trust_remote_code,
)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment