Commit 5a1cf2f0 authored by huangwb's avatar huangwb
Browse files

Merge tag 'v2.0.2' into dev-rocm

parents 24f58bb6 6073ece4
...@@ -2,6 +2,7 @@ import math ...@@ -2,6 +2,7 @@ import math
import torch import torch
from typing import Optional, List, Tuple from typing import Optional, List, Tuple
from text_generation_server.utils.import_utils import IS_XPU_SYSTEM
BLOCK_SIZE: int = 16 BLOCK_SIZE: int = 16
# Will be set in warmup # Will be set in warmup
...@@ -24,7 +25,10 @@ class CacheManager: ...@@ -24,7 +25,10 @@ class CacheManager:
self.repeat_slots = repeat_slots self.repeat_slots = repeat_slots
element_size = torch.tensor([], dtype=dtype).element_size() element_size = torch.tensor([], dtype=dtype).element_size()
x = self.block_size // element_size if IS_XPU_SYSTEM:
x = 1
else:
x = self.block_size // element_size
self.kv_cache = [ self.kv_cache = [
( (
......
...@@ -21,8 +21,10 @@ from transformers.activations import ACT2FN ...@@ -21,8 +21,10 @@ from transformers.activations import ACT2FN
from transformers.configuration_utils import PretrainedConfig from transformers.configuration_utils import PretrainedConfig
from typing import Optional, List, Tuple, Any from typing import Optional, List, Tuple, Any
from loguru import logger from loguru import logger
from text_generation_server.utils.import_utils import IS_XPU_SYSTEM
from vllm.model_executor.layers.fused_moe import fused_moe if not IS_XPU_SYSTEM:
from vllm.model_executor.layers.fused_moe import fused_moe
from text_generation_server.utils import paged_attention, flash_attn from text_generation_server.utils import paged_attention, flash_attn
from text_generation_server.utils.layers import ( from text_generation_server.utils.layers import (
FastLinear, FastLinear,
......
...@@ -38,58 +38,6 @@ from text_generation_server.utils.layers import ( ...@@ -38,58 +38,6 @@ from text_generation_server.utils.layers import (
) )
class LlamaConfig(PretrainedConfig):
def __init__(
self,
vocab_size=32000,
hidden_size=4096,
intermediate_size=11008,
num_hidden_layers=32,
num_attention_heads=32,
num_key_value_heads=None,
hidden_act="silu",
max_position_embeddings=2048,
initializer_range=0.02,
rms_norm_eps=1e-6,
use_cache=True,
pad_token_id=0,
bos_token_id=1,
eos_token_id=2,
pretraining_tp=1,
tie_word_embeddings=False,
rope_scaling=None,
rope_theta=10000.0,
**kwargs,
):
self.vocab_size = vocab_size
self.max_position_embeddings = max_position_embeddings
self.hidden_size = hidden_size
self.intermediate_size = intermediate_size
self.num_hidden_layers = num_hidden_layers
self.num_attention_heads = num_attention_heads
# for backward compatibility
if num_key_value_heads is None:
num_key_value_heads = num_attention_heads
self.num_key_value_heads = num_key_value_heads
self.hidden_act = hidden_act
self.initializer_range = initializer_range
self.rms_norm_eps = rms_norm_eps
self.pretraining_tp = pretraining_tp
self.use_cache = use_cache
self.rope_scaling = rope_scaling
self.rope_theta = rope_theta
super().__init__(
pad_token_id=pad_token_id,
bos_token_id=bos_token_id,
eos_token_id=eos_token_id,
tie_word_embeddings=tie_word_embeddings,
**kwargs,
)
def load_attention(config, prefix, weights): def load_attention(config, prefix, weights):
if config.num_attention_heads != config.num_key_value_heads: if config.num_attention_heads != config.num_key_value_heads:
return _load_gqa(config, prefix, weights) return _load_gqa(config, prefix, weights)
...@@ -101,6 +49,13 @@ def load_attention(config, prefix, weights): ...@@ -101,6 +49,13 @@ def load_attention(config, prefix, weights):
weights=weights, weights=weights,
bias=False, bias=False,
) )
elif config.model_type == "phi3":
return TensorParallelColumnLinear.load_qkv(
config,
prefix=f"{prefix}.qkv_proj",
weights=weights,
bias=False,
)
else: else:
return TensorParallelColumnLinear.load_multi( return TensorParallelColumnLinear.load_multi(
config, config,
...@@ -257,13 +212,21 @@ class LlamaMLP(nn.Module): ...@@ -257,13 +212,21 @@ class LlamaMLP(nn.Module):
) )
) )
# Fuse gate and up proj # Fuse gate and up proj
self.gate_up_proj = TensorParallelColumnLinear.load_multi( if config.model_type == "phi3":
config, self.gate_up_proj = TensorParallelColumnLinear.load_gate_up(
prefixes=[f"{prefix}.gate_proj", f"{prefix}.up_proj"], config,
weights=weights, prefix=f"{prefix}.gate_up_proj",
dim=0, weights=weights,
bias=False, bias=False,
) )
else:
self.gate_up_proj = TensorParallelColumnLinear.load_multi(
config,
prefixes=[f"{prefix}.gate_proj", f"{prefix}.up_proj"],
weights=weights,
dim=0,
bias=False,
)
self.down_proj = TensorParallelRowLinear.load( self.down_proj = TensorParallelRowLinear.load(
config, config,
prefix=f"{prefix}.down_proj", prefix=f"{prefix}.down_proj",
......
...@@ -409,23 +409,29 @@ class MistralModel(torch.nn.Module): ...@@ -409,23 +409,29 @@ class MistralModel(torch.nn.Module):
class FlashMistralForCausalLM(torch.nn.Module): class FlashMistralForCausalLM(torch.nn.Module):
def __init__(self, prefix, config, weights): def __init__(self, prefix, config, weights, name=None):
if name is None:
name = "model"
super().__init__() super().__init__()
self.embed_tokens = TensorParallelEmbedding( self.embed_tokens = TensorParallelEmbedding(
prefix=( prefix=(
"model.embed_tokens" if not prefix else f"{prefix}.model.embed_tokens" f"{name}.embed_tokens"
if not prefix
else f"{prefix}.{name}.embed_tokens"
), ),
weights=weights, weights=weights,
) )
self.model = MistralModel( self.model = MistralModel(
prefix="model" if not prefix else f"{prefix}.model", prefix=name if not prefix else f"{prefix}.{name}",
config=config, config=config,
weights=weights, weights=weights,
) )
self.lm_head = SpeculativeHead.load( self.lm_head = SpeculativeHead.load(
config, config,
prefix="lm_head" if not prefix else f"{prefix}.lm_head", # TODO dirty hack for idefics2.
prefix=(
"lm_head" if not prefix or name != "model" else f"{prefix}.lm_head"
),
weights=weights, weights=weights,
) )
self.max_past = config.sliding_window self.max_past = config.sliding_window
......
...@@ -24,7 +24,10 @@ import torch.distributed ...@@ -24,7 +24,10 @@ import torch.distributed
import numpy as np import numpy as np
from torch import nn from torch import nn
from vllm.model_executor.layers.fused_moe import fused_moe from text_generation_server.utils.import_utils import IS_XPU_SYSTEM
if not IS_XPU_SYSTEM:
from vllm.model_executor.layers.fused_moe import fused_moe
from transformers.activations import ACT2FN from transformers.activations import ACT2FN
from transformers.configuration_utils import PretrainedConfig from transformers.configuration_utils import PretrainedConfig
from typing import Optional, List, Tuple from typing import Optional, List, Tuple
......
...@@ -23,6 +23,10 @@ from torch import nn ...@@ -23,6 +23,10 @@ from torch import nn
from transformers.activations import ACT2FN from transformers.activations import ACT2FN
from transformers.image_processing_utils import select_best_resolution from transformers.image_processing_utils import select_best_resolution
from text_generation_server.models.custom_modeling.vlm import (
load_text_model,
load_vision_model,
)
from text_generation_server.utils.layers import ( from text_generation_server.utils.layers import (
TensorParallelColumnLinear, TensorParallelColumnLinear,
TensorParallelRowLinear, TensorParallelRowLinear,
...@@ -105,36 +109,6 @@ class LlavaNextMultiModalProjector(nn.Module): ...@@ -105,36 +109,6 @@ class LlavaNextMultiModalProjector(nn.Module):
return hidden_states return hidden_states
def load_vision_model(prefix, config, weights):
if config.model_type == "clip_vision_model":
from text_generation_server.models.custom_modeling.clip import (
CLIPVisionTransformer,
)
return CLIPVisionTransformer(
prefix=f"{prefix}.vision_model", config=config, weights=weights
)
else:
raise RuntimeError(f"Unsupported model type {config.model_type}")
def load_text_model(prefix, config, weights):
if config.model_type == "llama":
from text_generation_server.models.custom_modeling.flash_llama_modeling import (
FlashLlamaForCausalLM,
)
return FlashLlamaForCausalLM(prefix, config, weights)
elif config.model_type == "mistral":
from text_generation_server.models.custom_modeling.flash_mistral_modeling import (
FlashMistralForCausalLM,
)
return FlashMistralForCausalLM(prefix, config, weights)
else:
raise RuntimeError(f"Unsupported model type {config.model_type}")
class LlavaNextForConditionalGeneration(nn.Module): class LlavaNextForConditionalGeneration(nn.Module):
def __init__(self, prefix, config, weights): def __init__(self, prefix, config, weights):
super().__init__() super().__init__()
...@@ -180,7 +154,12 @@ class LlavaNextForConditionalGeneration(nn.Module): ...@@ -180,7 +154,12 @@ class LlavaNextForConditionalGeneration(nn.Module):
"""In place merges in vision_embeddings with inputs_embeds.""" """In place merges in vision_embeddings with inputs_embeds."""
mask = input_ids == self.config.image_token_index mask = input_ids == self.config.image_token_index
# Let's pray we have enabled enough slots ! # Let's pray we have enabled enough slots !
inputs_embeds[mask] = image_features.view(-1, image_features.shape[-1]) try:
inputs_embeds[mask] = image_features.view(-1, image_features.shape[-1])
except Exception as e:
raise RuntimeError(
f"Cannot fill images right now. If error happens at warmup, make sure you have enough `--max-input-tokens` to handle images. If error happens at regular runtime, please fill in an issue: {e}"
)
return inputs_embeds return inputs_embeds
def forward( def forward(
...@@ -196,6 +175,8 @@ class LlavaNextForConditionalGeneration(nn.Module): ...@@ -196,6 +175,8 @@ class LlavaNextForConditionalGeneration(nn.Module):
prefill_cache_indices: Optional[torch.Tensor], prefill_cache_indices: Optional[torch.Tensor],
lm_head_indices: Optional[torch.Tensor] = None, lm_head_indices: Optional[torch.Tensor] = None,
pixel_values: torch.FloatTensor = None, pixel_values: torch.FloatTensor = None,
# Unused for this model
pixel_attention_mask=None,
image_sizes: Optional[torch.LongTensor] = None, image_sizes: Optional[torch.LongTensor] = None,
): ):
inputs_embeds = self.language_model.embed_tokens(input_ids) inputs_embeds = self.language_model.embed_tokens(input_ids)
......
def load_text_model(prefix, config, weights, name=None):
if config.model_type == "llama":
from text_generation_server.models.custom_modeling.flash_llama_modeling import (
FlashLlamaForCausalLM,
)
return FlashLlamaForCausalLM(prefix, config, weights)
elif config.model_type == "mistral":
from text_generation_server.models.custom_modeling.flash_mistral_modeling import (
FlashMistralForCausalLM,
)
return FlashMistralForCausalLM(prefix, config, weights, name=name)
else:
raise RuntimeError(f"Unsupported model type {config.model_type}")
def load_vision_model(prefix, config, weights):
if config.model_type == "clip_vision_model":
from text_generation_server.models.custom_modeling.clip import (
CLIPVisionTransformer,
)
return CLIPVisionTransformer(
prefix=f"{prefix}.vision_model", config=config, weights=weights
)
else:
raise RuntimeError(f"Unsupported model type {config.model_type}")
...@@ -33,6 +33,11 @@ from text_generation_server.utils import StoppingCriteria, HeterogeneousNextToke ...@@ -33,6 +33,11 @@ from text_generation_server.utils import StoppingCriteria, HeterogeneousNextToke
from text_generation_server.utils.dist import MEMORY_FRACTION from text_generation_server.utils.dist import MEMORY_FRACTION
tracer = trace.get_tracer(__name__) tracer = trace.get_tracer(__name__)
from text_generation_server.utils.import_utils import (
IS_CUDA_SYSTEM,
IS_ROCM_SYSTEM,
IS_XPU_SYSTEM,
)
@dataclass @dataclass
...@@ -754,7 +759,10 @@ class FlashCausalLM(Model): ...@@ -754,7 +759,10 @@ class FlashCausalLM(Model):
def warmup(self, batch: FlashCausalLMBatch): def warmup(self, batch: FlashCausalLMBatch):
# The warmup batch is the biggest batch we could ever receive # The warmup batch is the biggest batch we could ever receive
torch.cuda.empty_cache() if IS_CUDA_SYSTEM or IS_ROCM_SYSTEM:
torch.cuda.empty_cache()
elif IS_XPU_SYSTEM:
torch.xpu.empty_cache()
try: try:
cache_manager = set_cache_manager( cache_manager = set_cache_manager(
batch.blocks, batch.blocks,
...@@ -774,7 +782,10 @@ class FlashCausalLM(Model): ...@@ -774,7 +782,10 @@ class FlashCausalLM(Model):
f"You need to decrease `--max-batch-prefill-tokens`" f"You need to decrease `--max-batch-prefill-tokens`"
) from e ) from e
torch.cuda.synchronize(self.device) if IS_CUDA_SYSTEM or IS_ROCM_SYSTEM:
torch.cuda.synchronize(self.device)
elif IS_XPU_SYSTEM:
torch.xpu.synchronize(self.device)
# Inspired by the original implementation in [vllm](https://github.com/vllm-project/vllm) # Inspired by the original implementation in [vllm](https://github.com/vllm-project/vllm)
# Calculate the number of blocks that can be allocated with the free memory # Calculate the number of blocks that can be allocated with the free memory
...@@ -782,12 +793,20 @@ class FlashCausalLM(Model): ...@@ -782,12 +793,20 @@ class FlashCausalLM(Model):
cache_block_size = BLOCK_SIZE * self.num_kv_heads * self.head_size cache_block_size = BLOCK_SIZE * self.num_kv_heads * self.head_size
total_cache_size = self.num_layers * cache_block_size * 2 * dtype_size total_cache_size = self.num_layers * cache_block_size * 2 * dtype_size
total_free_memory, _ = torch.cuda.mem_get_info(self.device) if IS_CUDA_SYSTEM or IS_ROCM_SYSTEM:
total_gpu_memory = torch.cuda.get_device_properties(self.device).total_memory total_free_memory, _ = torch.cuda.mem_get_info(self.device)
total_gpu_memory = torch.cuda.get_device_properties(
self.device
).total_memory
free_memory = max( free_memory = max(
0, total_free_memory - (1 - MEMORY_FRACTION) * total_gpu_memory 0, total_free_memory - (1 - MEMORY_FRACTION) * total_gpu_memory
) )
elif IS_XPU_SYSTEM:
total_gpu_memory = torch.xpu.get_device_properties(self.device).total_memory
free_memory = int(total_gpu_memory * 0.5)
else:
raise NotImplementedError("FlashModel is only available on GPU")
num_blocks = ( num_blocks = (
# Leave 5% for some wiggle room # Leave 5% for some wiggle room
...@@ -818,6 +837,8 @@ class FlashCausalLM(Model): ...@@ -818,6 +837,8 @@ class FlashCausalLM(Model):
self.cuda_graph_warmup(bs, max_s, max_bt) self.cuda_graph_warmup(bs, max_s, max_bt)
except torch.cuda.OutOfMemoryError: except torch.cuda.OutOfMemoryError:
logger.exception(f"Decode cuda graph warmup failed") logger.exception(f"Decode cuda graph warmup failed")
else:
logger.info(f"Cuda Graphs are disabled (CUDA_GRAPHS={CUDA_GRAPHS}).")
return int(num_blocks * BLOCK_SIZE) return int(num_blocks * BLOCK_SIZE)
......
...@@ -2,14 +2,13 @@ import torch ...@@ -2,14 +2,13 @@ import torch
import torch.distributed import torch.distributed
from opentelemetry import trace from opentelemetry import trace
from transformers import AutoConfig, AutoTokenizer from transformers import AutoConfig, AutoTokenizer, GenerationConfig
from transformers.models.llama import LlamaTokenizer from transformers.models.llama import LlamaTokenizer
from typing import Optional from typing import Optional
from text_generation_server.models import FlashCausalLM from text_generation_server.models import FlashCausalLM
from text_generation_server.models.custom_modeling.flash_llama_modeling import ( from text_generation_server.models.custom_modeling.flash_llama_modeling import (
FlashLlamaForCausalLM, FlashLlamaForCausalLM,
LlamaConfig,
) )
from text_generation_server.utils import ( from text_generation_server.utils import (
initialize_torch_distributed, initialize_torch_distributed,
...@@ -19,6 +18,8 @@ from text_generation_server.utils import ( ...@@ -19,6 +18,8 @@ from text_generation_server.utils import (
tracer = trace.get_tracer(__name__) tracer = trace.get_tracer(__name__)
from text_generation_server.utils.import_utils import IS_XPU_SYSTEM
class FlashLlama(FlashCausalLM): class FlashLlama(FlashCausalLM):
def __init__( def __init__(
...@@ -34,6 +35,9 @@ class FlashLlama(FlashCausalLM): ...@@ -34,6 +35,9 @@ class FlashLlama(FlashCausalLM):
if torch.cuda.is_available(): if torch.cuda.is_available():
device = torch.device(f"cuda:{rank}") device = torch.device(f"cuda:{rank}")
dtype = torch.float16 if dtype is None else dtype dtype = torch.float16 if dtype is None else dtype
elif IS_XPU_SYSTEM:
device = torch.device(f"xpu:{rank}")
dtype = torch.float16 if dtype is None else dtype
else: else:
raise NotImplementedError("FlashLlama is only available on GPU") raise NotImplementedError("FlashLlama is only available on GPU")
...@@ -53,8 +57,17 @@ class FlashLlama(FlashCausalLM): ...@@ -53,8 +57,17 @@ class FlashLlama(FlashCausalLM):
truncation_side="left", truncation_side="left",
trust_remote_code=trust_remote_code, trust_remote_code=trust_remote_code,
) )
try:
generation_config = GenerationConfig.from_pretrained(
model_id, revision=revision, trust_remote_code=trust_remote_code
)
if isinstance(generation_config.eos_token_id, (list, set)):
# TODO Huge hack
tokenizer._eos_token_ids = set(generation_config.eos_token_id)
except Exception:
pass
config = LlamaConfig.from_pretrained( config = AutoConfig.from_pretrained(
model_id, revision=revision, trust_remote_code=trust_remote_code model_id, revision=revision, trust_remote_code=trust_remote_code
) )
config.quantize = quantize config.quantize = quantize
......
...@@ -33,8 +33,9 @@ tracer = trace.get_tracer(__name__) ...@@ -33,8 +33,9 @@ tracer = trace.get_tracer(__name__)
# Will be set in init # Will be set in init
SLIDING_WINDOW: Optional[int] = None SLIDING_WINDOW: Optional[int] = None
SLIDING_WINDOW_BLOCKS: Optional[int] = None SLIDING_WINDOW_BLOCKS: Optional[int] = None
from text_generation_server.utils.import_utils import IS_XPU_SYSTEM
MEM_POOL = torch.cuda.graph_pool_handle() MEM_POOL = torch.cuda.graph_pool_handle() if torch.cuda.is_available() else None
def set_sliding_window(sliding_window: int, sliding_window_blocks: int): def set_sliding_window(sliding_window: int, sliding_window_blocks: int):
...@@ -316,6 +317,9 @@ class BaseFlashMistral(FlashCausalLM): ...@@ -316,6 +317,9 @@ class BaseFlashMistral(FlashCausalLM):
if torch.cuda.is_available(): if torch.cuda.is_available():
device = torch.device(f"cuda:{rank}") device = torch.device(f"cuda:{rank}")
dtype = torch.float16 if dtype is None else dtype dtype = torch.float16 if dtype is None else dtype
elif IS_XPU_SYSTEM:
device = torch.device(f"xpu:{rank}")
dtype = torch.float16 if dtype is None else dtype
else: else:
raise NotImplementedError("FlashMistral is only available on GPU") raise NotImplementedError("FlashMistral is only available on GPU")
......
...@@ -14,6 +14,7 @@ from text_generation_server.utils import ( ...@@ -14,6 +14,7 @@ from text_generation_server.utils import (
weight_files, weight_files,
Weights, Weights,
) )
from text_generation_server.utils.import_utils import IS_XPU_SYSTEM
tracer = trace.get_tracer(__name__) tracer = trace.get_tracer(__name__)
...@@ -32,6 +33,9 @@ class FlashNeoXSharded(FlashCausalLM): ...@@ -32,6 +33,9 @@ class FlashNeoXSharded(FlashCausalLM):
if torch.cuda.is_available(): if torch.cuda.is_available():
device = torch.device(f"cuda:{rank}") device = torch.device(f"cuda:{rank}")
dtype = torch.float16 if dtype is None else dtype dtype = torch.float16 if dtype is None else dtype
elif IS_XPU_SYSTEM:
device = torch.device(f"xpu:{rank}")
dtype = torch.float16 if dtype is None else dtype
else: else:
raise NotImplementedError("FlashNeoX is only available on GPU") raise NotImplementedError("FlashNeoX is only available on GPU")
......
...@@ -4,7 +4,7 @@ import torch ...@@ -4,7 +4,7 @@ import torch
import torch.distributed import torch.distributed
from opentelemetry import trace from opentelemetry import trace
from transformers.models.qwen2 import Qwen2Tokenizer from transformers import AutoTokenizer, AutoConfig
from typing import Optional from typing import Optional
from text_generation_server.models.cache_manager import BLOCK_SIZE from text_generation_server.models.cache_manager import BLOCK_SIZE
...@@ -15,7 +15,6 @@ from text_generation_server.models.flash_mistral import ( ...@@ -15,7 +15,6 @@ from text_generation_server.models.flash_mistral import (
from text_generation_server.models.custom_modeling.flash_qwen2_modeling import ( from text_generation_server.models.custom_modeling.flash_qwen2_modeling import (
Qwen2ForCausalLM, Qwen2ForCausalLM,
) )
from transformers.models.qwen2 import Qwen2Config
from text_generation_server.utils import ( from text_generation_server.utils import (
initialize_torch_distributed, initialize_torch_distributed,
weight_files, weight_files,
...@@ -42,7 +41,7 @@ class FlashQwen2(BaseFlashMistral): ...@@ -42,7 +41,7 @@ class FlashQwen2(BaseFlashMistral):
else: else:
raise NotImplementedError("FlashQwen2 is only available on GPU") raise NotImplementedError("FlashQwen2 is only available on GPU")
tokenizer = Qwen2Tokenizer.from_pretrained( tokenizer = AutoTokenizer.from_pretrained(
model_id, model_id,
revision=revision, revision=revision,
padding_side="left", padding_side="left",
...@@ -50,7 +49,7 @@ class FlashQwen2(BaseFlashMistral): ...@@ -50,7 +49,7 @@ class FlashQwen2(BaseFlashMistral):
trust_remote_code=trust_remote_code, trust_remote_code=trust_remote_code,
) )
config = Qwen2Config.from_pretrained( config = AutoConfig.from_pretrained(
model_id, revision=revision, trust_remote_code=trust_remote_code model_id, revision=revision, trust_remote_code=trust_remote_code
) )
config.quantize = quantize config.quantize = quantize
......
...@@ -15,6 +15,7 @@ from text_generation_server.utils import ( ...@@ -15,6 +15,7 @@ from text_generation_server.utils import (
weight_files, weight_files,
Weights, Weights,
) )
from text_generation_server.utils.import_utils import IS_XPU_SYSTEM
tracer = trace.get_tracer(__name__) tracer = trace.get_tracer(__name__)
...@@ -33,6 +34,9 @@ class FlashRWSharded(FlashCausalLM): ...@@ -33,6 +34,9 @@ class FlashRWSharded(FlashCausalLM):
if torch.cuda.is_available(): if torch.cuda.is_available():
device = torch.device(f"cuda:{rank}") device = torch.device(f"cuda:{rank}")
dtype = torch.float16 if dtype is None else dtype dtype = torch.float16 if dtype is None else dtype
elif IS_XPU_SYSTEM:
device = torch.device(f"xpu:{rank}")
dtype = torch.float16 if dtype is None else dtype
else: else:
raise NotImplementedError("FlashRW is only available on GPU") raise NotImplementedError("FlashRW is only available on GPU")
......
...@@ -18,6 +18,8 @@ from text_generation_server.utils import ( ...@@ -18,6 +18,8 @@ from text_generation_server.utils import (
Weights, Weights,
) )
from text_generation_server.utils.import_utils import IS_XPU_SYSTEM
tracer = trace.get_tracer(__name__) tracer = trace.get_tracer(__name__)
...@@ -35,6 +37,9 @@ class FlashSantacoderSharded(FlashCausalLM): ...@@ -35,6 +37,9 @@ class FlashSantacoderSharded(FlashCausalLM):
if torch.cuda.is_available(): if torch.cuda.is_available():
device = torch.device(f"cuda:{rank}") device = torch.device(f"cuda:{rank}")
dtype = torch.float16 if dtype is None else dtype dtype = torch.float16 if dtype is None else dtype
elif IS_XPU_SYSTEM:
device = torch.device(f"xpu:{rank}")
dtype = torch.float16 if dtype is None else dtype
else: else:
raise NotImplementedError("FlashSantacoderSharded is only available on GPU") raise NotImplementedError("FlashSantacoderSharded is only available on GPU")
......
import torch import torch
import os import os
MEM_POOL = torch.cuda.graph_pool_handle() MEM_POOL = torch.cuda.graph_pool_handle() if torch.cuda.is_available() else None
# This is overridden by the cli # This is overridden by the cli
cuda_graphs = os.getenv("CUDA_GRAPHS") cuda_graphs = os.getenv("CUDA_GRAPHS")
if cuda_graphs is not None: if cuda_graphs is not None:
...@@ -11,4 +11,7 @@ if cuda_graphs is not None: ...@@ -11,4 +11,7 @@ if cuda_graphs is not None:
raise RuntimeError( raise RuntimeError(
f"Could not parse cuda graphs {cuda_graphs}, expected comma separated list for batch sizes to run on: {e}" f"Could not parse cuda graphs {cuda_graphs}, expected comma separated list for batch sizes to run on: {e}"
) )
else:
cuda_graphs = None
CUDA_GRAPHS = cuda_graphs CUDA_GRAPHS = cuda_graphs
import torch
from typing import Optional, Tuple
from transformers import (
AutoProcessor,
)
from text_generation_server.models.custom_modeling.idefics2 import (
Idefics2ForConditionalGeneration,
)
from text_generation_server.models.vlm_causal_lm import VlmCausalLM
class Idefics2(VlmCausalLM):
def __init__(
self,
model_id: str,
revision: Optional[str] = None,
quantize: Optional[str] = None,
use_medusa: Optional[str] = None,
dtype: Optional[torch.dtype] = None,
trust_remote_code: bool = False,
):
self.processor = AutoProcessor.from_pretrained(
model_id,
revision=revision,
trust_remote_code=trust_remote_code,
# XXX: Extremely important to cap resolution in order to limit
# VRAM usage.
size={"longest_edge": 448, "shortest_edge": 378},
)
super().__init__(
model_cls=Idefics2ForConditionalGeneration,
model_id=model_id,
revision=revision,
quantize=quantize,
use_medusa=use_medusa,
dtype=dtype,
trust_remote_code=trust_remote_code,
)
def get_layer_config(self, model) -> Tuple[int, int, int]:
return (
len(model.text_model.model.layers),
model.text_model.model.num_key_value_heads,
model.text_model.model.head_size,
)
def max_past(self) -> Optional[int]:
return getattr(self.model.text_model, "max_past", None)
import torch import torch
from typing import Optional from typing import Optional, Tuple
from transformers import ( from transformers import (
AutoProcessor, AutoProcessor,
...@@ -34,3 +34,13 @@ class LlavaNext(VlmCausalLM): ...@@ -34,3 +34,13 @@ class LlavaNext(VlmCausalLM):
dtype=dtype, dtype=dtype,
trust_remote_code=trust_remote_code, trust_remote_code=trust_remote_code,
) )
def get_layer_config(self, model) -> Tuple[int, int, int]:
return (
len(model.language_model.model.layers),
model.language_model.model.num_key_value_heads,
model.language_model.model.head_size,
)
def max_past(self) -> Optional[int]:
return getattr(self.model.language_model, "max_past", None)
...@@ -474,6 +474,8 @@ class Mamba(Model): ...@@ -474,6 +474,8 @@ class Mamba(Model):
self.cuda_graph_warmup(bs) self.cuda_graph_warmup(bs)
except Exception: except Exception:
logger.exception(f"Decode cuda graph warmup failed") logger.exception(f"Decode cuda graph warmup failed")
else:
logger.info(f"Cuda Graphs are disabled (CUDA_GRAPHS={CUDA_GRAPHS}).")
return None return None
......
...@@ -27,7 +27,14 @@ class Model(ABC): ...@@ -27,7 +27,14 @@ class Model(ABC):
): ):
self.model = model.eval() self.model = model.eval()
self.tokenizer = tokenizer self.tokenizer = tokenizer
# all_special_ids is not set correctly if the rust tokenizer is unpacked
# TODO report this to transformers.
other_special_ids = {
id for id, token in tokenizer.added_tokens_decoder.items() if token.special
}
self.all_special_ids = set(tokenizer.all_special_ids) self.all_special_ids = set(tokenizer.all_special_ids)
self.all_special_ids.update(other_special_ids)
self.requires_padding = requires_padding self.requires_padding = requires_padding
self.dtype = dtype self.dtype = dtype
self.device = device self.device = device
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment