Commit 5a1cf2f0 authored by huangwb's avatar huangwb
Browse files

Merge tag 'v2.0.2' into dev-rocm

parents 24f58bb6 6073ece4
......@@ -2,6 +2,7 @@ import math
import torch
from typing import Optional, List, Tuple
from text_generation_server.utils.import_utils import IS_XPU_SYSTEM
BLOCK_SIZE: int = 16
# Will be set in warmup
......@@ -24,7 +25,10 @@ class CacheManager:
self.repeat_slots = repeat_slots
element_size = torch.tensor([], dtype=dtype).element_size()
x = self.block_size // element_size
if IS_XPU_SYSTEM:
x = 1
else:
x = self.block_size // element_size
self.kv_cache = [
(
......
......@@ -21,8 +21,10 @@ from transformers.activations import ACT2FN
from transformers.configuration_utils import PretrainedConfig
from typing import Optional, List, Tuple, Any
from loguru import logger
from text_generation_server.utils.import_utils import IS_XPU_SYSTEM
from vllm.model_executor.layers.fused_moe import fused_moe
if not IS_XPU_SYSTEM:
from vllm.model_executor.layers.fused_moe import fused_moe
from text_generation_server.utils import paged_attention, flash_attn
from text_generation_server.utils.layers import (
FastLinear,
......
......@@ -38,58 +38,6 @@ from text_generation_server.utils.layers import (
)
class LlamaConfig(PretrainedConfig):
def __init__(
self,
vocab_size=32000,
hidden_size=4096,
intermediate_size=11008,
num_hidden_layers=32,
num_attention_heads=32,
num_key_value_heads=None,
hidden_act="silu",
max_position_embeddings=2048,
initializer_range=0.02,
rms_norm_eps=1e-6,
use_cache=True,
pad_token_id=0,
bos_token_id=1,
eos_token_id=2,
pretraining_tp=1,
tie_word_embeddings=False,
rope_scaling=None,
rope_theta=10000.0,
**kwargs,
):
self.vocab_size = vocab_size
self.max_position_embeddings = max_position_embeddings
self.hidden_size = hidden_size
self.intermediate_size = intermediate_size
self.num_hidden_layers = num_hidden_layers
self.num_attention_heads = num_attention_heads
# for backward compatibility
if num_key_value_heads is None:
num_key_value_heads = num_attention_heads
self.num_key_value_heads = num_key_value_heads
self.hidden_act = hidden_act
self.initializer_range = initializer_range
self.rms_norm_eps = rms_norm_eps
self.pretraining_tp = pretraining_tp
self.use_cache = use_cache
self.rope_scaling = rope_scaling
self.rope_theta = rope_theta
super().__init__(
pad_token_id=pad_token_id,
bos_token_id=bos_token_id,
eos_token_id=eos_token_id,
tie_word_embeddings=tie_word_embeddings,
**kwargs,
)
def load_attention(config, prefix, weights):
if config.num_attention_heads != config.num_key_value_heads:
return _load_gqa(config, prefix, weights)
......@@ -101,6 +49,13 @@ def load_attention(config, prefix, weights):
weights=weights,
bias=False,
)
elif config.model_type == "phi3":
return TensorParallelColumnLinear.load_qkv(
config,
prefix=f"{prefix}.qkv_proj",
weights=weights,
bias=False,
)
else:
return TensorParallelColumnLinear.load_multi(
config,
......@@ -257,13 +212,21 @@ class LlamaMLP(nn.Module):
)
)
# Fuse gate and up proj
self.gate_up_proj = TensorParallelColumnLinear.load_multi(
config,
prefixes=[f"{prefix}.gate_proj", f"{prefix}.up_proj"],
weights=weights,
dim=0,
bias=False,
)
if config.model_type == "phi3":
self.gate_up_proj = TensorParallelColumnLinear.load_gate_up(
config,
prefix=f"{prefix}.gate_up_proj",
weights=weights,
bias=False,
)
else:
self.gate_up_proj = TensorParallelColumnLinear.load_multi(
config,
prefixes=[f"{prefix}.gate_proj", f"{prefix}.up_proj"],
weights=weights,
dim=0,
bias=False,
)
self.down_proj = TensorParallelRowLinear.load(
config,
prefix=f"{prefix}.down_proj",
......
......@@ -409,23 +409,29 @@ class MistralModel(torch.nn.Module):
class FlashMistralForCausalLM(torch.nn.Module):
def __init__(self, prefix, config, weights):
def __init__(self, prefix, config, weights, name=None):
if name is None:
name = "model"
super().__init__()
self.embed_tokens = TensorParallelEmbedding(
prefix=(
"model.embed_tokens" if not prefix else f"{prefix}.model.embed_tokens"
f"{name}.embed_tokens"
if not prefix
else f"{prefix}.{name}.embed_tokens"
),
weights=weights,
)
self.model = MistralModel(
prefix="model" if not prefix else f"{prefix}.model",
prefix=name if not prefix else f"{prefix}.{name}",
config=config,
weights=weights,
)
self.lm_head = SpeculativeHead.load(
config,
prefix="lm_head" if not prefix else f"{prefix}.lm_head",
# TODO dirty hack for idefics2.
prefix=(
"lm_head" if not prefix or name != "model" else f"{prefix}.lm_head"
),
weights=weights,
)
self.max_past = config.sliding_window
......
......@@ -24,7 +24,10 @@ import torch.distributed
import numpy as np
from torch import nn
from vllm.model_executor.layers.fused_moe import fused_moe
from text_generation_server.utils.import_utils import IS_XPU_SYSTEM
if not IS_XPU_SYSTEM:
from vllm.model_executor.layers.fused_moe import fused_moe
from transformers.activations import ACT2FN
from transformers.configuration_utils import PretrainedConfig
from typing import Optional, List, Tuple
......
......@@ -23,6 +23,10 @@ from torch import nn
from transformers.activations import ACT2FN
from transformers.image_processing_utils import select_best_resolution
from text_generation_server.models.custom_modeling.vlm import (
load_text_model,
load_vision_model,
)
from text_generation_server.utils.layers import (
TensorParallelColumnLinear,
TensorParallelRowLinear,
......@@ -105,36 +109,6 @@ class LlavaNextMultiModalProjector(nn.Module):
return hidden_states
def load_vision_model(prefix, config, weights):
if config.model_type == "clip_vision_model":
from text_generation_server.models.custom_modeling.clip import (
CLIPVisionTransformer,
)
return CLIPVisionTransformer(
prefix=f"{prefix}.vision_model", config=config, weights=weights
)
else:
raise RuntimeError(f"Unsupported model type {config.model_type}")
def load_text_model(prefix, config, weights):
if config.model_type == "llama":
from text_generation_server.models.custom_modeling.flash_llama_modeling import (
FlashLlamaForCausalLM,
)
return FlashLlamaForCausalLM(prefix, config, weights)
elif config.model_type == "mistral":
from text_generation_server.models.custom_modeling.flash_mistral_modeling import (
FlashMistralForCausalLM,
)
return FlashMistralForCausalLM(prefix, config, weights)
else:
raise RuntimeError(f"Unsupported model type {config.model_type}")
class LlavaNextForConditionalGeneration(nn.Module):
def __init__(self, prefix, config, weights):
super().__init__()
......@@ -180,7 +154,12 @@ class LlavaNextForConditionalGeneration(nn.Module):
"""In place merges in vision_embeddings with inputs_embeds."""
mask = input_ids == self.config.image_token_index
# Let's pray we have enabled enough slots !
inputs_embeds[mask] = image_features.view(-1, image_features.shape[-1])
try:
inputs_embeds[mask] = image_features.view(-1, image_features.shape[-1])
except Exception as e:
raise RuntimeError(
f"Cannot fill images right now. If error happens at warmup, make sure you have enough `--max-input-tokens` to handle images. If error happens at regular runtime, please fill in an issue: {e}"
)
return inputs_embeds
def forward(
......@@ -196,6 +175,8 @@ class LlavaNextForConditionalGeneration(nn.Module):
prefill_cache_indices: Optional[torch.Tensor],
lm_head_indices: Optional[torch.Tensor] = None,
pixel_values: torch.FloatTensor = None,
# Unused for this model
pixel_attention_mask=None,
image_sizes: Optional[torch.LongTensor] = None,
):
inputs_embeds = self.language_model.embed_tokens(input_ids)
......
def load_text_model(prefix, config, weights, name=None):
if config.model_type == "llama":
from text_generation_server.models.custom_modeling.flash_llama_modeling import (
FlashLlamaForCausalLM,
)
return FlashLlamaForCausalLM(prefix, config, weights)
elif config.model_type == "mistral":
from text_generation_server.models.custom_modeling.flash_mistral_modeling import (
FlashMistralForCausalLM,
)
return FlashMistralForCausalLM(prefix, config, weights, name=name)
else:
raise RuntimeError(f"Unsupported model type {config.model_type}")
def load_vision_model(prefix, config, weights):
if config.model_type == "clip_vision_model":
from text_generation_server.models.custom_modeling.clip import (
CLIPVisionTransformer,
)
return CLIPVisionTransformer(
prefix=f"{prefix}.vision_model", config=config, weights=weights
)
else:
raise RuntimeError(f"Unsupported model type {config.model_type}")
......@@ -33,6 +33,11 @@ from text_generation_server.utils import StoppingCriteria, HeterogeneousNextToke
from text_generation_server.utils.dist import MEMORY_FRACTION
tracer = trace.get_tracer(__name__)
from text_generation_server.utils.import_utils import (
IS_CUDA_SYSTEM,
IS_ROCM_SYSTEM,
IS_XPU_SYSTEM,
)
@dataclass
......@@ -754,7 +759,10 @@ class FlashCausalLM(Model):
def warmup(self, batch: FlashCausalLMBatch):
# The warmup batch is the biggest batch we could ever receive
torch.cuda.empty_cache()
if IS_CUDA_SYSTEM or IS_ROCM_SYSTEM:
torch.cuda.empty_cache()
elif IS_XPU_SYSTEM:
torch.xpu.empty_cache()
try:
cache_manager = set_cache_manager(
batch.blocks,
......@@ -774,7 +782,10 @@ class FlashCausalLM(Model):
f"You need to decrease `--max-batch-prefill-tokens`"
) from e
torch.cuda.synchronize(self.device)
if IS_CUDA_SYSTEM or IS_ROCM_SYSTEM:
torch.cuda.synchronize(self.device)
elif IS_XPU_SYSTEM:
torch.xpu.synchronize(self.device)
# Inspired by the original implementation in [vllm](https://github.com/vllm-project/vllm)
# Calculate the number of blocks that can be allocated with the free memory
......@@ -782,12 +793,20 @@ class FlashCausalLM(Model):
cache_block_size = BLOCK_SIZE * self.num_kv_heads * self.head_size
total_cache_size = self.num_layers * cache_block_size * 2 * dtype_size
total_free_memory, _ = torch.cuda.mem_get_info(self.device)
total_gpu_memory = torch.cuda.get_device_properties(self.device).total_memory
if IS_CUDA_SYSTEM or IS_ROCM_SYSTEM:
total_free_memory, _ = torch.cuda.mem_get_info(self.device)
total_gpu_memory = torch.cuda.get_device_properties(
self.device
).total_memory
free_memory = max(
0, total_free_memory - (1 - MEMORY_FRACTION) * total_gpu_memory
)
free_memory = max(
0, total_free_memory - (1 - MEMORY_FRACTION) * total_gpu_memory
)
elif IS_XPU_SYSTEM:
total_gpu_memory = torch.xpu.get_device_properties(self.device).total_memory
free_memory = int(total_gpu_memory * 0.5)
else:
raise NotImplementedError("FlashModel is only available on GPU")
num_blocks = (
# Leave 5% for some wiggle room
......@@ -818,6 +837,8 @@ class FlashCausalLM(Model):
self.cuda_graph_warmup(bs, max_s, max_bt)
except torch.cuda.OutOfMemoryError:
logger.exception(f"Decode cuda graph warmup failed")
else:
logger.info(f"Cuda Graphs are disabled (CUDA_GRAPHS={CUDA_GRAPHS}).")
return int(num_blocks * BLOCK_SIZE)
......
......@@ -2,14 +2,13 @@ import torch
import torch.distributed
from opentelemetry import trace
from transformers import AutoConfig, AutoTokenizer
from transformers import AutoConfig, AutoTokenizer, GenerationConfig
from transformers.models.llama import LlamaTokenizer
from typing import Optional
from text_generation_server.models import FlashCausalLM
from text_generation_server.models.custom_modeling.flash_llama_modeling import (
FlashLlamaForCausalLM,
LlamaConfig,
)
from text_generation_server.utils import (
initialize_torch_distributed,
......@@ -19,6 +18,8 @@ from text_generation_server.utils import (
tracer = trace.get_tracer(__name__)
from text_generation_server.utils.import_utils import IS_XPU_SYSTEM
class FlashLlama(FlashCausalLM):
def __init__(
......@@ -34,6 +35,9 @@ class FlashLlama(FlashCausalLM):
if torch.cuda.is_available():
device = torch.device(f"cuda:{rank}")
dtype = torch.float16 if dtype is None else dtype
elif IS_XPU_SYSTEM:
device = torch.device(f"xpu:{rank}")
dtype = torch.float16 if dtype is None else dtype
else:
raise NotImplementedError("FlashLlama is only available on GPU")
......@@ -53,8 +57,17 @@ class FlashLlama(FlashCausalLM):
truncation_side="left",
trust_remote_code=trust_remote_code,
)
try:
generation_config = GenerationConfig.from_pretrained(
model_id, revision=revision, trust_remote_code=trust_remote_code
)
if isinstance(generation_config.eos_token_id, (list, set)):
# TODO Huge hack
tokenizer._eos_token_ids = set(generation_config.eos_token_id)
except Exception:
pass
config = LlamaConfig.from_pretrained(
config = AutoConfig.from_pretrained(
model_id, revision=revision, trust_remote_code=trust_remote_code
)
config.quantize = quantize
......
......@@ -33,8 +33,9 @@ tracer = trace.get_tracer(__name__)
# Will be set in init
SLIDING_WINDOW: Optional[int] = None
SLIDING_WINDOW_BLOCKS: Optional[int] = None
from text_generation_server.utils.import_utils import IS_XPU_SYSTEM
MEM_POOL = torch.cuda.graph_pool_handle()
MEM_POOL = torch.cuda.graph_pool_handle() if torch.cuda.is_available() else None
def set_sliding_window(sliding_window: int, sliding_window_blocks: int):
......@@ -316,6 +317,9 @@ class BaseFlashMistral(FlashCausalLM):
if torch.cuda.is_available():
device = torch.device(f"cuda:{rank}")
dtype = torch.float16 if dtype is None else dtype
elif IS_XPU_SYSTEM:
device = torch.device(f"xpu:{rank}")
dtype = torch.float16 if dtype is None else dtype
else:
raise NotImplementedError("FlashMistral is only available on GPU")
......
......@@ -14,6 +14,7 @@ from text_generation_server.utils import (
weight_files,
Weights,
)
from text_generation_server.utils.import_utils import IS_XPU_SYSTEM
tracer = trace.get_tracer(__name__)
......@@ -32,6 +33,9 @@ class FlashNeoXSharded(FlashCausalLM):
if torch.cuda.is_available():
device = torch.device(f"cuda:{rank}")
dtype = torch.float16 if dtype is None else dtype
elif IS_XPU_SYSTEM:
device = torch.device(f"xpu:{rank}")
dtype = torch.float16 if dtype is None else dtype
else:
raise NotImplementedError("FlashNeoX is only available on GPU")
......
......@@ -4,7 +4,7 @@ import torch
import torch.distributed
from opentelemetry import trace
from transformers.models.qwen2 import Qwen2Tokenizer
from transformers import AutoTokenizer, AutoConfig
from typing import Optional
from text_generation_server.models.cache_manager import BLOCK_SIZE
......@@ -15,7 +15,6 @@ from text_generation_server.models.flash_mistral import (
from text_generation_server.models.custom_modeling.flash_qwen2_modeling import (
Qwen2ForCausalLM,
)
from transformers.models.qwen2 import Qwen2Config
from text_generation_server.utils import (
initialize_torch_distributed,
weight_files,
......@@ -42,7 +41,7 @@ class FlashQwen2(BaseFlashMistral):
else:
raise NotImplementedError("FlashQwen2 is only available on GPU")
tokenizer = Qwen2Tokenizer.from_pretrained(
tokenizer = AutoTokenizer.from_pretrained(
model_id,
revision=revision,
padding_side="left",
......@@ -50,7 +49,7 @@ class FlashQwen2(BaseFlashMistral):
trust_remote_code=trust_remote_code,
)
config = Qwen2Config.from_pretrained(
config = AutoConfig.from_pretrained(
model_id, revision=revision, trust_remote_code=trust_remote_code
)
config.quantize = quantize
......
......@@ -15,6 +15,7 @@ from text_generation_server.utils import (
weight_files,
Weights,
)
from text_generation_server.utils.import_utils import IS_XPU_SYSTEM
tracer = trace.get_tracer(__name__)
......@@ -33,6 +34,9 @@ class FlashRWSharded(FlashCausalLM):
if torch.cuda.is_available():
device = torch.device(f"cuda:{rank}")
dtype = torch.float16 if dtype is None else dtype
elif IS_XPU_SYSTEM:
device = torch.device(f"xpu:{rank}")
dtype = torch.float16 if dtype is None else dtype
else:
raise NotImplementedError("FlashRW is only available on GPU")
......
......@@ -18,6 +18,8 @@ from text_generation_server.utils import (
Weights,
)
from text_generation_server.utils.import_utils import IS_XPU_SYSTEM
tracer = trace.get_tracer(__name__)
......@@ -35,6 +37,9 @@ class FlashSantacoderSharded(FlashCausalLM):
if torch.cuda.is_available():
device = torch.device(f"cuda:{rank}")
dtype = torch.float16 if dtype is None else dtype
elif IS_XPU_SYSTEM:
device = torch.device(f"xpu:{rank}")
dtype = torch.float16 if dtype is None else dtype
else:
raise NotImplementedError("FlashSantacoderSharded is only available on GPU")
......
import torch
import os
MEM_POOL = torch.cuda.graph_pool_handle()
MEM_POOL = torch.cuda.graph_pool_handle() if torch.cuda.is_available() else None
# This is overridden by the cli
cuda_graphs = os.getenv("CUDA_GRAPHS")
if cuda_graphs is not None:
......@@ -11,4 +11,7 @@ if cuda_graphs is not None:
raise RuntimeError(
f"Could not parse cuda graphs {cuda_graphs}, expected comma separated list for batch sizes to run on: {e}"
)
else:
cuda_graphs = None
CUDA_GRAPHS = cuda_graphs
import torch
from typing import Optional, Tuple
from transformers import (
AutoProcessor,
)
from text_generation_server.models.custom_modeling.idefics2 import (
Idefics2ForConditionalGeneration,
)
from text_generation_server.models.vlm_causal_lm import VlmCausalLM
class Idefics2(VlmCausalLM):
def __init__(
self,
model_id: str,
revision: Optional[str] = None,
quantize: Optional[str] = None,
use_medusa: Optional[str] = None,
dtype: Optional[torch.dtype] = None,
trust_remote_code: bool = False,
):
self.processor = AutoProcessor.from_pretrained(
model_id,
revision=revision,
trust_remote_code=trust_remote_code,
# XXX: Extremely important to cap resolution in order to limit
# VRAM usage.
size={"longest_edge": 448, "shortest_edge": 378},
)
super().__init__(
model_cls=Idefics2ForConditionalGeneration,
model_id=model_id,
revision=revision,
quantize=quantize,
use_medusa=use_medusa,
dtype=dtype,
trust_remote_code=trust_remote_code,
)
def get_layer_config(self, model) -> Tuple[int, int, int]:
return (
len(model.text_model.model.layers),
model.text_model.model.num_key_value_heads,
model.text_model.model.head_size,
)
def max_past(self) -> Optional[int]:
return getattr(self.model.text_model, "max_past", None)
import torch
from typing import Optional
from typing import Optional, Tuple
from transformers import (
AutoProcessor,
......@@ -34,3 +34,13 @@ class LlavaNext(VlmCausalLM):
dtype=dtype,
trust_remote_code=trust_remote_code,
)
def get_layer_config(self, model) -> Tuple[int, int, int]:
return (
len(model.language_model.model.layers),
model.language_model.model.num_key_value_heads,
model.language_model.model.head_size,
)
def max_past(self) -> Optional[int]:
return getattr(self.model.language_model, "max_past", None)
......@@ -474,6 +474,8 @@ class Mamba(Model):
self.cuda_graph_warmup(bs)
except Exception:
logger.exception(f"Decode cuda graph warmup failed")
else:
logger.info(f"Cuda Graphs are disabled (CUDA_GRAPHS={CUDA_GRAPHS}).")
return None
......
......@@ -27,7 +27,14 @@ class Model(ABC):
):
self.model = model.eval()
self.tokenizer = tokenizer
# all_special_ids is not set correctly if the rust tokenizer is unpacked
# TODO report this to transformers.
other_special_ids = {
id for id, token in tokenizer.added_tokens_decoder.items() if token.special
}
self.all_special_ids = set(tokenizer.all_special_ids)
self.all_special_ids.update(other_special_ids)
self.requires_padding = requires_padding
self.dtype = dtype
self.device = device
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment