Unverified Commit 3a521c92 authored by OlivierDehaene's avatar OlivierDehaene Committed by GitHub
Browse files

feat: mixtral (#1328)

parent 9ecfa16b
......@@ -154,6 +154,11 @@ COPY server/Makefile-vllm Makefile
# Build specific version of vllm
RUN make build-vllm-cuda
# Build megablocks
FROM kernel-builder as megablocks-builder
RUN pip install git+https://github.com/OlivierDehaene/megablocks@181709df192de9a941fdf3a641cdc65a0462996e
# Text Generation Inference base image
FROM nvidia/cuda:12.1.0-base-ubuntu20.04 as base
......@@ -175,8 +180,8 @@ RUN apt-get update && DEBIAN_FRONTEND=noninteractive apt-get install -y --no-ins
curl \
&& rm -rf /var/lib/apt/lists/*
# Copy conda with PyTorch installed
COPY --from=pytorch-install /opt/conda /opt/conda
# Copy conda with PyTorch and Megablocks installed
COPY --from=megablocks-builder /opt/conda /opt/conda
# Copy build artifacts from flash attention builder
COPY --from=flash-att-builder /usr/src/flash-attention/build/lib.linux-x86_64-cpython-310 /opt/conda/lib/python3.10/site-packages
......
......@@ -629,6 +629,9 @@ pub async fn run(
// Batch size buckets
let batch_size_matcher = Matcher::Full(String::from("tgi_batch_next_size"));
let batch_size_buckets: Vec<f64> = (0..1024).map(|x| (x + 1) as f64).collect();
// Speculated tokens buckets
let skipped_matcher = Matcher::Full(String::from("tgi_request_skipped_tokens"));
let skipped_buckets: Vec<f64> = (0..shard_info.speculate + 1).map(|x| x as f64).collect();
// Prometheus handler
let builder = PrometheusBuilder::new()
......@@ -641,6 +644,8 @@ pub async fn run(
.set_buckets_for_metric(max_new_tokens_matcher, &max_new_tokens_buckets)
.unwrap()
.set_buckets_for_metric(batch_size_matcher, &batch_size_buckets)
.unwrap()
.set_buckets_for_metric(skipped_matcher, &skipped_buckets)
.unwrap();
let prom_handle = builder
.install_recorder()
......
......@@ -16,6 +16,9 @@ gen-server:
find text_generation_server/pb/ -type f -name "*.py" -print0 -exec sed -i -e 's/^\(import.*pb2\)/from . \1/g' {} \;
touch text_generation_server/pb/__init__.py
install-megablocks:
pip install git+https://github.com/OlivierDehaene/megablocks@181709df192de9a941fdf3a641cdc65a0462996e
install: gen-server
pip install pip --upgrade
pip install -r requirements_cuda.txt
......
import os
import torch
from loguru import logger
......@@ -78,6 +77,18 @@ except ImportError as e:
if MISTRAL:
__all__.append(FlashMistral)
MIXTRAL = True
try:
from text_generation_server.models.flash_mixtral import FlashMixtral
except ImportError as e:
logger.warning(f"Could not import Mixtral model: {e}")
MIXTRAL = False
if MIXTRAL:
__all__.append(FlashMixtral)
def get_model(
model_id: str,
revision: Optional[str],
......@@ -141,7 +152,6 @@ def get_model(
use_medusa = None
if "medusa_num_heads" in config_dict:
use_medusa = model_id
medusa_config = config_dict
model_id = config_dict["base_model_name_or_path"]
revision = "main"
speculate_medusa = config_dict["medusa_num_heads"]
......@@ -292,7 +302,18 @@ def get_model(
dtype=dtype,
trust_remote_code=trust_remote_code,
)
raise NotImplementedError("Mistral model requires flash attention v2")
raise NotImplementedError("Mistral models requires flash attention v2")
if model_type == "mixtral":
if MIXTRAL:
return FlashMixtral(
model_id,
revision,
quantize=quantize,
dtype=dtype,
trust_remote_code=trust_remote_code,
)
raise NotImplementedError("Mixtral models requires flash attention v2, stk and megablocks")
if model_type == "opt":
return OPTSharded(
......
......@@ -34,14 +34,8 @@ from text_generation_server.utils.layers import (
PositionRotaryEmbedding,
TensorParallelHead,
get_linear,
FastRMSNorm
)
from text_generation_server.utils.import_utils import IS_CUDA_SYSTEM, IS_ROCM_SYSTEM
if IS_CUDA_SYSTEM:
import dropout_layer_norm
elif IS_ROCM_SYSTEM:
from vllm import layernorm_ops
class LlamaConfig(PretrainedConfig):
def __init__(
......@@ -95,75 +89,6 @@ class LlamaConfig(PretrainedConfig):
)
class LlamaRMSNorm(nn.Module):
def __init__(self, prefix, weights, eps=1e-6):
"""
LlamaRMSNorm is equivalent to T5LayerNorm
"""
super().__init__()
weight = weights.get_tensor(f"{prefix}.weight")
self.weight = nn.Parameter(weight)
self.variance_epsilon = eps
def forward(self, hidden_states, residual=None):
if hidden_states.shape[-1] > 8192:
if residual is not None:
hidden_states += residual
residual = hidden_states
hidden_states = hidden_states.to(torch.float32)
variance = hidden_states.pow(2).mean(-1, keepdim=True)
hidden_states = hidden_states * torch.rsqrt(
variance + self.variance_epsilon
)
# convert into half-precision if necessary
if self.weight.dtype in [torch.float16, torch.bfloat16]:
hidden_states = hidden_states.to(self.weight.dtype)
return self.weight * hidden_states, residual
elif IS_CUDA_SYSTEM:
# faster post attention rms norm
normed_hidden_states, res, *rest = dropout_layer_norm.dropout_add_ln_fwd(
hidden_states,
residual,
self.weight,
None,
None,
None,
None,
None,
0.0,
self.variance_epsilon,
1.0,
0,
None,
False,
True, # Activate RMSNorm
)
if res is None:
res = hidden_states
return normed_hidden_states, res
elif IS_ROCM_SYSTEM:
# We use VLLM RMSNorm kernel that can be compiled for RoCm, instead of Flash Attention ones that can not.
if residual is not None:
hidden_states += residual
residual = hidden_states
out = torch.empty_like(hidden_states)
layernorm_ops.rms_norm(
out,
hidden_states,
self.weight.data,
self.variance_epsilon,
)
return out, residual
else:
raise ValueError("Your system seem to be not supported. Please check your install or open an issue at https://github.com/huggingface/text-generation-inference/issues with a clear reproduction.")
def load_attention(config, prefix, weights):
if config.num_attention_heads != config.num_key_value_heads:
return _load_gqa(config, prefix, weights)
......@@ -363,10 +288,8 @@ class FlashLlamaLayer(nn.Module):
)
self.mlp = LlamaMLP(prefix=f"{prefix}.mlp", config=config, weights=weights)
self.input_layernorm = LlamaRMSNorm(
prefix=f"{prefix}.input_layernorm", weights=weights, eps=config.rms_norm_eps
)
self.post_attention_layernorm = LlamaRMSNorm(
self.input_layernorm = FastRMSNorm.load(prefix=f"{prefix}.input_layernorm", weights=weights, eps=config.rms_norm_eps)
self.post_attention_layernorm = FastRMSNorm.load(
prefix=f"{prefix}.post_attention_layernorm",
weights=weights,
eps=config.rms_norm_eps,
......@@ -430,7 +353,7 @@ class FlashLlamaModel(torch.nn.Module):
for layer_id in range(config.num_hidden_layers)
]
)
self.norm = LlamaRMSNorm(
self.norm = FastRMSNorm.load(
prefix="model.norm", weights=weights, eps=config.rms_norm_eps
)
......
......@@ -35,13 +35,9 @@ from text_generation_server.utils.layers import (
PositionRotaryEmbedding,
TensorParallelHead,
get_linear,
FastRMSNorm
)
from text_generation_server.utils.import_utils import IS_CUDA_SYSTEM, IS_ROCM_SYSTEM
if IS_CUDA_SYSTEM:
import dropout_layer_norm
elif IS_ROCM_SYSTEM:
from vllm import layernorm_ops
if not HAS_FLASH_ATTN_V2_CUDA and not HAS_FLASH_ATTN_V2_ROCM:
raise ImportError("Mistral model requires flash attn v2")
......@@ -100,76 +96,6 @@ class MistralConfig(PretrainedConfig):
**kwargs,
)
class MistralRMSNorm(nn.Module):
def __init__(self, prefix, weights, eps=1e-6):
"""
LlamaRMSNorm is equivalent to T5LayerNorm
"""
super().__init__()
weight = weights.get_tensor(f"{prefix}.weight")
self.weight = nn.Parameter(weight)
self.variance_epsilon = eps
def forward(self, hidden_states, residual=None):
if hidden_states.shape[-1] > 8192:
if residual is not None:
hidden_states += residual
residual = hidden_states
hidden_states = hidden_states.to(torch.float32)
variance = hidden_states.pow(2).mean(-1, keepdim=True)
hidden_states = hidden_states * torch.rsqrt(
variance + self.variance_epsilon
)
# convert into half-precision if necessary
if self.weight.dtype in [torch.float16, torch.bfloat16]:
hidden_states = hidden_states.to(self.weight.dtype)
return self.weight * hidden_states, residual
elif IS_CUDA_SYSTEM:
# faster post attention rms norm
normed_hidden_states, res, *rest = dropout_layer_norm.dropout_add_ln_fwd(
hidden_states,
residual,
self.weight,
None,
None,
None,
None,
None,
0.0,
self.variance_epsilon,
1.0,
0,
None,
False,
True, # Activate RMSNorm
)
if res is None:
res = hidden_states
return normed_hidden_states, res
elif IS_ROCM_SYSTEM:
# We use VLLM RMSNorm kernel that can be compiled for RoCm, instead of Flash Attention ones that can not.
if residual is not None:
hidden_states += residual
residual = hidden_states
out = torch.empty_like(hidden_states)
layernorm_ops.rms_norm(
out,
hidden_states,
self.weight.data,
self.variance_epsilon,
)
return out, residual
else:
raise ValueError("Your system seem to be not supported. Please check your install or open an issue at https://github.com/huggingface/text-generation-inference/issues with a clear reproduction.")
def load_attention(config, prefix, weights):
if config.num_attention_heads != config.num_key_value_heads:
return _load_gqa(config, prefix, weights)
......@@ -371,10 +297,10 @@ class MistralLayer(nn.Module):
)
self.mlp = MistralMLP(prefix=f"{prefix}.mlp", config=config, weights=weights)
self.input_layernorm = MistralRMSNorm(
self.input_layernorm = FastRMSNorm.load(
prefix=f"{prefix}.input_layernorm", weights=weights, eps=config.rms_norm_eps
)
self.post_attention_layernorm = MistralRMSNorm(
self.post_attention_layernorm = FastRMSNorm.load(
prefix=f"{prefix}.post_attention_layernorm",
weights=weights,
eps=config.rms_norm_eps,
......@@ -440,7 +366,7 @@ class MistralModel(torch.nn.Module):
for layer_id in range(config.num_hidden_layers)
]
)
self.norm = MistralRMSNorm(
self.norm = FastRMSNorm.load(
prefix="model.norm", weights=weights, eps=config.rms_norm_eps
)
......
......@@ -6,7 +6,6 @@ from transformers.activations import ACT2FN
from typing import Optional, List, Tuple
from text_generation_server.utils import paged_attention, flash_attn
from text_generation_server.utils.flash_attn import attention
from text_generation_server.utils.layers import (
TensorParallelRowLinear,
TensorParallelColumnLinear,
......
......@@ -8,14 +8,13 @@ from dataclasses import dataclass
from opentelemetry import trace
from transformers import PreTrainedTokenizerBase
from transformers.models.llama import LlamaTokenizerFast
from typing import Optional, Tuple, Type
from typing import Optional, Tuple, Type, List
from text_generation_server.pb import generate_pb2
from text_generation_server.models import FlashCausalLM
from text_generation_server.models.flash_causal_lm import FlashCausalLMBatch, BLOCK_SIZE
from text_generation_server.models.cache_manager import (
get_cache_manager,
set_cache_manager,
)
from text_generation_server.models.custom_modeling.flash_mistral_modeling import (
FlashMistralForCausalLM,
......@@ -105,7 +104,7 @@ class FlashMistralBatch(FlashCausalLMBatch):
# request id -> idx in list mapping
requests_idx_mapping[r.id] = i
tokenized_input = tokenized_input[-r.truncate :]
tokenized_input = tokenized_input[-r.truncate:]
input_length = len(tokenized_input)
input_lengths.append(input_length)
......@@ -278,9 +277,11 @@ class FlashMistralBatch(FlashCausalLMBatch):
)
class FlashMistral(FlashCausalLM):
class BaseFlashMistral(FlashCausalLM):
def __init__(
self,
config_cls,
model_cls,
model_id: str,
revision: Optional[str] = None,
quantize: Optional[str] = None,
......@@ -305,7 +306,7 @@ class FlashMistral(FlashCausalLM):
trust_remote_code=trust_remote_code,
)
config = MistralConfig.from_pretrained(
config = config_cls.from_pretrained(
model_id, revision=revision, trust_remote_code=trust_remote_code
)
config.quantize = quantize
......@@ -321,10 +322,10 @@ class FlashMistral(FlashCausalLM):
if config.quantize in ["gptq", "awq"]:
weights._set_gptq_params(model_id)
model = FlashMistralForCausalLM(config, weights)
model = model_cls(config, weights)
torch.distributed.barrier(group=self.process_group)
super(FlashMistral, self).__init__(
super(BaseFlashMistral, self).__init__(
model=model,
tokenizer=tokenizer,
num_layers=len(model.model.layers),
......@@ -396,3 +397,23 @@ class FlashMistral(FlashCausalLM):
if batch.prefill_cache_indices is not None:
batch.prefill_cache_indices = None
return logits
class FlashMistral(BaseFlashMistral):
def __init__(
self,
model_id: str,
revision: Optional[str] = None,
quantize: Optional[str] = None,
dtype: Optional[torch.dtype] = None,
trust_remote_code: bool = False,
):
super(FlashMistral, self).__init__(
config_cls=MistralConfig,
model_cls=FlashMistralForCausalLM,
model_id=model_id,
revision=revision,
quantize=quantize,
dtype=dtype,
trust_remote_code=trust_remote_code
)
import torch
from typing import Optional
from text_generation_server.models.flash_mistral import BaseFlashMistral
from text_generation_server.models.custom_modeling.flash_mixtral_modeling import MixtralConfig, FlashMixtralForCausalLM
class FlashMixtral(BaseFlashMistral):
def __init__(
self,
model_id: str,
revision: Optional[str] = None,
quantize: Optional[str] = None,
dtype: Optional[torch.dtype] = None,
trust_remote_code: bool = False,
):
super(FlashMixtral, self).__init__(
config_cls=MixtralConfig,
model_cls=FlashMixtralForCausalLM,
model_id=model_id,
revision=revision,
quantize=quantize,
dtype=dtype,
trust_remote_code=trust_remote_code
)
......@@ -47,12 +47,14 @@ elif CAN_EXLLAMA:
create_exllama_buffers,
set_device,
)
HAS_EXLLAMA = "2"
else:
from text_generation_server.utils.gptq.exllama import (Ex4bitLinear as ExllamaQuantLinear,
create_exllama_buffers,
set_device,
)
HAS_EXLLAMA = "1"
except ImportError:
......@@ -526,9 +528,12 @@ class TensorParallelEmbedding(nn.Module):
try:
if IS_CUDA_SYSTEM:
import dropout_layer_norm
elif IS_ROCM_SYSTEM:
from vllm import layernorm_ops
else:
dropout_layer_norm = None
class FastLayerNorm(nn.LayerNorm):
def forward(self, hidden_states, residual=None):
if hidden_states.shape[-1] > 8192 or IS_ROCM_SYSTEM:
......@@ -563,10 +568,81 @@ try:
residual = hidden_states
return normed_hidden_states, residual
class FastRMSNorm(nn.Module):
def __init__(self, weight: torch.Tensor, eps: float):
super().__init__()
self.weight = nn.Parameter(weight)
self.variance_epsilon = eps
@classmethod
def load(cls, prefix, weights, eps=1e-6):
weight = weights.get_tensor(f"{prefix}.weight")
return cls(weight, eps)
def forward(self, hidden_states, residual=None):
if hidden_states.shape[-1] > 8192:
if residual is not None:
hidden_states += residual
residual = hidden_states
hidden_states = hidden_states.to(torch.float32)
variance = hidden_states.pow(2).mean(-1, keepdim=True)
hidden_states = hidden_states * torch.rsqrt(
variance + self.variance_epsilon
)
# convert into half-precision if necessary
if self.weight.dtype in [torch.float16, torch.bfloat16]:
hidden_states = hidden_states.to(self.weight.dtype)
return self.weight * hidden_states, residual
elif IS_CUDA_SYSTEM:
# faster post attention rms norm
normed_hidden_states, res, *rest = dropout_layer_norm.dropout_add_ln_fwd(
hidden_states,
residual,
self.weight,
None,
None,
None,
None,
None,
0.0,
self.variance_epsilon,
1.0,
0,
None,
False,
True, # Activate RMSNorm
)
if res is None:
res = hidden_states
return normed_hidden_states, res
elif IS_ROCM_SYSTEM:
# We use VLLM RMSNorm kernel that can be compiled for RoCm, instead of Flash Attention ones that can not.
if residual is not None:
hidden_states += residual
residual = hidden_states
out = torch.empty_like(hidden_states)
layernorm_ops.rms_norm(
out,
hidden_states,
self.weight.data,
self.variance_epsilon,
)
return out, residual
else:
raise ValueError(
"Your system seem to be not supported. Please check your install or open an issue at https://github.com/huggingface/text-generation-inference/issues with a clear reproduction.")
except ImportError:
pass
try:
if IS_CUDA_SYSTEM:
from flash_attn.layers.rotary import RotaryEmbedding
......@@ -574,12 +650,14 @@ try:
elif IS_ROCM_SYSTEM:
from vllm import pos_encoding_ops
def _create_inv_freq(dim, base, device):
inv_freq = 1.0 / (
base ** (torch.arange(0, dim, 2, device=device, dtype=torch.float32) / dim)
)
return inv_freq
def _get_rope_config(config):
if os.getenv("ROPE_SCALING", None) is not None:
rope_scaling = {
......@@ -589,6 +667,7 @@ try:
return rope_scaling
return getattr(config, "rope_scaling", None)
class PositionRotaryEmbedding(nn.Module):
def __init__(self, inv_freq, scaling_factor):
super().__init__()
......@@ -606,12 +685,12 @@ try:
if IS_CUDA_SYSTEM:
rotary_dim = cos.shape[-1]
q1 = query[..., :rotary_dim]
q2 = query[..., rotary_dim : 2 * rotary_dim]
q2 = query[..., rotary_dim: 2 * rotary_dim]
rotary_emb.apply_rotary(q1, q2, cos, sin, q1, q2, False)
k1 = key[..., :rotary_dim]
k2 = key[..., rotary_dim : 2 * rotary_dim]
k2 = key[..., rotary_dim: 2 * rotary_dim]
rotary_emb.apply_rotary(k1, k2, cos, sin, k1, k2, False)
elif IS_ROCM_SYSTEM:
......@@ -630,7 +709,8 @@ try:
True
)
else:
raise ValueError("Your system seem to be not supported. Please check your install or open an issue at https://github.com/huggingface/text-generation-inference/issues with a clear reproduction.")
raise ValueError(
"Your system seem to be not supported. Please check your install or open an issue at https://github.com/huggingface/text-generation-inference/issues with a clear reproduction.")
@classmethod
def static(cls, config, dim, base, device):
......@@ -747,6 +827,7 @@ try:
# Note: this unsqueeze is not necessary on RoCm + VLLM ROPE implementation, but we leave it as is to avoid yet an other controlflow.
return cos.unsqueeze(1), sin.unsqueeze(1)
class DynamicPositionRotaryEmbedding(PositionRotaryEmbedding):
def __init__(self, dim, max_position_embeddings, base, device, scaling_factor):
inv_freq = _create_inv_freq(dim, base, device)
......@@ -783,8 +864,11 @@ try:
# Inverse dim formula to find dim based on number of rotations
import math
def find_correction_dim(num_rotations, dim, base=10000, max_position_embeddings=2048):
return (dim * math.log(max_position_embeddings/(num_rotations * 2 * math.pi)))/(2 * math.log(base))
return (dim * math.log(max_position_embeddings / (num_rotations * 2 * math.pi))) / (2 * math.log(base))
# Find dim range bounds based on rotations
def find_correction_range(low_rot, high_rot, dim, base=10000, max_position_embeddings=2048):
......@@ -792,7 +876,8 @@ try:
low_rot, dim, base, max_position_embeddings))
high = math.ceil(find_correction_dim(
high_rot, dim, base, max_position_embeddings))
return max(low, 0), min(high, dim-1) # Clamp values just in case
return max(low, 0), min(high, dim - 1) # Clamp values just in case
def linear_ramp_mask(min, max, dim):
if min == max:
......@@ -802,13 +887,16 @@ try:
ramp_func = torch.clamp(linear_func, 0, 1)
return ramp_func
def get_mscale(scale=1):
if scale <= 1:
return 1.0
return 0.1 * math.log(scale) + 1.0
class YarnPositionRotaryEmbedding(PositionRotaryEmbedding):
def __init__(self, dim, max_position_embeddings, base, device, scaling_factor,*, extrapolation_factor, attn_factor, beta_fast, beta_slow):
def __init__(self, dim, max_position_embeddings, base, device, scaling_factor, *, extrapolation_factor,
attn_factor, beta_fast, beta_slow):
inv_freq = _create_inv_freq(dim, base, device)
super().__init__(inv_freq, scaling_factor)
self.dim = dim
......@@ -818,7 +906,8 @@ try:
self.attn_factor = attn_factor
self.beta_fast = beta_fast
self.beta_slow = beta_slow
self.mscale = float(get_mscale(self.scaling_factor) * self.attn_factor) # Get n-d magnitude scaling corrected for interpolation
self.mscale = float(get_mscale(
self.scaling_factor) * self.attn_factor) # Get n-d magnitude scaling corrected for interpolation
def _update_cos_sin_cache(self, dtype, device, seqlen):
# Reset the tables if the sequence length has changed,
......@@ -834,13 +923,15 @@ try:
)
freqs = 1.0 / inv_freq_extrapolation
inv_freq_interpolation = 1.0 / (self.scaling_factor * freqs)
low, high = find_correction_range(self.beta_fast, self.beta_slow, self.dim, self.base, self.max_position_embeddings)
inv_freq_mask = (1 - linear_ramp_mask(low, high, self.dim // 2).float().to(device)) * self.extrapolation_factor # Get n-d rotational scaling corrected for extrapolation
low, high = find_correction_range(self.beta_fast, self.beta_slow, self.dim, self.base,
self.max_position_embeddings)
inv_freq_mask = (1 - linear_ramp_mask(low, high, self.dim // 2).float().to(
device)) * self.extrapolation_factor # Get n-d rotational scaling corrected for extrapolation
inv_freq = inv_freq_interpolation * (1 - inv_freq_mask) + inv_freq_extrapolation * inv_freq_mask
self.inv_freq = inv_freq
self.mscale = float(get_mscale(self.scaling_factor) * self.attn_factor) # Get n-d magnitude scaling corrected for interpolation
self.mscale = float(get_mscale(
self.scaling_factor) * self.attn_factor) # Get n-d magnitude scaling corrected for interpolation
self._seq_len_cached = seqlen
t = torch.arange(seqlen, device=device, dtype=self.inv_freq.dtype)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment