Unverified Commit e52be9bb authored by Daniël de Kok's avatar Daniël de Kok Committed by GitHub
Browse files

Add support for Deepseek V2 (#2224)

Deepseek V2 is a MoE model from Deepseek. Relevant variations
compared to other models:

- Grouped top-K in expert selection.
- mscale in yarn is calculated using the `mscale` and `mscale_all_dim`
  configuration options.
- `mscale_all_dim` is also used in scaling attention softmax.
- Permuting of the query/key representations before applying rotary
  embeddings.
- Some projections cannot be sharded (`q_a_proj`, `kv_a_proj_with_mqa`).
  So, we need weight loads that supports quantized weights. To this
  end `{Weights,WeightLoader}.get_weight` was added.
- The query/key head dimensionality differs from that of the value,
  so we need to pad during attention.
- Heads with size 192, needs an extension to our paged attention
  fork and we need to ensure that the KV cache is allocated with the
  correct size.
- Shared experts.
parent 68a9685f
...@@ -5,6 +5,7 @@ Text Generation Inference enables serving optimized models on specific hardware ...@@ -5,6 +5,7 @@ Text Generation Inference enables serving optimized models on specific hardware
## Supported Models ## Supported Models
- [Deepseek V2](https://huggingface.co/deepseek-ai/DeepSeek-V2)
- [Idefics 2](https://huggingface.co/HuggingFaceM4/idefics2-8b) (Multimodal) - [Idefics 2](https://huggingface.co/HuggingFaceM4/idefics2-8b) (Multimodal)
- [Llava Next (1.6)](https://huggingface.co/llava-hf/llava-v1.6-vicuna-13b-hf) (Multimodal) - [Llava Next (1.6)](https://huggingface.co/llava-hf/llava-v1.6-vicuna-13b-hf) (Multimodal)
- [Llama](https://huggingface.co/meta-llama/Meta-Llama-3-8B-Instruct) - [Llama](https://huggingface.co/meta-llama/Meta-Llama-3-8B-Instruct)
......
{
"details": {
"best_of_sequences": null,
"finish_reason": "length",
"generated_tokens": 10,
"prefill": [
{
"id": 100000,
"logprob": null,
"text": "<|begin▁of▁sentence|>"
},
{
"id": 3533,
"logprob": -9.625,
"text": "Test"
},
{
"id": 3102,
"logprob": -11.1875,
"text": " request"
}
],
"seed": null,
"tokens": [
{
"id": 185,
"logprob": -1.5546875,
"special": false,
"text": "\n"
},
{
"id": 549,
"logprob": -2.84375,
"special": false,
"text": "The"
},
{
"id": 1727,
"logprob": -2.34375,
"special": false,
"text": " test"
},
{
"id": 3102,
"logprob": -0.8359375,
"special": false,
"text": " request"
},
{
"id": 317,
"logprob": -1.0859375,
"special": false,
"text": " is"
},
{
"id": 254,
"logprob": -1.5390625,
"special": false,
"text": " the"
},
{
"id": 1022,
"logprob": -1.1875,
"special": false,
"text": " first"
},
{
"id": 3458,
"logprob": -0.35546875,
"special": false,
"text": " step"
},
{
"id": 279,
"logprob": -0.8828125,
"special": false,
"text": " in"
},
{
"id": 254,
"logprob": -0.71484375,
"special": false,
"text": " the"
}
],
"top_tokens": null
},
"generated_text": "\nThe test request is the first step in the"
}
{
"details": {
"best_of_sequences": null,
"finish_reason": "length",
"generated_tokens": 10,
"prefill": [
{
"id": 100000,
"logprob": null,
"text": "<|begin▁of▁sentence|>"
},
{
"id": 3533,
"logprob": -9.625,
"text": "Test"
},
{
"id": 3102,
"logprob": -11.1875,
"text": " request"
}
],
"seed": 0,
"tokens": [
{
"id": 2143,
"logprob": -1.828125,
"special": false,
"text": " sent"
},
{
"id": 10081,
"logprob": -0.36914062,
"special": false,
"text": " successfully"
},
{
"id": 13,
"logprob": 0.0,
"special": false,
"text": "."
},
{
"id": 185,
"logprob": 0.0,
"special": false,
"text": "\n"
},
{
"id": 1380,
"logprob": -0.38671875,
"special": false,
"text": "We"
},
{
"id": 543,
"logprob": -0.12695312,
"special": false,
"text": " will"
},
{
"id": 752,
"logprob": -0.20117188,
"special": false,
"text": " get"
},
{
"id": 279,
"logprob": 0.0,
"special": false,
"text": " in"
},
{
"id": 5402,
"logprob": 0.0,
"special": false,
"text": " touch"
},
{
"id": 366,
"logprob": 0.0,
"special": false,
"text": " with"
}
],
"top_tokens": null
},
"generated_text": "Test request sent successfully.\nWe will get in touch with"
}
[
{
"details": {
"best_of_sequences": null,
"finish_reason": "length",
"generated_tokens": 10,
"prefill": [
{
"id": 100000,
"logprob": null,
"text": "<|begin▁of▁sentence|>"
},
{
"id": 3533,
"logprob": -9.625,
"text": "Test"
},
{
"id": 3102,
"logprob": -11.1875,
"text": " request"
}
],
"seed": null,
"tokens": [
{
"id": 185,
"logprob": -1.5546875,
"special": false,
"text": "\n"
},
{
"id": 549,
"logprob": -2.8125,
"special": false,
"text": "The"
},
{
"id": 1727,
"logprob": -2.375,
"special": false,
"text": " test"
},
{
"id": 3102,
"logprob": -0.890625,
"special": false,
"text": " request"
},
{
"id": 317,
"logprob": -1.1484375,
"special": false,
"text": " is"
},
{
"id": 245,
"logprob": -1.5390625,
"special": false,
"text": " a"
},
{
"id": 3102,
"logprob": -2.609375,
"special": false,
"text": " request"
},
{
"id": 327,
"logprob": -0.75,
"special": false,
"text": " for"
},
{
"id": 245,
"logprob": -1.1171875,
"special": false,
"text": " a"
},
{
"id": 1727,
"logprob": -0.90625,
"special": false,
"text": " test"
}
],
"top_tokens": null
},
"generated_text": "\nThe test request is a request for a test"
},
{
"details": {
"best_of_sequences": null,
"finish_reason": "length",
"generated_tokens": 10,
"prefill": [
{
"id": 100000,
"logprob": null,
"text": "<|begin▁of▁sentence|>"
},
{
"id": 3533,
"logprob": -9.625,
"text": "Test"
},
{
"id": 3102,
"logprob": -11.25,
"text": " request"
}
],
"seed": null,
"tokens": [
{
"id": 185,
"logprob": -1.5546875,
"special": false,
"text": "\n"
},
{
"id": 549,
"logprob": -2.8125,
"special": false,
"text": "The"
},
{
"id": 1727,
"logprob": -2.375,
"special": false,
"text": " test"
},
{
"id": 3102,
"logprob": -0.890625,
"special": false,
"text": " request"
},
{
"id": 317,
"logprob": -1.1484375,
"special": false,
"text": " is"
},
{
"id": 245,
"logprob": -1.5390625,
"special": false,
"text": " a"
},
{
"id": 3102,
"logprob": -2.609375,
"special": false,
"text": " request"
},
{
"id": 327,
"logprob": -0.75,
"special": false,
"text": " for"
},
{
"id": 245,
"logprob": -1.1171875,
"special": false,
"text": " a"
},
{
"id": 1727,
"logprob": -0.90625,
"special": false,
"text": " test"
}
],
"top_tokens": null
},
"generated_text": "\nThe test request is a request for a test"
},
{
"details": {
"best_of_sequences": null,
"finish_reason": "length",
"generated_tokens": 10,
"prefill": [
{
"id": 100000,
"logprob": null,
"text": "<|begin▁of▁sentence|>"
},
{
"id": 3533,
"logprob": -9.625,
"text": "Test"
},
{
"id": 3102,
"logprob": -11.25,
"text": " request"
}
],
"seed": null,
"tokens": [
{
"id": 185,
"logprob": -1.5546875,
"special": false,
"text": "\n"
},
{
"id": 549,
"logprob": -2.8125,
"special": false,
"text": "The"
},
{
"id": 1727,
"logprob": -2.375,
"special": false,
"text": " test"
},
{
"id": 3102,
"logprob": -0.890625,
"special": false,
"text": " request"
},
{
"id": 317,
"logprob": -1.1484375,
"special": false,
"text": " is"
},
{
"id": 245,
"logprob": -1.5390625,
"special": false,
"text": " a"
},
{
"id": 3102,
"logprob": -2.609375,
"special": false,
"text": " request"
},
{
"id": 327,
"logprob": -0.75,
"special": false,
"text": " for"
},
{
"id": 245,
"logprob": -1.1171875,
"special": false,
"text": " a"
},
{
"id": 1727,
"logprob": -0.90625,
"special": false,
"text": " test"
}
],
"top_tokens": null
},
"generated_text": "\nThe test request is a request for a test"
},
{
"details": {
"best_of_sequences": null,
"finish_reason": "length",
"generated_tokens": 10,
"prefill": [
{
"id": 100000,
"logprob": null,
"text": "<|begin▁of▁sentence|>"
},
{
"id": 3533,
"logprob": -9.625,
"text": "Test"
},
{
"id": 3102,
"logprob": -11.25,
"text": " request"
}
],
"seed": null,
"tokens": [
{
"id": 185,
"logprob": -1.5546875,
"special": false,
"text": "\n"
},
{
"id": 549,
"logprob": -2.8125,
"special": false,
"text": "The"
},
{
"id": 1727,
"logprob": -2.375,
"special": false,
"text": " test"
},
{
"id": 3102,
"logprob": -0.890625,
"special": false,
"text": " request"
},
{
"id": 317,
"logprob": -1.1484375,
"special": false,
"text": " is"
},
{
"id": 245,
"logprob": -1.5390625,
"special": false,
"text": " a"
},
{
"id": 3102,
"logprob": -2.609375,
"special": false,
"text": " request"
},
{
"id": 327,
"logprob": -0.75,
"special": false,
"text": " for"
},
{
"id": 245,
"logprob": -1.1171875,
"special": false,
"text": " a"
},
{
"id": 1727,
"logprob": -0.90625,
"special": false,
"text": " test"
}
],
"top_tokens": null
},
"generated_text": "\nThe test request is a request for a test"
}
]
import pytest
@pytest.fixture(scope="module")
def flash_deepseek_v2_handle(launcher):
with launcher("deepseek-ai/DeepSeek-V2-Lite", num_shard=2) as handle:
yield handle
@pytest.fixture(scope="module")
async def flash_deepseek_v2(flash_deepseek_v2_handle):
await flash_deepseek_v2_handle.health(300)
return flash_deepseek_v2_handle.client
@pytest.mark.release
@pytest.mark.asyncio
@pytest.mark.private
async def test_flash_deepseek_v2(flash_deepseek_v2, response_snapshot):
response = await flash_deepseek_v2.generate(
"Test request", max_new_tokens=10, decoder_input_details=True
)
assert response == response_snapshot
@pytest.mark.release
@pytest.mark.asyncio
@pytest.mark.private
async def test_flash_deepseek_v2_all_params(flash_deepseek_v2, response_snapshot):
response = await flash_deepseek_v2.generate(
"Test request",
max_new_tokens=10,
repetition_penalty=1.2,
return_full_text=True,
stop_sequences=["test"],
temperature=0.5,
top_p=0.9,
top_k=10,
truncate=5,
typical_p=0.9,
watermark=True,
decoder_input_details=True,
seed=0,
)
assert response == response_snapshot
@pytest.mark.release
@pytest.mark.asyncio
@pytest.mark.private
async def test_flash_deepseek_v2_load(
flash_deepseek_v2, generate_load, response_snapshot
):
responses = await generate_load(
flash_deepseek_v2, "Test request", max_new_tokens=10, n=4
)
assert len(responses) == 4
assert all([r.generated_text == responses[0].generated_text for r in responses])
assert responses == response_snapshot
commit_cuda := b5dfc61db88a81069e45b44f7cc99bd9e62a60fa commit_cuda := d243e9dc7e2c9c2e36a4150ec8e64809cb55c01b
commit_rocm := c6ee53b1be97e3bbc791b95f22827501297f8921 commit_rocm := c6ee53b1be97e3bbc791b95f22827501297f8921
build-vllm-cuda: build-vllm-cuda:
if [ ! -d 'vllm' ]; then \ if [ ! -d 'vllm' ]; then \
pip install -U ninja packaging --no-cache-dir && \ pip install -U ninja packaging --no-cache-dir && \
git clone https://github.com/Narsil/vllm.git vllm; \ git clone https://github.com/Narsil/vllm.git vllm; \
fi fi
cd vllm && git fetch && git checkout $(commit_cuda) && python setup.py build cd vllm && git fetch origin && git checkout $(commit_cuda) && python setup.py build
install-vllm-cuda: build-vllm-cuda install-vllm-cuda: build-vllm-cuda
cd vllm && git fetch && git checkout $(commit_cuda) && pip install -e . cd vllm && git fetch origin && git checkout $(commit_cuda) && pip install -e .
build-vllm-rocm: build-vllm-rocm:
if [ ! -d 'vllm' ]; then \ if [ ! -d 'vllm' ]; then \
......
...@@ -34,15 +34,10 @@ class Exl2Weight(Weight): ...@@ -34,15 +34,10 @@ class Exl2Weight(Weight):
class Exl2WeightsLoader(WeightsLoader): class Exl2WeightsLoader(WeightsLoader):
"""Loader for exl2-quantized weights.""" """Loader for exl2-quantized weights."""
def get_weights_col_packed( def get_weights(self, weights: "Weights", prefix: str):
self, """
weights: Weights, Get weights at the given prefix and apply without tensor paralllism.
prefix: str, """
block_sizes: Union[int, List[int]],
):
raise RuntimeError("Column-packed weights are not supported for exl")
def get_weights_col(self, weights: Weights, prefix: str):
try: try:
q_weight = weights.get_tensor(f"{prefix}.q_weight") q_weight = weights.get_tensor(f"{prefix}.q_weight")
except RuntimeError: except RuntimeError:
...@@ -63,26 +58,21 @@ class Exl2WeightsLoader(WeightsLoader): ...@@ -63,26 +58,21 @@ class Exl2WeightsLoader(WeightsLoader):
q_groups=q_groups, q_groups=q_groups,
) )
def get_weights_col_packed(
self,
weights: Weights,
prefix: str,
block_sizes: Union[int, List[int]],
):
raise RuntimeError("Column-packed weights are not supported for exl")
def get_weights_col(self, weights: Weights, prefix: str):
# Sharding is not yet supported, so we return the weights as-is.
return self.get_weights(weights, prefix)
def get_multi_weights_col(self, weights: Weights, prefixes: List[str], dim: int): def get_multi_weights_col(self, weights: Weights, prefixes: List[str], dim: int):
raise ValueError("get_multi_weights_col is not supported for exl2") raise ValueError("get_multi_weights_col is not supported for exl2")
def get_weights_row(self, weights: Weights, prefix: str): def get_weights_row(self, weights: Weights, prefix: str):
try: # Sharding is not yet supported, so we return the weights as-is.
q_weight = weights.get_tensor(f"{prefix}.q_weight") return self.get_weights(weights, prefix)
except RuntimeError:
raise RuntimeError(
"Cannot load `exl2`-quantized weight, make sure the model is already quantized."
)
q_scale = weights.get_tensor(f"{prefix}.q_scale")
q_invperm = weights.get_tensor(f"{prefix}.q_invperm")
q_scale_max = weights.get_tensor(f"{prefix}.q_scale_max")
q_groups = weights.get_tensor(f"{prefix}.q_groups")
return Exl2Weight(
q_weight=q_weight,
q_scale=q_scale,
q_invperm=q_invperm,
q_scale_max=q_scale_max,
q_groups=q_groups,
)
...@@ -134,6 +134,115 @@ class GPTQWeightsLoader(WeightsLoader): ...@@ -134,6 +134,115 @@ class GPTQWeightsLoader(WeightsLoader):
self.quantize = quantize self.quantize = quantize
self.sym = sym self.sym = sym
def get_weights(self, weights: Weights, prefix: str):
from text_generation_server.layers.marlin import (
can_use_gptq_marlin,
repack_gptq_for_marlin,
)
self._get_gptq_params(weights)
if can_use_gptq_marlin(
bits=self.bits,
groupsize=self.groupsize,
quant_method=self.quant_method,
quantize=self.quantize,
sym=self.sym,
):
log_once(logger.info, "Using GPTQ-Marlin kernels")
try:
qweight = weights.get_tensor(f"{prefix}.qweight")
except RuntimeError:
raise RuntimeError(
f"Cannot load `{self.quantize}` weight for GPTQ -> Marlin repacking, make sure the model is already quantized"
)
g_idx = weights.get_tensor(f"{prefix}.g_idx")
scales = weights.get_tensor(f"{prefix}.scales")
return repack_gptq_for_marlin(
qweight=qweight,
scales=scales,
g_idx=g_idx,
bits=self.bits,
desc_act=self.desc_act,
groupsize=self.groupsize,
sym=self.sym,
sharded_infeatures=False,
)
use_exllama = True
if self.bits != 4:
use_exllama = False
if self.desc_act:
log_once(logger.warning, "Disabling exllama because desc_act=True")
use_exllama = False
try:
qweight = weights.get_tensor(f"{prefix}.qweight")
except RuntimeError:
raise RuntimeError(
"Cannot load `gptq` weight, make sure the model is already quantized, or quantize it with `text-generation-server quantize ORIGINAL_MODEL_ID NEW_MODEL_ID`"
)
if self.quantize == "gptq" and self.quant_method == "gptq":
g_idx = weights.get_tensor(f"{prefix}.g_idx")
else:
g_idx = None
from text_generation_server.layers.gptq import (
HAS_EXLLAMA,
CAN_EXLLAMA,
GPTQWeight,
)
if use_exllama:
if not HAS_EXLLAMA:
if CAN_EXLLAMA:
log_once(
logger.warning,
"Exllama GPTQ cuda kernels (which are faster) could have been used, but are not currently installed, try using BUILD_EXTENSIONS=True",
)
use_exllama = False
else:
log_once(logger.info, f"Using exllama kernels v{HAS_EXLLAMA}")
qzeros = weights.get_tensor(f"{prefix}.qzeros")
scales = weights.get_tensor(f"{prefix}.scales")
if use_exllama and g_idx is not None:
g_idx = g_idx - g_idx[0]
if self.quantize == "gptq" and self.quant_method == "awq":
log_once(
logger.info, "Converting AWQ model to Exllama/GPTQ packing format."
)
from text_generation_server.layers.awq.conversion_utils import (
fast_awq_to_gptq,
)
qweight, qzeros = fast_awq_to_gptq(qweight, qzeros)
if use_exllama:
g_idx = None
else:
g_idx = (
torch.arange(
qweight.shape[0] * (32 // self.bits),
device=qweight.device,
)
// self.groupsize
).to(dtype=torch.int32)
return GPTQWeight(
qweight=qweight,
qzeros=qzeros,
scales=scales,
g_idx=g_idx,
bits=self.bits,
groupsize=self.groupsize,
use_exllama=use_exllama,
)
def get_weights_col_packed( def get_weights_col_packed(
self, self,
weights: Weights, weights: Weights,
......
...@@ -33,6 +33,35 @@ class MarlinWeightsLoader(WeightsLoader): ...@@ -33,6 +33,35 @@ class MarlinWeightsLoader(WeightsLoader):
self.bits = bits self.bits = bits
self.is_marlin_24 = is_marlin_24 self.is_marlin_24 = is_marlin_24
def get_weights(self, weights: "Weights", prefix: str):
"""
Get weights at the given prefix and apply without tensor paralllism.
"""
is_marlin_24 = getattr(self, "gptq_checkpoint_format", None) == "marlin_24"
if is_marlin_24:
try:
B = weights.get_tensor(f"{prefix}.B_24")
except RuntimeError:
raise RuntimeError(
"Cannot load `marlin` 2:4 sparsity weight, make sure the model is already quantized."
)
B_meta = weights.get_tensor(f"{prefix}.B_meta")
s = weights.get_tensor(f"{prefix}.s")
weight = GPTQMarlin24Weight(B=B, B_meta=B_meta, s=s, bits=self.bits)
else:
try:
B = weights.get_tensor(f"{prefix}.B")
except RuntimeError:
raise RuntimeError(
"Cannot load `marlin` weight, make sure the model is already quantized."
)
s = weights.get_tensor(f"{prefix}.s")
weight = MarlinWeight(B=B, s=s)
return weight
def get_weights_col_packed( def get_weights_col_packed(
self, self,
weights: Weights, weights: Weights,
......
import os import os
import torch import torch
from torch import nn from torch import nn
from loguru import logger
from text_generation_server.utils.import_utils import SYSTEM from text_generation_server.utils.import_utils import SYSTEM
...@@ -97,6 +98,8 @@ class PositionRotaryEmbedding(nn.Module): ...@@ -97,6 +98,8 @@ class PositionRotaryEmbedding(nn.Module):
) )
elif rope_scaling["type"] == "yarn": elif rope_scaling["type"] == "yarn":
scaling_factor = rope_scaling["factor"] scaling_factor = rope_scaling["factor"]
mscale = rope_scaling.get("mscale", 1.0)
mscale_all_dim = rope_scaling.get("mscale_all_dim", 0.0)
return YarnPositionRotaryEmbedding( return YarnPositionRotaryEmbedding(
dim=2 * inv_freq.shape[0], dim=2 * inv_freq.shape[0],
max_position_embeddings=rope_scaling[ max_position_embeddings=rope_scaling[
...@@ -109,6 +112,8 @@ class PositionRotaryEmbedding(nn.Module): ...@@ -109,6 +112,8 @@ class PositionRotaryEmbedding(nn.Module):
attn_factor=1, attn_factor=1,
beta_fast=32, beta_fast=32,
beta_slow=1, beta_slow=1,
mscale=mscale,
mscale_all_dim=mscale_all_dim,
) )
elif rope_scaling["type"] in ["su", "longrope"]: elif rope_scaling["type"] in ["su", "longrope"]:
short_factor = torch.tensor( short_factor = torch.tensor(
...@@ -181,6 +186,8 @@ class PositionRotaryEmbedding(nn.Module): ...@@ -181,6 +186,8 @@ class PositionRotaryEmbedding(nn.Module):
scaling_factor=scaling_factor, scaling_factor=scaling_factor,
) )
elif rope_scaling["type"] == "yarn": elif rope_scaling["type"] == "yarn":
mscale = rope_scaling.get("mscale", 1.0)
mscale_all_dim = rope_scaling.get("mscale_all_dim", 0.0)
return YarnPositionRotaryEmbedding( return YarnPositionRotaryEmbedding(
dim=2 * inv_freq.shape[0], dim=2 * inv_freq.shape[0],
max_position_embeddings=rope_scaling[ max_position_embeddings=rope_scaling[
...@@ -193,6 +200,8 @@ class PositionRotaryEmbedding(nn.Module): ...@@ -193,6 +200,8 @@ class PositionRotaryEmbedding(nn.Module):
attn_factor=1, attn_factor=1,
beta_fast=32, beta_fast=32,
beta_slow=1, beta_slow=1,
mscale=mscale,
mscale_all_dim=mscale_all_dim,
) )
else: else:
raise NotImplementedError( raise NotImplementedError(
...@@ -346,10 +355,10 @@ def linear_ramp_mask(min, max, dim): ...@@ -346,10 +355,10 @@ def linear_ramp_mask(min, max, dim):
return ramp_func return ramp_func
def get_mscale(scale=1): def get_mscale(scale: float = 1.0, mscale: float = 1.0):
if scale <= 1: if scale <= 1:
return 1.0 return 1.0
return 0.1 * math.log(scale) + 1.0 return 0.1 * mscale * math.log(scale) + 1.0
class YarnPositionRotaryEmbedding(PositionRotaryEmbedding): class YarnPositionRotaryEmbedding(PositionRotaryEmbedding):
...@@ -365,6 +374,8 @@ class YarnPositionRotaryEmbedding(PositionRotaryEmbedding): ...@@ -365,6 +374,8 @@ class YarnPositionRotaryEmbedding(PositionRotaryEmbedding):
attn_factor, attn_factor,
beta_fast, beta_fast,
beta_slow, beta_slow,
mscale: float,
mscale_all_dim: float,
): ):
inv_freq = _create_inv_freq(dim, base, device) inv_freq = _create_inv_freq(dim, base, device)
super().__init__(inv_freq, scaling_factor) super().__init__(inv_freq, scaling_factor)
...@@ -375,8 +386,12 @@ class YarnPositionRotaryEmbedding(PositionRotaryEmbedding): ...@@ -375,8 +386,12 @@ class YarnPositionRotaryEmbedding(PositionRotaryEmbedding):
self.attn_factor = attn_factor self.attn_factor = attn_factor
self.beta_fast = beta_fast self.beta_fast = beta_fast
self.beta_slow = beta_slow self.beta_slow = beta_slow
self.mscale_all_dim = mscale_all_dim
self.scaling_factor = scaling_factor
self.mscale = float( self.mscale = float(
get_mscale(self.scaling_factor) * self.attn_factor get_mscale(self.scaling_factor, mscale)
/ get_mscale(self.scaling_factor, mscale_all_dim)
* self.attn_factor
) # Get n-d magnitude scaling corrected for interpolation ) # Get n-d magnitude scaling corrected for interpolation
def _update_cos_sin_cache(self, dtype, device, seqlen): def _update_cos_sin_cache(self, dtype, device, seqlen):
...@@ -387,7 +402,7 @@ class YarnPositionRotaryEmbedding(PositionRotaryEmbedding): ...@@ -387,7 +402,7 @@ class YarnPositionRotaryEmbedding(PositionRotaryEmbedding):
or self._cos_cached.device != device or self._cos_cached.device != device
or self._cos_cached.dtype != dtype or self._cos_cached.dtype != dtype
): ):
if seqlen > self.max_position_embeddings: if seqlen > self.max_position_embeddings or True:
inv_freq_extrapolation = _create_inv_freq( inv_freq_extrapolation = _create_inv_freq(
self.dim, self.base, self.inv_freq.device self.dim, self.base, self.inv_freq.device
) )
...@@ -400,6 +415,7 @@ class YarnPositionRotaryEmbedding(PositionRotaryEmbedding): ...@@ -400,6 +415,7 @@ class YarnPositionRotaryEmbedding(PositionRotaryEmbedding):
self.base, self.base,
self.max_position_embeddings, self.max_position_embeddings,
) )
inv_freq_mask = ( inv_freq_mask = (
1 - linear_ramp_mask(low, high, self.dim // 2).float().to(device) 1 - linear_ramp_mask(low, high, self.dim // 2).float().to(device)
) * self.extrapolation_factor # Get n-d rotational scaling corrected for extrapolation ) * self.extrapolation_factor # Get n-d rotational scaling corrected for extrapolation
...@@ -409,9 +425,6 @@ class YarnPositionRotaryEmbedding(PositionRotaryEmbedding): ...@@ -409,9 +425,6 @@ class YarnPositionRotaryEmbedding(PositionRotaryEmbedding):
) )
self.inv_freq = inv_freq self.inv_freq = inv_freq
self.mscale = float(
get_mscale(self.scaling_factor) * self.attn_factor
) # Get n-d magnitude scaling corrected for interpolation
self._seq_len_cached = seqlen self._seq_len_cached = seqlen
t = torch.arange(seqlen, device=device, dtype=self.inv_freq.dtype) t = torch.arange(seqlen, device=device, dtype=self.inv_freq.dtype)
......
...@@ -61,6 +61,10 @@ FLASH_ATTENTION = True ...@@ -61,6 +61,10 @@ FLASH_ATTENTION = True
try: try:
from text_generation_server.models.flash_causal_lm import FlashCausalLM from text_generation_server.models.flash_causal_lm import FlashCausalLM
from text_generation_server.models.vlm_causal_lm import VlmCausalLM from text_generation_server.models.vlm_causal_lm import VlmCausalLM
from text_generation_server.models.custom_modeling.flash_deepseek_v2_modeling import (
FlashDeepseekV2ForCausalLM,
DeepseekV2Config,
)
from text_generation_server.models.custom_modeling.flash_llama_modeling import ( from text_generation_server.models.custom_modeling.flash_llama_modeling import (
FlashLlamaForCausalLM, FlashLlamaForCausalLM,
) )
...@@ -141,6 +145,11 @@ if MAMBA_AVAILABLE: ...@@ -141,6 +145,11 @@ if MAMBA_AVAILABLE:
class ModelType(enum.Enum): class ModelType(enum.Enum):
DEEPSEEK_V2 = {
"type": "deepseek_v2",
"name": "Deepseek V2",
"url": "https://huggingface.co/deepseek-ai/DeepSeek-V2",
}
IDEFICS2 = { IDEFICS2 = {
"type": "idefics2", "type": "idefics2",
"name": "Idefics 2", "name": "Idefics 2",
...@@ -459,7 +468,40 @@ def get_model( ...@@ -459,7 +468,40 @@ def get_model(
f"The backend {SYSTEM} does not support sliding window attention that is used by the model type {model_type}. To use this model nonetheless with the {SYSTEM} backend, please launch TGI with the argument `--max-input-tokens` smaller than sliding_window={sliding_window} (got here max_input_tokens={max_input_tokens})." f"The backend {SYSTEM} does not support sliding window attention that is used by the model type {model_type}. To use this model nonetheless with the {SYSTEM} backend, please launch TGI with the argument `--max-input-tokens` smaller than sliding_window={sliding_window} (got here max_input_tokens={max_input_tokens})."
) )
if model_type == MAMBA: if model_type == DEEPSEEK_V2:
if FLASH_ATTENTION:
head_size = max(
config_dict.get("qk_nope_dim", 128)
+ config_dict.get("qk_rope_dim", 64),
config_dict.get("v_head_dim", 128),
)
return FlashCausalLM(
model_id=model_id,
model_class=FlashDeepseekV2ForCausalLM,
revision=revision,
quantize=quantize,
speculator=speculator,
default_dtype=torch.bfloat16,
dtype=dtype,
trust_remote_code=trust_remote_code,
lora_adapter_ids=lora_adapter_ids,
config_class=DeepseekV2Config,
head_size=head_size,
)
elif sharded:
raise NotImplementedError(
FLASH_ATT_ERROR_MESSAGE.format("Sharded Deepseek V2")
)
else:
return CausalLM.fallback(
model_id,
revision,
quantize=quantize,
speculator=speculator,
dtype=dtype,
trust_remote_code=trust_remote_code,
)
elif model_type == MAMBA:
return Mamba( return Mamba(
model_id, model_id,
revision, revision,
......
...@@ -839,7 +839,9 @@ class FlashCausalLM(Model): ...@@ -839,7 +839,9 @@ class FlashCausalLM(Model):
default_dtype=torch.float16, default_dtype=torch.float16,
aliases=None, aliases=None,
# Used for Santacoder override of config # Used for Santacoder override of config
num_kv_heads=None, num_kv_heads: Optional[int] = None,
# Deepseek V2 uses different QK and V dims.
head_size: Optional[int] = None,
skip_special_tokens: bool = True, skip_special_tokens: bool = True,
): ):
self.process_group, rank, world_size = initialize_torch_distributed() self.process_group, rank, world_size = initialize_torch_distributed()
...@@ -922,7 +924,11 @@ class FlashCausalLM(Model): ...@@ -922,7 +924,11 @@ class FlashCausalLM(Model):
else num_kv_heads else num_kv_heads
) )
assert self.num_kv_heads > 0 assert self.num_kv_heads > 0
if head_size is None:
self.head_size = config.hidden_size // config.num_attention_heads self.head_size = config.hidden_size // config.num_attention_heads
else:
self.head_size = head_size
self.cuda_graphs = {} self.cuda_graphs = {}
self.kv_cache = [] self.kv_cache = []
......
...@@ -21,6 +21,13 @@ class WeightsLoader(ABC): ...@@ -21,6 +21,13 @@ class WeightsLoader(ABC):
with the format, etc. with the format, etc.
""" """
@abstractmethod
def get_weights(self, weights: "Weights", prefix: str):
"""
Get weights at the given prefix and apply without tensor paralllism.
"""
...
@abstractmethod @abstractmethod
def get_weights_col_packed( def get_weights_col_packed(
self, self,
...@@ -104,6 +111,9 @@ class DefaultWeightsLoader(WeightsLoader): ...@@ -104,6 +111,9 @@ class DefaultWeightsLoader(WeightsLoader):
and/or concatenation. and/or concatenation.
""" """
def get_weights(self, weights: "Weights", prefix: str):
return weights.get_tensor(f"{prefix}.weight")
def get_weights_col_packed( def get_weights_col_packed(
self, self,
weights: "Weights", weights: "Weights",
...@@ -299,6 +309,9 @@ class Weights: ...@@ -299,6 +309,9 @@ class Weights:
return tensor return tensor
def get_weights(self, prefix: str):
return self.weights_loader.get_weights(self, prefix)
def get_weights_col_packed_qkv( def get_weights_col_packed_qkv(
self, self,
prefix: str, prefix: str,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment