Unverified Commit 9946165e authored by OlivierDehaene's avatar OlivierDehaene Committed by GitHub
Browse files

chore: add pre-commit (#1569)

parent 142cdabe
......@@ -69,7 +69,12 @@ def _load_multi_mqa_gptq(
qzeros = torch.cat([q_tensor, kv_tensor], dim=1)
qzeros = qzeros.to(device=weights.device)
bits, groupsize, _, quant_method, = weights._get_gptq_params()
(
bits,
groupsize,
_,
quant_method,
) = weights._get_gptq_params()
if quant_method == "gptq":
g_idx = weights.get_tensor(f"{prefix}.c_attn.g_idx")
g_idx = g_idx.to(device=weights.device)
......@@ -306,9 +311,9 @@ class MLP(nn.Module):
if "gelu" not in act
else lambda x: torch.nn.functional.gelu(
x,
approximate="tanh"
if act in ["gelu_fast", "gelu_pytorch_tanh"]
else "none",
approximate=(
"tanh" if act in ["gelu_fast", "gelu_pytorch_tanh"] else "none"
),
)
)
......
......@@ -66,6 +66,7 @@ class IdeficsVisionConfig(PretrainedConfig):
initializer_range (`float`, *optional*, defaults to 0.02):
The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
"""
model_type = "idefics"
attribute_map = {
"hidden_size": "embed_dim",
......@@ -125,6 +126,7 @@ class IdeficsPerceiverConfig(PretrainedConfig):
qk_layer_norms_perceiver (`bool`, *optional*, defaults to `False`):
Whether or not to use qk layer norms in perceiver
"""
model_type = "idefics"
def __init__(
......@@ -219,6 +221,7 @@ class IdeficsConfig(PretrainedConfig):
>>> # Accessing the model configuration
>>> configuration = model.config
```"""
model_type = "idefics"
is_composition = True
......
......@@ -123,10 +123,10 @@ def expand_inputs_for_generation(
raise ValueError(
"If `is_encoder_decoder` is True, make sure that `encoder_outputs` is defined."
)
encoder_outputs[
"last_hidden_state"
] = encoder_outputs.last_hidden_state.index_select(
0, expanded_return_idx.to(encoder_outputs.last_hidden_state.device)
encoder_outputs["last_hidden_state"] = (
encoder_outputs.last_hidden_state.index_select(
0, expanded_return_idx.to(encoder_outputs.last_hidden_state.device)
)
)
model_kwargs["encoder_outputs"] = encoder_outputs
return input_ids, model_kwargs
......
......@@ -133,6 +133,7 @@ class IdeficsProcessor(ProcessorMixin):
An instance of [`LlamaTokenizerFast`]. The tokenizer is a required input.
image_size (`int`, *optional*, defaults to 224): Image size (assuming a square image)
"""
attributes = ["image_processor", "tokenizer"]
image_processor_class = "IdeficsImageProcessor"
tokenizer_class = "LlamaTokenizerFast"
......
......@@ -19,10 +19,12 @@ from causal_conv1d import causal_conv1d_fn, causal_conv1d_update
import math
from dataclasses import dataclass
@dataclass
class InferenceParams:
"""Inference parameters that are passed to the main model in order
to efficienly calculate and store the context during inference."""
max_seqlen: int
max_batch_size: int
conv_states: torch.Tensor
......@@ -137,13 +139,28 @@ class MambaBlock(nn.Module):
def step(self, hidden_states, conv_state, ssm_state):
xz = self.in_proj(hidden_states.squeeze(1))
x, z = xz.chunk(2, dim=-1) # (B D)
x = causal_conv1d_update(x, conv_state, self.conv1d.weight.squeeze(1), self.conv1d.bias, self.activation)
x = causal_conv1d_update(
x,
conv_state,
self.conv1d.weight.squeeze(1),
self.conv1d.bias,
self.activation,
)
x_db = self.x_proj(x) # (B dt_rank+2*d_state)
dt, B, C = torch.split(x_db, [self.dt_rank, self.d_state, self.d_state], dim=-1)
dt = F.linear(dt, self.dt_proj.weight)
A = self.negA
y = selective_state_update(
ssm_state, x, dt, A, B, C, self.D, z=z, dt_bias=self.dt_proj.bias, dt_softplus=True
ssm_state,
x,
dt,
A,
B,
C,
self.D,
z=z,
dt_bias=self.dt_proj.bias,
dt_softplus=True,
)
out = self.out_proj(y)
return out.unsqueeze(1), conv_state.clone(), ssm_state.clone()
......
......@@ -2,6 +2,7 @@
Inspired by https://github.com/karpathy/minGPT/blob/master/mingpt/model.py
"""
import math
import os
import warnings
......
......@@ -35,7 +35,6 @@ from text_generation_server.utils.dist import MEMORY_FRACTION
tracer = trace.get_tracer(__name__)
@dataclass
class FlashCausalLMBatch(Batch):
batch_id: int
......@@ -1213,8 +1212,9 @@ class FlashCausalLM(Model):
# accept each new token for this specific request since we may
# have more than one new token per request with speculative decoding
for next_token_id in _next_token_ids:
batch.next_token_chooser = batch.next_token_chooser.advance_grammar_single(i, next_token_id)
batch.next_token_chooser = (
batch.next_token_chooser.advance_grammar_single(i, next_token_id)
)
# Update values
batch.input_lengths[i] = input_length + n_accepted_ids
......
......@@ -92,7 +92,9 @@ class GalacticaCausalLMBatch(CausalLMBatch):
requests_idx_mapping[r.id] = i
# Add escape_custom_split_sequence to the CausalLMBatch logic
inputs.append(escape_custom_split_sequence(r.inputs))
next_token_choosers.append(NextTokenChooser.from_pb(r.parameters, device, tokenizer))
next_token_choosers.append(
NextTokenChooser.from_pb(r.parameters, device, tokenizer)
)
stopping_criteria = StoppingCriteria.from_pb(
r.stopping_parameters, tokenizer
)
......
......@@ -114,7 +114,9 @@ class IdeficsCausalLMBatch(Batch):
for i, r in enumerate(pb.requests):
requests_idx_mapping[r.id] = i
inputs.append(r.inputs)
next_token_choosers.append(NextTokenChooser.from_pb(r.parameters, device, tokenizer))
next_token_choosers.append(
NextTokenChooser.from_pb(r.parameters, device, tokenizer)
)
stopping_criteria = StoppingCriteria.from_pb(
r.stopping_parameters, tokenizer
)
......@@ -401,9 +403,9 @@ class IdeficsCausalLMBatch(Batch):
pixel_values = batch.pixel_values.new_zeros(
(total_batch_size, max_num_images, 3, 224, 224)
)
pixel_values[
start_index:end_index, :curr_batch_max_num_images
] = batch.pixel_values
pixel_values[start_index:end_index, :curr_batch_max_num_images] = (
batch.pixel_values
)
if image_attention_mask is None:
image_attention_mask = batch.image_attention_mask.new_zeros(
......@@ -500,14 +502,14 @@ class IdeficsCausalLMBatch(Batch):
# We slice the keys to remove the padding from previous batches
past_seq_len = batch.max_input_length - 1
if batch.keys_head_dim_last:
padded_past_keys[
start_index:end_index, :, -past_seq_len:, :
] = past_keys[:, :, -past_seq_len:, :]
padded_past_keys[start_index:end_index, :, -past_seq_len:, :] = (
past_keys[:, :, -past_seq_len:, :]
)
else:
# BLOOM case
padded_past_keys[
start_index:end_index, :, :, -past_seq_len:
] = past_keys[:, :, :, -past_seq_len:]
padded_past_keys[start_index:end_index, :, :, -past_seq_len:] = (
past_keys[:, :, :, -past_seq_len:]
)
del past_keys
start_index = end_index
......@@ -525,9 +527,9 @@ class IdeficsCausalLMBatch(Batch):
end_index = start_index + len(batch)
# We slice the past values to remove the padding from previous batches
past_seq_len = batch.max_input_length - 1
padded_past_values[
start_index:end_index, :, -past_seq_len:, :
] = past_values[:, :, -past_seq_len:, :]
padded_past_values[start_index:end_index, :, -past_seq_len:, :] = (
past_values[:, :, -past_seq_len:, :]
)
del past_values
# Update values
......@@ -603,9 +605,11 @@ class IdeficsCausalLM(Model):
model_id,
revision=revision,
torch_dtype=dtype,
device_map="auto"
if torch.cuda.is_available() and torch.cuda.device_count() > 1
else None,
device_map=(
"auto"
if torch.cuda.is_available() and torch.cuda.device_count() > 1
else None
),
load_in_8bit=quantize == "bitsandbytes",
trust_remote_code=trust_remote_code,
)
......@@ -836,9 +840,9 @@ class IdeficsCausalLM(Model):
# Update attention_mask as we added a new token to input_ids
batch.attention_mask[:, -batch.padding_right_offset] = 1
batch.image_attention_mask[
:, -batch.padding_right_offset, :
] = batch.image_attention_mask[:, -(batch.padding_right_offset + 1), :]
batch.image_attention_mask[:, -batch.padding_right_offset, :] = (
batch.image_attention_mask[:, -(batch.padding_right_offset + 1), :]
)
# Decrease right offset
batch.padding_right_offset -= 1
......
......@@ -15,7 +15,10 @@ from text_generation_server.utils import (
)
from text_generation_server.models.globals import ENABLE_CUDA_GRAPHS, MEM_POOL
import time
from text_generation_server.models.custom_modeling.mamba_modeling import MambaModel, InferenceParams
from text_generation_server.models.custom_modeling.mamba_modeling import (
MambaModel,
InferenceParams,
)
from text_generation_server.models import Model
from typing import Any, List, Optional, Tuple, Type, Dict
from text_generation_server.models.types import (
......@@ -28,21 +31,35 @@ from text_generation_server.utils.tokens import batch_top_tokens, Sampling
from dataclasses import dataclass
from text_generation_server.utils import NextTokenChooser, StoppingCriteria, Sampling
def new_inference_params(n_blocks: int, batch_size: int, d_inner: int, d_conv: int, d_state: int, seqlen_offset: int, dtype: torch.dtype, device: torch.device):
def new_inference_params(
n_blocks: int,
batch_size: int,
d_inner: int,
d_conv: int,
d_state: int,
seqlen_offset: int,
dtype: torch.dtype,
device: torch.device,
):
max_seqlen = 0
conv_states = torch.zeros(
(n_blocks,
batch_size,
d_inner,
d_conv,),
(
n_blocks,
batch_size,
d_inner,
d_conv,
),
device=device,
dtype=dtype,
)
ssm_states = torch.zeros(
(n_blocks,
batch_size,
d_inner,
d_state,),
(
n_blocks,
batch_size,
d_inner,
d_state,
),
device=device,
dtype=dtype,
)
......@@ -52,7 +69,6 @@ def new_inference_params(n_blocks: int, batch_size: int, d_inner: int, d_conv: i
seqlen_offset=seqlen_offset,
conv_states=conv_states,
ssm_states=ssm_states,
)
return inference_params
......@@ -124,7 +140,9 @@ class MambaBatch(Batch):
for i, r in enumerate(pb.requests):
requests_idx_mapping[r.id] = i
inputs.append(r.inputs)
next_token_choosers.append(NextTokenChooser.from_pb(r.parameters, device, tokenizer))
next_token_choosers.append(
NextTokenChooser.from_pb(r.parameters, device, tokenizer)
)
stopping_criteria = StoppingCriteria.from_pb(
r.stopping_parameters, tokenizer
)
......@@ -251,7 +269,9 @@ class MambaBatch(Batch):
# TODO
# Kept it simple by just updating the state, maybe updating the other CPU values is necessary.
self.inference_params.conv_states = self.inference_params.conv_states[:, indices]
self.inference_params.conv_states = self.inference_params.conv_states[
:, indices
]
self.inference_params.ssm_states = self.inference_params.ssm_states[:, indices]
return self
......@@ -280,13 +300,20 @@ class MambaBatch(Batch):
max_seqlen = 0
seqlen_offset = 0
(n_blocks, _, d_inner, d_conv) = (
batches[0].inference_params.conv_states.shape
)
(n_blocks, _, d_inner, d_conv) = batches[0].inference_params.conv_states.shape
(_, _, _, d_state) = batches[0].inference_params.ssm_states.shape
dtype = batches[0].inference_params.conv_states.dtype
device = batches[0].inference_params.conv_states.device
inference_params = new_inference_params(n_blocks=n_blocks, batch_size=total_batch_size, d_state=d_state, d_conv=d_conv, d_inner=d_inner, seqlen_offset=seqlen_offset, device=device, dtype=dtype)
inference_params = new_inference_params(
n_blocks=n_blocks,
batch_size=total_batch_size,
d_state=d_state,
d_conv=d_conv,
d_inner=d_inner,
seqlen_offset=seqlen_offset,
device=device,
dtype=dtype,
)
# Batch tensors
input_ids = None
......@@ -334,13 +361,20 @@ class MambaBatch(Batch):
max_input_length - batch.max_input_length
) * len(batch)
inference_params.max_seqlen = max(inference_params.max_seqlen, batch.inference_params.max_seqlen)
inference_params.max_seqlen = max(
inference_params.max_seqlen, batch.inference_params.max_seqlen
)
assert batch.inference_params.seqlen_offset != 0, "Invalid seqlen offset"
inference_params.seqlen_offset = max(inference_params.seqlen_offset, batch.inference_params.seqlen_offset)
inference_params.seqlen_offset = max(
inference_params.seqlen_offset, batch.inference_params.seqlen_offset
)
inference_params.conv_states[:, start_index:end_index] = batch.inference_params.conv_states
inference_params.ssm_states[:, start_index:end_index] = batch.inference_params.ssm_states
inference_params.conv_states[:, start_index:end_index] = (
batch.inference_params.conv_states
)
inference_params.ssm_states[:, start_index:end_index] = (
batch.inference_params.ssm_states
)
start_index = end_index
......@@ -452,36 +486,39 @@ class Mamba(Model):
# Important seqlen_offset to go through the update mecanism with the state
seqlen_offset = 1
inference_params = new_inference_params(n_blocks=n_blocks, batch_size=batch_size, d_state=d_state, d_conv=d_conv, d_inner=d_inner, seqlen_offset=seqlen_offset, device=self.device, dtype=self.dtype)
inference_params = new_inference_params(
n_blocks=n_blocks,
batch_size=batch_size,
d_state=d_state,
d_conv=d_conv,
d_inner=d_inner,
seqlen_offset=seqlen_offset,
device=self.device,
dtype=self.dtype,
)
graph = torch.cuda.CUDAGraph()
torch.cuda.synchronize()
# Run once outside to warmup
self.model.forward(
input_ids=input_ids,
inference_params=inference_params
)
self.model.forward(input_ids=input_ids, inference_params=inference_params)
torch.cuda.synchronize()
with torch.cuda.graph(graph, pool=MEM_POOL):
logits = self.model.forward(
input_ids=input_ids,
inference_params=inference_params
input_ids=input_ids, inference_params=inference_params
)
torch.cuda.synchronize()
graph_dict = {
"input_ids": input_ids,
"inference_params": inference_params,
"graph": graph,
"logits": logits
"logits": logits,
}
self.cuda_graphs[batch_size] = graph_dict
def forward(
self,
input_ids: torch.Tensor,
inference_params: Any
self, input_ids: torch.Tensor, inference_params: Any
) -> Tuple[torch.Tensor, torch.Tensor]:
bs = input_ids.shape[0]
padded_bs = bs
......@@ -504,15 +541,21 @@ class Mamba(Model):
# Copy inputs to the static inputs of the cuda graph
# Static inputs are potentially padded
cuda_graph["input_ids"][: bs] = input_ids
cuda_graph["inference_params"].conv_states[:, : bs] = inference_params.conv_states
cuda_graph["inference_params"].ssm_states[:, : bs] = inference_params.ssm_states
cuda_graph["input_ids"][:bs] = input_ids
cuda_graph["inference_params"].conv_states[
:, :bs
] = inference_params.conv_states
cuda_graph["inference_params"].ssm_states[:, :bs] = inference_params.ssm_states
# Replay the graph
cuda_graph["graph"].replay()
inference_params.conv_states.copy_(cuda_graph["inference_params"].conv_states[:, :bs])
inference_params.ssm_states.copy_(cuda_graph["inference_params"].ssm_states[:, :bs])
inference_params.conv_states.copy_(
cuda_graph["inference_params"].conv_states[:, :bs]
)
inference_params.ssm_states.copy_(
cuda_graph["inference_params"].ssm_states[:, :bs]
)
# Slice output to the correct shape
return cuda_graph["logits"][:bs]
......@@ -528,19 +571,25 @@ class Mamba(Model):
if batch.inference_params is None:
# 0 is important here
seqlen_offset = 0
seqlen_offset = 0
n_blocks = len(self.model.blocks)
d_state = self.model.config.d_state
d_conv = self.model.config.d_conv
d_inner = self.model.config.d_inner
inference_params = new_inference_params(n_blocks=n_blocks, batch_size=batch_size, d_state=d_state, d_conv=d_conv, d_inner=d_inner, seqlen_offset=seqlen_offset, device=self.device, dtype=self.dtype)
inference_params = new_inference_params(
n_blocks=n_blocks,
batch_size=batch_size,
d_state=d_state,
d_conv=d_conv,
d_inner=d_inner,
seqlen_offset=seqlen_offset,
device=self.device,
dtype=self.dtype,
)
batch.inference_params = inference_params
# Forward pass
logits = self.forward(
input_ids, inference_params=batch.inference_params
)
logits = self.forward(input_ids, inference_params=batch.inference_params)
# batch.inference_params = new_inference_params
# Results
......@@ -694,9 +743,9 @@ class Mamba(Model):
generations.append(generation)
# Update values
batch.next_token_choosers[i] = batch.next_token_choosers[i].advance_grammar(
next_token_id_squeezed.item()
)
batch.next_token_choosers[i] = batch.next_token_choosers[
i
].advance_grammar(next_token_id_squeezed.item())
batch.input_ids[i, 0] = next_token_id
batch.all_input_ids[i] = all_input_ids
batch.input_lengths[i] = new_input_length
......
......@@ -36,9 +36,11 @@ class RW(CausalLM):
model_id,
revision=revision,
torch_dtype=dtype,
device_map="auto"
if torch.cuda.is_available() and torch.cuda.device_count() > 1
else None,
device_map=(
"auto"
if torch.cuda.is_available() and torch.cuda.device_count() > 1
else None
),
load_in_8bit=quantize == "bitsandbytes",
trust_remote_code=trust_remote_code,
)
......
......@@ -96,7 +96,9 @@ class Seq2SeqLMBatch(Batch):
inputs.append(r.inputs)
requests_idx_mapping[r.id] = i
decoder_input_lengths.append(1)
next_token_choosers.append(NextTokenChooser.from_pb(r.parameters, device, tokenizer))
next_token_choosers.append(
NextTokenChooser.from_pb(r.parameters, device, tokenizer)
)
stopping_criteria = StoppingCriteria.from_pb(
r.stopping_parameters, tokenizer
)
......@@ -351,9 +353,9 @@ class Seq2SeqLMBatch(Batch):
(total_batch_size, max_input_length),
)
# Copy to correct indices
attention_mask[
start_index:end_index, -batch.max_input_length :
] = batch.attention_mask[:, -batch.max_input_length :]
attention_mask[start_index:end_index, -batch.max_input_length :] = (
batch.attention_mask[:, -batch.max_input_length :]
)
# Create padded tensor
if decoder_input_ids is None:
......@@ -547,9 +549,11 @@ class Seq2SeqLM(Model):
model_id,
revision=revision,
torch_dtype=dtype,
device_map="auto"
if torch.cuda.is_available() and torch.cuda.device_count() > 1
else None,
device_map=(
"auto"
if torch.cuda.is_available() and torch.cuda.device_count() > 1
else None
),
load_in_8bit=quantize == "bitsandbytes",
trust_remote_code=trust_remote_code,
)
......@@ -750,7 +754,7 @@ class Seq2SeqLM(Model):
if top_n_tokens > 0:
all_top_tokens = []
for (top_token_ids, top_token_logprobs) in zip(
for top_token_ids, top_token_logprobs in zip(
top_token_ids, top_token_logprobs
):
toptoken_texts = self.tokenizer.batch_decode(
......
......@@ -88,14 +88,16 @@ class Generation:
def to_pb(self) -> generate_pb2.Generation:
return generate_pb2.Generation(
request_id=self.request_id,
prefill_tokens=self.prefill_tokens.to_pb()
if self.prefill_tokens is not None
else None,
prefill_tokens=(
self.prefill_tokens.to_pb() if self.prefill_tokens is not None else None
),
tokens=self.tokens.to_pb(),
generated_text=self.generated_text.to_pb()
if self.generated_text is not None
else None,
top_tokens=[top_tokens.to_pb() for top_tokens in self.top_tokens]
if self.top_tokens is not None
else None,
generated_text=(
self.generated_text.to_pb() if self.generated_text is not None else None
),
top_tokens=(
[top_tokens.to_pb() for top_tokens in self.top_tokens]
if self.top_tokens is not None
else None
),
)
*.py
*.pyi
*.py-e
\ No newline at end of file
*.py-e
......@@ -182,7 +182,7 @@ try:
) # (BLOCK_SIZE_K, BLOCK_SIZE_N,)
zeros = (zeros >> zeros_shifter[None, :]) & maxq
zeros = (zeros + 1) & maxq # eventually avoid overflow
zeros = (zeros + 1) & maxq # eventually avoid overflow
a = tl.load(a_ptrs, mask=a_mask, other=0.0) # (BLOCK_SIZE_M, BLOCK_SIZE_K)
b = tl.load(b_ptrs) # (BLOCK_SIZE_K, BLOCK_SIZE_N), but repeated
......
......@@ -355,7 +355,9 @@ def get_linear(weight, bias, quantize):
"to use Exllama/GPTQ kernels for AWQ inference."
)
if not HAS_AWQ:
raise NotImplementedError("You do not seem to have awq installed, either install it (cd server && make install-awq), or try using GPTQ `---quantize gptq` a conversion AWQ->GPTQ will happen on the fly")
raise NotImplementedError(
"You do not seem to have awq installed, either install it (cd server && make install-awq), or try using GPTQ `---quantize gptq` a conversion AWQ->GPTQ will happen on the fly"
)
linear = WQLinear(
w_bit=bits,
group_size=groupsize,
......
......@@ -516,7 +516,7 @@ class GrammarLogitProcessor(LogitsProcessor):
if grammar_type == GrammarType.GRAMMAR_TYPE_JSON:
schema = build_regex_from_object(schema)
elif grammar_type == GrammarType.GRAMMAR_TYPE_REGEX:
pass # schema is already a regex just here for clarity
pass # schema is already a regex just here for clarity
fsm = RegexFSM(schema, tokenizer)
logger.debug(f"Compiled FSM in {time.time() - start_time:.2f}s")
return fsm
......
......@@ -409,8 +409,12 @@ class HeterogeneousNextTokenChooser:
def advance_grammar_single(self, grammar_state_index: int, next_id: int):
if self.grammar_processor is not None:
self.fsm_grammar_states[grammar_state_index] = self.grammar_processor.advance_at_index(
next_id, self.fsm_grammar_states[grammar_state_index], grammar_state_index
self.fsm_grammar_states[grammar_state_index] = (
self.grammar_processor.advance_at_index(
next_id,
self.fsm_grammar_states[grammar_state_index],
grammar_state_index,
)
)
return self
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment