Unverified Commit 9946165e authored by OlivierDehaene's avatar OlivierDehaene Committed by GitHub
Browse files

chore: add pre-commit (#1569)

parent 142cdabe
...@@ -69,7 +69,12 @@ def _load_multi_mqa_gptq( ...@@ -69,7 +69,12 @@ def _load_multi_mqa_gptq(
qzeros = torch.cat([q_tensor, kv_tensor], dim=1) qzeros = torch.cat([q_tensor, kv_tensor], dim=1)
qzeros = qzeros.to(device=weights.device) qzeros = qzeros.to(device=weights.device)
bits, groupsize, _, quant_method, = weights._get_gptq_params() (
bits,
groupsize,
_,
quant_method,
) = weights._get_gptq_params()
if quant_method == "gptq": if quant_method == "gptq":
g_idx = weights.get_tensor(f"{prefix}.c_attn.g_idx") g_idx = weights.get_tensor(f"{prefix}.c_attn.g_idx")
g_idx = g_idx.to(device=weights.device) g_idx = g_idx.to(device=weights.device)
...@@ -306,9 +311,9 @@ class MLP(nn.Module): ...@@ -306,9 +311,9 @@ class MLP(nn.Module):
if "gelu" not in act if "gelu" not in act
else lambda x: torch.nn.functional.gelu( else lambda x: torch.nn.functional.gelu(
x, x,
approximate="tanh" approximate=(
if act in ["gelu_fast", "gelu_pytorch_tanh"] "tanh" if act in ["gelu_fast", "gelu_pytorch_tanh"] else "none"
else "none", ),
) )
) )
......
...@@ -66,6 +66,7 @@ class IdeficsVisionConfig(PretrainedConfig): ...@@ -66,6 +66,7 @@ class IdeficsVisionConfig(PretrainedConfig):
initializer_range (`float`, *optional*, defaults to 0.02): initializer_range (`float`, *optional*, defaults to 0.02):
The standard deviation of the truncated_normal_initializer for initializing all weight matrices. The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
""" """
model_type = "idefics" model_type = "idefics"
attribute_map = { attribute_map = {
"hidden_size": "embed_dim", "hidden_size": "embed_dim",
...@@ -125,6 +126,7 @@ class IdeficsPerceiverConfig(PretrainedConfig): ...@@ -125,6 +126,7 @@ class IdeficsPerceiverConfig(PretrainedConfig):
qk_layer_norms_perceiver (`bool`, *optional*, defaults to `False`): qk_layer_norms_perceiver (`bool`, *optional*, defaults to `False`):
Whether or not to use qk layer norms in perceiver Whether or not to use qk layer norms in perceiver
""" """
model_type = "idefics" model_type = "idefics"
def __init__( def __init__(
...@@ -219,6 +221,7 @@ class IdeficsConfig(PretrainedConfig): ...@@ -219,6 +221,7 @@ class IdeficsConfig(PretrainedConfig):
>>> # Accessing the model configuration >>> # Accessing the model configuration
>>> configuration = model.config >>> configuration = model.config
```""" ```"""
model_type = "idefics" model_type = "idefics"
is_composition = True is_composition = True
......
...@@ -123,10 +123,10 @@ def expand_inputs_for_generation( ...@@ -123,10 +123,10 @@ def expand_inputs_for_generation(
raise ValueError( raise ValueError(
"If `is_encoder_decoder` is True, make sure that `encoder_outputs` is defined." "If `is_encoder_decoder` is True, make sure that `encoder_outputs` is defined."
) )
encoder_outputs[ encoder_outputs["last_hidden_state"] = (
"last_hidden_state" encoder_outputs.last_hidden_state.index_select(
] = encoder_outputs.last_hidden_state.index_select( 0, expanded_return_idx.to(encoder_outputs.last_hidden_state.device)
0, expanded_return_idx.to(encoder_outputs.last_hidden_state.device) )
) )
model_kwargs["encoder_outputs"] = encoder_outputs model_kwargs["encoder_outputs"] = encoder_outputs
return input_ids, model_kwargs return input_ids, model_kwargs
......
...@@ -133,6 +133,7 @@ class IdeficsProcessor(ProcessorMixin): ...@@ -133,6 +133,7 @@ class IdeficsProcessor(ProcessorMixin):
An instance of [`LlamaTokenizerFast`]. The tokenizer is a required input. An instance of [`LlamaTokenizerFast`]. The tokenizer is a required input.
image_size (`int`, *optional*, defaults to 224): Image size (assuming a square image) image_size (`int`, *optional*, defaults to 224): Image size (assuming a square image)
""" """
attributes = ["image_processor", "tokenizer"] attributes = ["image_processor", "tokenizer"]
image_processor_class = "IdeficsImageProcessor" image_processor_class = "IdeficsImageProcessor"
tokenizer_class = "LlamaTokenizerFast" tokenizer_class = "LlamaTokenizerFast"
......
...@@ -19,10 +19,12 @@ from causal_conv1d import causal_conv1d_fn, causal_conv1d_update ...@@ -19,10 +19,12 @@ from causal_conv1d import causal_conv1d_fn, causal_conv1d_update
import math import math
from dataclasses import dataclass from dataclasses import dataclass
@dataclass @dataclass
class InferenceParams: class InferenceParams:
"""Inference parameters that are passed to the main model in order """Inference parameters that are passed to the main model in order
to efficienly calculate and store the context during inference.""" to efficienly calculate and store the context during inference."""
max_seqlen: int max_seqlen: int
max_batch_size: int max_batch_size: int
conv_states: torch.Tensor conv_states: torch.Tensor
...@@ -137,13 +139,28 @@ class MambaBlock(nn.Module): ...@@ -137,13 +139,28 @@ class MambaBlock(nn.Module):
def step(self, hidden_states, conv_state, ssm_state): def step(self, hidden_states, conv_state, ssm_state):
xz = self.in_proj(hidden_states.squeeze(1)) xz = self.in_proj(hidden_states.squeeze(1))
x, z = xz.chunk(2, dim=-1) # (B D) x, z = xz.chunk(2, dim=-1) # (B D)
x = causal_conv1d_update(x, conv_state, self.conv1d.weight.squeeze(1), self.conv1d.bias, self.activation) x = causal_conv1d_update(
x,
conv_state,
self.conv1d.weight.squeeze(1),
self.conv1d.bias,
self.activation,
)
x_db = self.x_proj(x) # (B dt_rank+2*d_state) x_db = self.x_proj(x) # (B dt_rank+2*d_state)
dt, B, C = torch.split(x_db, [self.dt_rank, self.d_state, self.d_state], dim=-1) dt, B, C = torch.split(x_db, [self.dt_rank, self.d_state, self.d_state], dim=-1)
dt = F.linear(dt, self.dt_proj.weight) dt = F.linear(dt, self.dt_proj.weight)
A = self.negA A = self.negA
y = selective_state_update( y = selective_state_update(
ssm_state, x, dt, A, B, C, self.D, z=z, dt_bias=self.dt_proj.bias, dt_softplus=True ssm_state,
x,
dt,
A,
B,
C,
self.D,
z=z,
dt_bias=self.dt_proj.bias,
dt_softplus=True,
) )
out = self.out_proj(y) out = self.out_proj(y)
return out.unsqueeze(1), conv_state.clone(), ssm_state.clone() return out.unsqueeze(1), conv_state.clone(), ssm_state.clone()
......
...@@ -2,6 +2,7 @@ ...@@ -2,6 +2,7 @@
Inspired by https://github.com/karpathy/minGPT/blob/master/mingpt/model.py Inspired by https://github.com/karpathy/minGPT/blob/master/mingpt/model.py
""" """
import math import math
import os import os
import warnings import warnings
......
...@@ -35,7 +35,6 @@ from text_generation_server.utils.dist import MEMORY_FRACTION ...@@ -35,7 +35,6 @@ from text_generation_server.utils.dist import MEMORY_FRACTION
tracer = trace.get_tracer(__name__) tracer = trace.get_tracer(__name__)
@dataclass @dataclass
class FlashCausalLMBatch(Batch): class FlashCausalLMBatch(Batch):
batch_id: int batch_id: int
...@@ -1213,8 +1212,9 @@ class FlashCausalLM(Model): ...@@ -1213,8 +1212,9 @@ class FlashCausalLM(Model):
# accept each new token for this specific request since we may # accept each new token for this specific request since we may
# have more than one new token per request with speculative decoding # have more than one new token per request with speculative decoding
for next_token_id in _next_token_ids: for next_token_id in _next_token_ids:
batch.next_token_chooser = batch.next_token_chooser.advance_grammar_single(i, next_token_id) batch.next_token_chooser = (
batch.next_token_chooser.advance_grammar_single(i, next_token_id)
)
# Update values # Update values
batch.input_lengths[i] = input_length + n_accepted_ids batch.input_lengths[i] = input_length + n_accepted_ids
......
...@@ -92,7 +92,9 @@ class GalacticaCausalLMBatch(CausalLMBatch): ...@@ -92,7 +92,9 @@ class GalacticaCausalLMBatch(CausalLMBatch):
requests_idx_mapping[r.id] = i requests_idx_mapping[r.id] = i
# Add escape_custom_split_sequence to the CausalLMBatch logic # Add escape_custom_split_sequence to the CausalLMBatch logic
inputs.append(escape_custom_split_sequence(r.inputs)) inputs.append(escape_custom_split_sequence(r.inputs))
next_token_choosers.append(NextTokenChooser.from_pb(r.parameters, device, tokenizer)) next_token_choosers.append(
NextTokenChooser.from_pb(r.parameters, device, tokenizer)
)
stopping_criteria = StoppingCriteria.from_pb( stopping_criteria = StoppingCriteria.from_pb(
r.stopping_parameters, tokenizer r.stopping_parameters, tokenizer
) )
......
...@@ -114,7 +114,9 @@ class IdeficsCausalLMBatch(Batch): ...@@ -114,7 +114,9 @@ class IdeficsCausalLMBatch(Batch):
for i, r in enumerate(pb.requests): for i, r in enumerate(pb.requests):
requests_idx_mapping[r.id] = i requests_idx_mapping[r.id] = i
inputs.append(r.inputs) inputs.append(r.inputs)
next_token_choosers.append(NextTokenChooser.from_pb(r.parameters, device, tokenizer)) next_token_choosers.append(
NextTokenChooser.from_pb(r.parameters, device, tokenizer)
)
stopping_criteria = StoppingCriteria.from_pb( stopping_criteria = StoppingCriteria.from_pb(
r.stopping_parameters, tokenizer r.stopping_parameters, tokenizer
) )
...@@ -401,9 +403,9 @@ class IdeficsCausalLMBatch(Batch): ...@@ -401,9 +403,9 @@ class IdeficsCausalLMBatch(Batch):
pixel_values = batch.pixel_values.new_zeros( pixel_values = batch.pixel_values.new_zeros(
(total_batch_size, max_num_images, 3, 224, 224) (total_batch_size, max_num_images, 3, 224, 224)
) )
pixel_values[ pixel_values[start_index:end_index, :curr_batch_max_num_images] = (
start_index:end_index, :curr_batch_max_num_images batch.pixel_values
] = batch.pixel_values )
if image_attention_mask is None: if image_attention_mask is None:
image_attention_mask = batch.image_attention_mask.new_zeros( image_attention_mask = batch.image_attention_mask.new_zeros(
...@@ -500,14 +502,14 @@ class IdeficsCausalLMBatch(Batch): ...@@ -500,14 +502,14 @@ class IdeficsCausalLMBatch(Batch):
# We slice the keys to remove the padding from previous batches # We slice the keys to remove the padding from previous batches
past_seq_len = batch.max_input_length - 1 past_seq_len = batch.max_input_length - 1
if batch.keys_head_dim_last: if batch.keys_head_dim_last:
padded_past_keys[ padded_past_keys[start_index:end_index, :, -past_seq_len:, :] = (
start_index:end_index, :, -past_seq_len:, : past_keys[:, :, -past_seq_len:, :]
] = past_keys[:, :, -past_seq_len:, :] )
else: else:
# BLOOM case # BLOOM case
padded_past_keys[ padded_past_keys[start_index:end_index, :, :, -past_seq_len:] = (
start_index:end_index, :, :, -past_seq_len: past_keys[:, :, :, -past_seq_len:]
] = past_keys[:, :, :, -past_seq_len:] )
del past_keys del past_keys
start_index = end_index start_index = end_index
...@@ -525,9 +527,9 @@ class IdeficsCausalLMBatch(Batch): ...@@ -525,9 +527,9 @@ class IdeficsCausalLMBatch(Batch):
end_index = start_index + len(batch) end_index = start_index + len(batch)
# We slice the past values to remove the padding from previous batches # We slice the past values to remove the padding from previous batches
past_seq_len = batch.max_input_length - 1 past_seq_len = batch.max_input_length - 1
padded_past_values[ padded_past_values[start_index:end_index, :, -past_seq_len:, :] = (
start_index:end_index, :, -past_seq_len:, : past_values[:, :, -past_seq_len:, :]
] = past_values[:, :, -past_seq_len:, :] )
del past_values del past_values
# Update values # Update values
...@@ -603,9 +605,11 @@ class IdeficsCausalLM(Model): ...@@ -603,9 +605,11 @@ class IdeficsCausalLM(Model):
model_id, model_id,
revision=revision, revision=revision,
torch_dtype=dtype, torch_dtype=dtype,
device_map="auto" device_map=(
if torch.cuda.is_available() and torch.cuda.device_count() > 1 "auto"
else None, if torch.cuda.is_available() and torch.cuda.device_count() > 1
else None
),
load_in_8bit=quantize == "bitsandbytes", load_in_8bit=quantize == "bitsandbytes",
trust_remote_code=trust_remote_code, trust_remote_code=trust_remote_code,
) )
...@@ -836,9 +840,9 @@ class IdeficsCausalLM(Model): ...@@ -836,9 +840,9 @@ class IdeficsCausalLM(Model):
# Update attention_mask as we added a new token to input_ids # Update attention_mask as we added a new token to input_ids
batch.attention_mask[:, -batch.padding_right_offset] = 1 batch.attention_mask[:, -batch.padding_right_offset] = 1
batch.image_attention_mask[ batch.image_attention_mask[:, -batch.padding_right_offset, :] = (
:, -batch.padding_right_offset, : batch.image_attention_mask[:, -(batch.padding_right_offset + 1), :]
] = batch.image_attention_mask[:, -(batch.padding_right_offset + 1), :] )
# Decrease right offset # Decrease right offset
batch.padding_right_offset -= 1 batch.padding_right_offset -= 1
......
...@@ -15,7 +15,10 @@ from text_generation_server.utils import ( ...@@ -15,7 +15,10 @@ from text_generation_server.utils import (
) )
from text_generation_server.models.globals import ENABLE_CUDA_GRAPHS, MEM_POOL from text_generation_server.models.globals import ENABLE_CUDA_GRAPHS, MEM_POOL
import time import time
from text_generation_server.models.custom_modeling.mamba_modeling import MambaModel, InferenceParams from text_generation_server.models.custom_modeling.mamba_modeling import (
MambaModel,
InferenceParams,
)
from text_generation_server.models import Model from text_generation_server.models import Model
from typing import Any, List, Optional, Tuple, Type, Dict from typing import Any, List, Optional, Tuple, Type, Dict
from text_generation_server.models.types import ( from text_generation_server.models.types import (
...@@ -28,21 +31,35 @@ from text_generation_server.utils.tokens import batch_top_tokens, Sampling ...@@ -28,21 +31,35 @@ from text_generation_server.utils.tokens import batch_top_tokens, Sampling
from dataclasses import dataclass from dataclasses import dataclass
from text_generation_server.utils import NextTokenChooser, StoppingCriteria, Sampling from text_generation_server.utils import NextTokenChooser, StoppingCriteria, Sampling
def new_inference_params(n_blocks: int, batch_size: int, d_inner: int, d_conv: int, d_state: int, seqlen_offset: int, dtype: torch.dtype, device: torch.device):
def new_inference_params(
n_blocks: int,
batch_size: int,
d_inner: int,
d_conv: int,
d_state: int,
seqlen_offset: int,
dtype: torch.dtype,
device: torch.device,
):
max_seqlen = 0 max_seqlen = 0
conv_states = torch.zeros( conv_states = torch.zeros(
(n_blocks, (
batch_size, n_blocks,
d_inner, batch_size,
d_conv,), d_inner,
d_conv,
),
device=device, device=device,
dtype=dtype, dtype=dtype,
) )
ssm_states = torch.zeros( ssm_states = torch.zeros(
(n_blocks, (
batch_size, n_blocks,
d_inner, batch_size,
d_state,), d_inner,
d_state,
),
device=device, device=device,
dtype=dtype, dtype=dtype,
) )
...@@ -52,7 +69,6 @@ def new_inference_params(n_blocks: int, batch_size: int, d_inner: int, d_conv: i ...@@ -52,7 +69,6 @@ def new_inference_params(n_blocks: int, batch_size: int, d_inner: int, d_conv: i
seqlen_offset=seqlen_offset, seqlen_offset=seqlen_offset,
conv_states=conv_states, conv_states=conv_states,
ssm_states=ssm_states, ssm_states=ssm_states,
) )
return inference_params return inference_params
...@@ -124,7 +140,9 @@ class MambaBatch(Batch): ...@@ -124,7 +140,9 @@ class MambaBatch(Batch):
for i, r in enumerate(pb.requests): for i, r in enumerate(pb.requests):
requests_idx_mapping[r.id] = i requests_idx_mapping[r.id] = i
inputs.append(r.inputs) inputs.append(r.inputs)
next_token_choosers.append(NextTokenChooser.from_pb(r.parameters, device, tokenizer)) next_token_choosers.append(
NextTokenChooser.from_pb(r.parameters, device, tokenizer)
)
stopping_criteria = StoppingCriteria.from_pb( stopping_criteria = StoppingCriteria.from_pb(
r.stopping_parameters, tokenizer r.stopping_parameters, tokenizer
) )
...@@ -251,7 +269,9 @@ class MambaBatch(Batch): ...@@ -251,7 +269,9 @@ class MambaBatch(Batch):
# TODO # TODO
# Kept it simple by just updating the state, maybe updating the other CPU values is necessary. # Kept it simple by just updating the state, maybe updating the other CPU values is necessary.
self.inference_params.conv_states = self.inference_params.conv_states[:, indices] self.inference_params.conv_states = self.inference_params.conv_states[
:, indices
]
self.inference_params.ssm_states = self.inference_params.ssm_states[:, indices] self.inference_params.ssm_states = self.inference_params.ssm_states[:, indices]
return self return self
...@@ -280,13 +300,20 @@ class MambaBatch(Batch): ...@@ -280,13 +300,20 @@ class MambaBatch(Batch):
max_seqlen = 0 max_seqlen = 0
seqlen_offset = 0 seqlen_offset = 0
(n_blocks, _, d_inner, d_conv) = ( (n_blocks, _, d_inner, d_conv) = batches[0].inference_params.conv_states.shape
batches[0].inference_params.conv_states.shape
)
(_, _, _, d_state) = batches[0].inference_params.ssm_states.shape (_, _, _, d_state) = batches[0].inference_params.ssm_states.shape
dtype = batches[0].inference_params.conv_states.dtype dtype = batches[0].inference_params.conv_states.dtype
device = batches[0].inference_params.conv_states.device device = batches[0].inference_params.conv_states.device
inference_params = new_inference_params(n_blocks=n_blocks, batch_size=total_batch_size, d_state=d_state, d_conv=d_conv, d_inner=d_inner, seqlen_offset=seqlen_offset, device=device, dtype=dtype) inference_params = new_inference_params(
n_blocks=n_blocks,
batch_size=total_batch_size,
d_state=d_state,
d_conv=d_conv,
d_inner=d_inner,
seqlen_offset=seqlen_offset,
device=device,
dtype=dtype,
)
# Batch tensors # Batch tensors
input_ids = None input_ids = None
...@@ -334,13 +361,20 @@ class MambaBatch(Batch): ...@@ -334,13 +361,20 @@ class MambaBatch(Batch):
max_input_length - batch.max_input_length max_input_length - batch.max_input_length
) * len(batch) ) * len(batch)
inference_params.max_seqlen = max(inference_params.max_seqlen, batch.inference_params.max_seqlen) inference_params.max_seqlen = max(
inference_params.max_seqlen, batch.inference_params.max_seqlen
)
assert batch.inference_params.seqlen_offset != 0, "Invalid seqlen offset" assert batch.inference_params.seqlen_offset != 0, "Invalid seqlen offset"
inference_params.seqlen_offset = max(inference_params.seqlen_offset, batch.inference_params.seqlen_offset) inference_params.seqlen_offset = max(
inference_params.seqlen_offset, batch.inference_params.seqlen_offset
)
inference_params.conv_states[:, start_index:end_index] = batch.inference_params.conv_states inference_params.conv_states[:, start_index:end_index] = (
inference_params.ssm_states[:, start_index:end_index] = batch.inference_params.ssm_states batch.inference_params.conv_states
)
inference_params.ssm_states[:, start_index:end_index] = (
batch.inference_params.ssm_states
)
start_index = end_index start_index = end_index
...@@ -452,36 +486,39 @@ class Mamba(Model): ...@@ -452,36 +486,39 @@ class Mamba(Model):
# Important seqlen_offset to go through the update mecanism with the state # Important seqlen_offset to go through the update mecanism with the state
seqlen_offset = 1 seqlen_offset = 1
inference_params = new_inference_params(n_blocks=n_blocks, batch_size=batch_size, d_state=d_state, d_conv=d_conv, d_inner=d_inner, seqlen_offset=seqlen_offset, device=self.device, dtype=self.dtype) inference_params = new_inference_params(
n_blocks=n_blocks,
batch_size=batch_size,
d_state=d_state,
d_conv=d_conv,
d_inner=d_inner,
seqlen_offset=seqlen_offset,
device=self.device,
dtype=self.dtype,
)
graph = torch.cuda.CUDAGraph() graph = torch.cuda.CUDAGraph()
torch.cuda.synchronize() torch.cuda.synchronize()
# Run once outside to warmup # Run once outside to warmup
self.model.forward( self.model.forward(input_ids=input_ids, inference_params=inference_params)
input_ids=input_ids,
inference_params=inference_params
)
torch.cuda.synchronize() torch.cuda.synchronize()
with torch.cuda.graph(graph, pool=MEM_POOL): with torch.cuda.graph(graph, pool=MEM_POOL):
logits = self.model.forward( logits = self.model.forward(
input_ids=input_ids, input_ids=input_ids, inference_params=inference_params
inference_params=inference_params
) )
torch.cuda.synchronize() torch.cuda.synchronize()
graph_dict = { graph_dict = {
"input_ids": input_ids, "input_ids": input_ids,
"inference_params": inference_params, "inference_params": inference_params,
"graph": graph, "graph": graph,
"logits": logits "logits": logits,
} }
self.cuda_graphs[batch_size] = graph_dict self.cuda_graphs[batch_size] = graph_dict
def forward( def forward(
self, self, input_ids: torch.Tensor, inference_params: Any
input_ids: torch.Tensor,
inference_params: Any
) -> Tuple[torch.Tensor, torch.Tensor]: ) -> Tuple[torch.Tensor, torch.Tensor]:
bs = input_ids.shape[0] bs = input_ids.shape[0]
padded_bs = bs padded_bs = bs
...@@ -504,15 +541,21 @@ class Mamba(Model): ...@@ -504,15 +541,21 @@ class Mamba(Model):
# Copy inputs to the static inputs of the cuda graph # Copy inputs to the static inputs of the cuda graph
# Static inputs are potentially padded # Static inputs are potentially padded
cuda_graph["input_ids"][: bs] = input_ids cuda_graph["input_ids"][:bs] = input_ids
cuda_graph["inference_params"].conv_states[:, : bs] = inference_params.conv_states cuda_graph["inference_params"].conv_states[
cuda_graph["inference_params"].ssm_states[:, : bs] = inference_params.ssm_states :, :bs
] = inference_params.conv_states
cuda_graph["inference_params"].ssm_states[:, :bs] = inference_params.ssm_states
# Replay the graph # Replay the graph
cuda_graph["graph"].replay() cuda_graph["graph"].replay()
inference_params.conv_states.copy_(cuda_graph["inference_params"].conv_states[:, :bs]) inference_params.conv_states.copy_(
inference_params.ssm_states.copy_(cuda_graph["inference_params"].ssm_states[:, :bs]) cuda_graph["inference_params"].conv_states[:, :bs]
)
inference_params.ssm_states.copy_(
cuda_graph["inference_params"].ssm_states[:, :bs]
)
# Slice output to the correct shape # Slice output to the correct shape
return cuda_graph["logits"][:bs] return cuda_graph["logits"][:bs]
...@@ -528,19 +571,25 @@ class Mamba(Model): ...@@ -528,19 +571,25 @@ class Mamba(Model):
if batch.inference_params is None: if batch.inference_params is None:
# 0 is important here # 0 is important here
seqlen_offset = 0 seqlen_offset = 0
n_blocks = len(self.model.blocks) n_blocks = len(self.model.blocks)
d_state = self.model.config.d_state d_state = self.model.config.d_state
d_conv = self.model.config.d_conv d_conv = self.model.config.d_conv
d_inner = self.model.config.d_inner d_inner = self.model.config.d_inner
inference_params = new_inference_params(n_blocks=n_blocks, batch_size=batch_size, d_state=d_state, d_conv=d_conv, d_inner=d_inner, seqlen_offset=seqlen_offset, device=self.device, dtype=self.dtype) inference_params = new_inference_params(
n_blocks=n_blocks,
batch_size=batch_size,
d_state=d_state,
d_conv=d_conv,
d_inner=d_inner,
seqlen_offset=seqlen_offset,
device=self.device,
dtype=self.dtype,
)
batch.inference_params = inference_params batch.inference_params = inference_params
# Forward pass # Forward pass
logits = self.forward( logits = self.forward(input_ids, inference_params=batch.inference_params)
input_ids, inference_params=batch.inference_params
)
# batch.inference_params = new_inference_params # batch.inference_params = new_inference_params
# Results # Results
...@@ -694,9 +743,9 @@ class Mamba(Model): ...@@ -694,9 +743,9 @@ class Mamba(Model):
generations.append(generation) generations.append(generation)
# Update values # Update values
batch.next_token_choosers[i] = batch.next_token_choosers[i].advance_grammar( batch.next_token_choosers[i] = batch.next_token_choosers[
next_token_id_squeezed.item() i
) ].advance_grammar(next_token_id_squeezed.item())
batch.input_ids[i, 0] = next_token_id batch.input_ids[i, 0] = next_token_id
batch.all_input_ids[i] = all_input_ids batch.all_input_ids[i] = all_input_ids
batch.input_lengths[i] = new_input_length batch.input_lengths[i] = new_input_length
......
...@@ -36,9 +36,11 @@ class RW(CausalLM): ...@@ -36,9 +36,11 @@ class RW(CausalLM):
model_id, model_id,
revision=revision, revision=revision,
torch_dtype=dtype, torch_dtype=dtype,
device_map="auto" device_map=(
if torch.cuda.is_available() and torch.cuda.device_count() > 1 "auto"
else None, if torch.cuda.is_available() and torch.cuda.device_count() > 1
else None
),
load_in_8bit=quantize == "bitsandbytes", load_in_8bit=quantize == "bitsandbytes",
trust_remote_code=trust_remote_code, trust_remote_code=trust_remote_code,
) )
......
...@@ -96,7 +96,9 @@ class Seq2SeqLMBatch(Batch): ...@@ -96,7 +96,9 @@ class Seq2SeqLMBatch(Batch):
inputs.append(r.inputs) inputs.append(r.inputs)
requests_idx_mapping[r.id] = i requests_idx_mapping[r.id] = i
decoder_input_lengths.append(1) decoder_input_lengths.append(1)
next_token_choosers.append(NextTokenChooser.from_pb(r.parameters, device, tokenizer)) next_token_choosers.append(
NextTokenChooser.from_pb(r.parameters, device, tokenizer)
)
stopping_criteria = StoppingCriteria.from_pb( stopping_criteria = StoppingCriteria.from_pb(
r.stopping_parameters, tokenizer r.stopping_parameters, tokenizer
) )
...@@ -351,9 +353,9 @@ class Seq2SeqLMBatch(Batch): ...@@ -351,9 +353,9 @@ class Seq2SeqLMBatch(Batch):
(total_batch_size, max_input_length), (total_batch_size, max_input_length),
) )
# Copy to correct indices # Copy to correct indices
attention_mask[ attention_mask[start_index:end_index, -batch.max_input_length :] = (
start_index:end_index, -batch.max_input_length : batch.attention_mask[:, -batch.max_input_length :]
] = batch.attention_mask[:, -batch.max_input_length :] )
# Create padded tensor # Create padded tensor
if decoder_input_ids is None: if decoder_input_ids is None:
...@@ -547,9 +549,11 @@ class Seq2SeqLM(Model): ...@@ -547,9 +549,11 @@ class Seq2SeqLM(Model):
model_id, model_id,
revision=revision, revision=revision,
torch_dtype=dtype, torch_dtype=dtype,
device_map="auto" device_map=(
if torch.cuda.is_available() and torch.cuda.device_count() > 1 "auto"
else None, if torch.cuda.is_available() and torch.cuda.device_count() > 1
else None
),
load_in_8bit=quantize == "bitsandbytes", load_in_8bit=quantize == "bitsandbytes",
trust_remote_code=trust_remote_code, trust_remote_code=trust_remote_code,
) )
...@@ -750,7 +754,7 @@ class Seq2SeqLM(Model): ...@@ -750,7 +754,7 @@ class Seq2SeqLM(Model):
if top_n_tokens > 0: if top_n_tokens > 0:
all_top_tokens = [] all_top_tokens = []
for (top_token_ids, top_token_logprobs) in zip( for top_token_ids, top_token_logprobs in zip(
top_token_ids, top_token_logprobs top_token_ids, top_token_logprobs
): ):
toptoken_texts = self.tokenizer.batch_decode( toptoken_texts = self.tokenizer.batch_decode(
......
...@@ -88,14 +88,16 @@ class Generation: ...@@ -88,14 +88,16 @@ class Generation:
def to_pb(self) -> generate_pb2.Generation: def to_pb(self) -> generate_pb2.Generation:
return generate_pb2.Generation( return generate_pb2.Generation(
request_id=self.request_id, request_id=self.request_id,
prefill_tokens=self.prefill_tokens.to_pb() prefill_tokens=(
if self.prefill_tokens is not None self.prefill_tokens.to_pb() if self.prefill_tokens is not None else None
else None, ),
tokens=self.tokens.to_pb(), tokens=self.tokens.to_pb(),
generated_text=self.generated_text.to_pb() generated_text=(
if self.generated_text is not None self.generated_text.to_pb() if self.generated_text is not None else None
else None, ),
top_tokens=[top_tokens.to_pb() for top_tokens in self.top_tokens] top_tokens=(
if self.top_tokens is not None [top_tokens.to_pb() for top_tokens in self.top_tokens]
else None, if self.top_tokens is not None
else None
),
) )
*.py *.py
*.pyi *.pyi
*.py-e *.py-e
\ No newline at end of file
...@@ -182,7 +182,7 @@ try: ...@@ -182,7 +182,7 @@ try:
) # (BLOCK_SIZE_K, BLOCK_SIZE_N,) ) # (BLOCK_SIZE_K, BLOCK_SIZE_N,)
zeros = (zeros >> zeros_shifter[None, :]) & maxq zeros = (zeros >> zeros_shifter[None, :]) & maxq
zeros = (zeros + 1) & maxq # eventually avoid overflow zeros = (zeros + 1) & maxq # eventually avoid overflow
a = tl.load(a_ptrs, mask=a_mask, other=0.0) # (BLOCK_SIZE_M, BLOCK_SIZE_K) a = tl.load(a_ptrs, mask=a_mask, other=0.0) # (BLOCK_SIZE_M, BLOCK_SIZE_K)
b = tl.load(b_ptrs) # (BLOCK_SIZE_K, BLOCK_SIZE_N), but repeated b = tl.load(b_ptrs) # (BLOCK_SIZE_K, BLOCK_SIZE_N), but repeated
......
...@@ -355,7 +355,9 @@ def get_linear(weight, bias, quantize): ...@@ -355,7 +355,9 @@ def get_linear(weight, bias, quantize):
"to use Exllama/GPTQ kernels for AWQ inference." "to use Exllama/GPTQ kernels for AWQ inference."
) )
if not HAS_AWQ: if not HAS_AWQ:
raise NotImplementedError("You do not seem to have awq installed, either install it (cd server && make install-awq), or try using GPTQ `---quantize gptq` a conversion AWQ->GPTQ will happen on the fly") raise NotImplementedError(
"You do not seem to have awq installed, either install it (cd server && make install-awq), or try using GPTQ `---quantize gptq` a conversion AWQ->GPTQ will happen on the fly"
)
linear = WQLinear( linear = WQLinear(
w_bit=bits, w_bit=bits,
group_size=groupsize, group_size=groupsize,
......
...@@ -516,7 +516,7 @@ class GrammarLogitProcessor(LogitsProcessor): ...@@ -516,7 +516,7 @@ class GrammarLogitProcessor(LogitsProcessor):
if grammar_type == GrammarType.GRAMMAR_TYPE_JSON: if grammar_type == GrammarType.GRAMMAR_TYPE_JSON:
schema = build_regex_from_object(schema) schema = build_regex_from_object(schema)
elif grammar_type == GrammarType.GRAMMAR_TYPE_REGEX: elif grammar_type == GrammarType.GRAMMAR_TYPE_REGEX:
pass # schema is already a regex just here for clarity pass # schema is already a regex just here for clarity
fsm = RegexFSM(schema, tokenizer) fsm = RegexFSM(schema, tokenizer)
logger.debug(f"Compiled FSM in {time.time() - start_time:.2f}s") logger.debug(f"Compiled FSM in {time.time() - start_time:.2f}s")
return fsm return fsm
......
...@@ -409,8 +409,12 @@ class HeterogeneousNextTokenChooser: ...@@ -409,8 +409,12 @@ class HeterogeneousNextTokenChooser:
def advance_grammar_single(self, grammar_state_index: int, next_id: int): def advance_grammar_single(self, grammar_state_index: int, next_id: int):
if self.grammar_processor is not None: if self.grammar_processor is not None:
self.fsm_grammar_states[grammar_state_index] = self.grammar_processor.advance_at_index( self.fsm_grammar_states[grammar_state_index] = (
next_id, self.fsm_grammar_states[grammar_state_index], grammar_state_index self.grammar_processor.advance_at_index(
next_id,
self.fsm_grammar_states[grammar_state_index],
grammar_state_index,
)
) )
return self return self
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment