Commit 5a1cf2f0 authored by huangwb's avatar huangwb
Browse files

Merge tag 'v2.0.2' into dev-rocm

parents 24f58bb6 6073ece4
...@@ -64,6 +64,46 @@ def get_anyres_image_grid_shape(image_size, grid_pinpoints, patch_size): ...@@ -64,6 +64,46 @@ def get_anyres_image_grid_shape(image_size, grid_pinpoints, patch_size):
return height // patch_size, width // patch_size return height // patch_size, width // patch_size
def image_text_replacement(image_input, config, image_id) -> str:
if config.model_type == "idefics2":
# TODO technically depends on image splitting which is not implemented.
num_features = 320
return (
"<fake_token_around_image>"
+ "<image>" * num_features
+ "<fake_token_around_image>"
)
elif config.model_type == "llava_next":
height, width = image_input["image_sizes"][image_id]
num_features = get_number_of_features(height, width, config)
from loguru import logger
logger.info(f"Found {num_features} in image of resolution {height}x{width}")
return "<image>" * num_features
else:
raise RuntimeError(f"Unknown config {config.model_type} for multimodal")
def get_unpadded_features(
height: int, width: int, npatches: int, num_patch_height: int, num_patch_width: int
) -> Tuple[int, int]:
current_height = npatches * num_patch_height
current_width = npatches * num_patch_width
aspect_ratio: float = width / height
current_aspect_ratio: float = current_width / current_height
if aspect_ratio > current_aspect_ratio:
new_height = (height * current_width) // width
current_height = new_height
else:
new_width = (width * current_height) // height
current_width = new_width
unpadded_features = current_height * current_width
newline_features = current_height
return (unpadded_features, newline_features)
def get_number_of_features(height: int, width: int, config) -> int: def get_number_of_features(height: int, width: int, config) -> int:
# From config # From config
# Hardcoded for CLIP for now # Hardcoded for CLIP for now
...@@ -81,12 +121,9 @@ def get_number_of_features(height: int, width: int, config) -> int: ...@@ -81,12 +121,9 @@ def get_number_of_features(height: int, width: int, config) -> int:
image_grid_pinpoints, image_grid_pinpoints,
image_size, image_size,
) )
unpadded_features, newline_features = get_unpadded_features(
height_of_patch = math.ceil(height / width * npatches) height, width, npatches, num_patch_height, num_patch_width
)
unpadded_features = npatches * height_of_patch * num_patch_height * num_patch_width
# They are only added after width
newline_features = height_of_patch * num_patch_width
# The base patch covers the entire image # The base patch covers the entire image
base_features = npatches**2 base_features = npatches**2
return unpadded_features + newline_features + base_features return unpadded_features + newline_features + base_features
...@@ -99,12 +136,9 @@ def load_data_uri(image_uri: str) -> Image.Image: ...@@ -99,12 +136,9 @@ def load_data_uri(image_uri: str) -> Image.Image:
return image return image
# assert get_number_of_features(889, 1024) == 2634, f"{get_number_of_features(889, 1024)}"
# assert get_number_of_features(640, 640) == 2928
class VlmCausalLMBatch(FlashMistralBatch): class VlmCausalLMBatch(FlashMistralBatch):
pixel_values: Optional[List[torch.Tensor]] pixel_values: Optional[List[torch.Tensor]]
pixel_attention_mask: Optional[List[torch.Tensor]]
image_sizes: Optional[List[Tuple[int, int]]] image_sizes: Optional[List[Tuple[int, int]]]
@classmethod @classmethod
...@@ -112,6 +146,7 @@ class VlmCausalLMBatch(FlashMistralBatch): ...@@ -112,6 +146,7 @@ class VlmCausalLMBatch(FlashMistralBatch):
def concatenate(cls, batches): def concatenate(cls, batches):
batch = super(VlmCausalLMBatch, cls).concatenate(batches) batch = super(VlmCausalLMBatch, cls).concatenate(batches)
batch.pixel_values = None batch.pixel_values = None
batch.pixel_attention_mask = None
batch.image_sizes = None batch.image_sizes = None
return batch return batch
...@@ -119,6 +154,7 @@ class VlmCausalLMBatch(FlashMistralBatch): ...@@ -119,6 +154,7 @@ class VlmCausalLMBatch(FlashMistralBatch):
def filter(self, request_ids: List[int]): def filter(self, request_ids: List[int]):
batch = super().filter(request_ids) batch = super().filter(request_ids)
batch.pixel_values = None batch.pixel_values = None
batch.pixel_attention_mask = None
batch.image_sizes = None batch.image_sizes = None
return batch return batch
...@@ -130,6 +166,7 @@ class VlmCausalLMBatch(FlashMistralBatch): ...@@ -130,6 +166,7 @@ class VlmCausalLMBatch(FlashMistralBatch):
for r in requests: for r in requests:
chunks = split(r.inputs) chunks = split(r.inputs)
full_text = "" full_text = ""
image_id = 0
for chunk in chunks: for chunk in chunks:
if chunk["type"] == "text": if chunk["type"] == "text":
full_text += chunk["content"] full_text += chunk["content"]
...@@ -147,9 +184,7 @@ class VlmCausalLMBatch(FlashMistralBatch): ...@@ -147,9 +184,7 @@ class VlmCausalLMBatch(FlashMistralBatch):
"Cannot process input image not starting with data:" "Cannot process input image not starting with data:"
) )
image_input = processor.image_processor(image, return_tensors="pt") image_input = processor.image_processor(image, return_tensors="pt")
height, width = image_input["image_sizes"][0] full_text += image_text_replacement(image_input, config, image_id)
num_features = get_number_of_features(height, width, config)
full_text += "<image>" * num_features
image_inputs.append(image_input) image_inputs.append(image_input)
else: else:
raise RuntimeError(f"Invalid chunk type {chunk['type']}") raise RuntimeError(f"Invalid chunk type {chunk['type']}")
...@@ -161,12 +196,21 @@ class VlmCausalLMBatch(FlashMistralBatch): ...@@ -161,12 +196,21 @@ class VlmCausalLMBatch(FlashMistralBatch):
batch_inputs, truncation=True, max_length=max_truncation batch_inputs, truncation=True, max_length=max_truncation
)["input_ids"] )["input_ids"]
if image_inputs: if image_inputs:
image_inputs = { image_input = image_inputs[0]
new_image_inputs = {
"pixel_values": torch.cat( "pixel_values": torch.cat(
[img["pixel_values"] for img in image_inputs], dim=0 [img["pixel_values"] for img in image_inputs], dim=0
), ),
"image_sizes": torch.cat([img["image_sizes"] for img in image_inputs]),
} }
if "pixel_attention_mask" in image_input:
new_image_inputs["pixel_attention_mask"] = torch.cat(
[img["pixel_attention_mask"] for img in image_inputs], dim=0
)
if "image_sizes" in image_input:
new_image_inputs["image_sizes"] = torch.cat(
[img["image_sizes"] for img in image_inputs], dim=0
)
image_inputs = new_image_inputs
else: else:
image_inputs = None image_inputs = None
return batch_tokenized_inputs, image_inputs return batch_tokenized_inputs, image_inputs
...@@ -187,9 +231,19 @@ class VlmCausalLMBatch(FlashMistralBatch): ...@@ -187,9 +231,19 @@ class VlmCausalLMBatch(FlashMistralBatch):
batch = cls.from_tokenized(pb, tokenizer, batch_tokenized_inputs, dtype, device) batch = cls.from_tokenized(pb, tokenizer, batch_tokenized_inputs, dtype, device)
if image_inputs is not None: if image_inputs is not None:
batch.pixel_values = image_inputs["pixel_values"].to(device=device) batch.pixel_values = image_inputs["pixel_values"].to(device=device)
batch.image_sizes = image_inputs["image_sizes"].to(device=device) if "pixel_attention_mask" in image_inputs:
batch.pixel_attention_mask = image_inputs["pixel_attention_mask"].to(
device=device
)
else:
batch.pixel_attention_mask = None
if "image_sizes" in image_inputs:
batch.image_sizes = image_inputs["image_sizes"].to(device=device)
else:
batch.image_sizes = None
else: else:
batch.pixel_values = None batch.pixel_values = None
batch.pixel_attention_mask = None
batch.image_sizes = None batch.image_sizes = None
return batch return batch
...@@ -199,16 +253,6 @@ class VlmCausalLM(BaseFlashMistral): ...@@ -199,16 +253,6 @@ class VlmCausalLM(BaseFlashMistral):
def batch_type(self) -> Type[VlmCausalLMBatch]: def batch_type(self) -> Type[VlmCausalLMBatch]:
return VlmCausalLMBatch return VlmCausalLMBatch
def get_layer_config(self, model) -> Tuple[int, int, int]:
return (
len(model.language_model.model.layers),
model.language_model.model.num_key_value_heads,
model.language_model.model.head_size,
)
def max_past(self) -> Optional[int]:
return getattr(self.model.language_model, "max_past", None)
def forward( def forward(
self, batch: VlmCausalLMBatch self, batch: VlmCausalLMBatch
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
...@@ -270,17 +314,14 @@ class VlmCausalLM(BaseFlashMistral): ...@@ -270,17 +314,14 @@ class VlmCausalLM(BaseFlashMistral):
max_s = min(self.max_past(), max_s) max_s = min(self.max_past(), max_s)
bs = input_ids.shape[0] bs = input_ids.shape[0]
padded_bs = bs
if bs == 3:
padded_bs = 4
elif 3 < bs <= 8:
padded_bs = 8
elif bs > 8:
padded_bs = (bs + 7) // 8 * 8
# Try to find an associated cuda graph # Try to find an associated cuda graph
cuda_graph = self.cuda_graphs.get(padded_bs, None) bs = input_ids.shape[0]
sorted_padded_bs = sorted([k for k in self.cuda_graphs.keys() if k >= bs])
if sorted_padded_bs:
# Get associated cuda graph
cuda_graph = self.cuda_graphs[sorted_padded_bs[0]]
else:
cuda_graph = None
if cu_seqlen_prefill is not None or cuda_graph is None: if cu_seqlen_prefill is not None or cuda_graph is None:
logits, speculative_logits = self.model.forward( logits, speculative_logits = self.model.forward(
input_ids=input_ids, input_ids=input_ids,
...@@ -294,12 +335,15 @@ class VlmCausalLM(BaseFlashMistral): ...@@ -294,12 +335,15 @@ class VlmCausalLM(BaseFlashMistral):
prefill_cache_indices=batch.prefill_cache_indices, prefill_cache_indices=batch.prefill_cache_indices,
lm_head_indices=lm_head_indices, lm_head_indices=lm_head_indices,
pixel_values=batch.pixel_values, pixel_values=batch.pixel_values,
pixel_attention_mask=batch.pixel_attention_mask,
image_sizes=batch.image_sizes, image_sizes=batch.image_sizes,
) )
if batch.prefill_cache_indices is not None: if batch.prefill_cache_indices is not None:
batch.prefill_cache_indices = None batch.prefill_cache_indices = None
if batch.pixel_values is not None: if batch.pixel_values is not None:
batch.pixel_values = None batch.pixel_values = None
if batch.pixel_attention_mask is not None:
batch.pixel_attention_mask = None
if batch.image_sizes is not None: if batch.image_sizes is not None:
batch.image_sizes = None batch.image_sizes = None
return logits, speculative_logits return logits, speculative_logits
......
...@@ -2,6 +2,7 @@ import asyncio ...@@ -2,6 +2,7 @@ import asyncio
import os import os
import torch import torch
import time import time
import signal
from grpc import aio from grpc import aio
from loguru import logger from loguru import logger
...@@ -19,6 +20,21 @@ from text_generation_server.tracing import UDSOpenTelemetryAioServerInterceptor ...@@ -19,6 +20,21 @@ from text_generation_server.tracing import UDSOpenTelemetryAioServerInterceptor
from text_generation_server.models.idefics_causal_lm import IdeficsCausalLMBatch from text_generation_server.models.idefics_causal_lm import IdeficsCausalLMBatch
class SignalHandler:
KEEP_PROCESSING = True
def __init__(self):
signal.signal(signal.SIGINT, self.exit_gracefully)
signal.signal(signal.SIGTERM, self.exit_gracefully)
def exit_gracefully(self, signum, frame):
print(f"Exiting gracefully: Signal {signum}")
self.KEEP_PROCESSING = False
signal_handler = SignalHandler()
class TextGenerationService(generate_pb2_grpc.TextGenerationServiceServicer): class TextGenerationService(generate_pb2_grpc.TextGenerationServiceServicer):
def __init__( def __init__(
self, self,
...@@ -231,11 +247,8 @@ def serve( ...@@ -231,11 +247,8 @@ def serve(
logger.info("Server started at {}".format(local_url)) logger.info("Server started at {}".format(local_url))
try: while signal_handler.KEEP_PROCESSING:
await server.wait_for_termination() await asyncio.sleep(0.5)
except KeyboardInterrupt:
logger.info("Signal received. Shutting down")
await server.stop(0)
asyncio.run( asyncio.run(
serve_inner( serve_inner(
......
...@@ -57,7 +57,14 @@ def initialize_torch_distributed(): ...@@ -57,7 +57,14 @@ def initialize_torch_distributed():
options.is_high_priority_stream = True options.is_high_priority_stream = True
options._timeout = timedelta(seconds=60) options._timeout = timedelta(seconds=60)
else: else:
backend = "gloo" try:
import oneccl_bindings_for_pytorch
backend = "ccl"
if os.getenv("CCL_WORKER_COUNT", None) is None:
os.environ["CCL_WORKER_COUNT"] = str(1)
except ImportError:
backend = "gloo"
options = None options = None
if WORLD_SIZE == 1: if WORLD_SIZE == 1:
......
...@@ -2,69 +2,81 @@ import os ...@@ -2,69 +2,81 @@ import os
import torch import torch
from loguru import logger from loguru import logger
import math
from text_generation_server.utils.import_utils import IS_CUDA_SYSTEM, IS_ROCM_SYSTEM from text_generation_server.utils.import_utils import (
IS_CUDA_SYSTEM,
IS_ROCM_SYSTEM,
IS_XPU_SYSTEM,
)
if os.getenv("USE_FLASH_ATTENTION", "").lower() == "false": if os.getenv("USE_FLASH_ATTENTION", "").lower() == "false":
raise ImportError("`USE_FLASH_ATTENTION` is false.") raise ImportError("`USE_FLASH_ATTENTION` is false.")
HAS_FLASH_ATTN = True
HAS_FLASH_ATTN_V2_CUDA = False
HAS_FLASH_ATTN_V2_ROCM = False
if not torch.cuda.is_available(): if IS_XPU_SYSTEM:
raise ImportError("CUDA is not available") import intel_extension_for_pytorch as ipex
major, minor = torch.cuda.get_device_capability() if IS_CUDA_SYSTEM or IS_ROCM_SYSTEM:
is_sm75 = major == 7 and minor == 5 if not torch.cuda.is_available():
is_sm8x = major == 8 and minor >= 0 raise ImportError("CUDA is not available")
is_sm90 = major == 9 and minor == 0
HAS_FLASH_ATTN = False major, minor = torch.cuda.get_device_capability()
HAS_FLASH_ATTN_V2_CUDA = False is_sm75 = major == 7 and minor == 5
HAS_FLASH_ATTN_V2_ROCM = False is_sm8x = major == 8 and minor >= 0
try: is_sm90 = major == 9 and minor == 0
HAS_FLASH_ATTN = False
HAS_FLASH_ATTN_V2_CUDA = False
HAS_FLASH_ATTN_V2_ROCM = False
try: try:
import flash_attn_2_cuda try:
except ImportError: import flash_attn_2_cuda
architecture_suffix = "" except ImportError:
if IS_CUDA_SYSTEM: architecture_suffix = ""
architecture_suffix = "-cuda" if IS_CUDA_SYSTEM:
architecture_suffix = "-cuda"
elif IS_ROCM_SYSTEM:
architecture_suffix = "-rocm"
raise ImportError(
"Flash Attention V2 is not installed.\n"
"Use the official Docker image (ghcr.io/huggingface/text-generation-inference:latest) "
f"or install flash attention v2 with `cd server && make install install-flash-attention-v2{architecture_suffix}`"
)
if not (is_sm8x or is_sm90):
raise ImportError(
f"GPU with CUDA capability {major} {minor} is not supported for "
"Flash Attention V2"
)
HAS_FLASH_ATTN_V2_CUDA = IS_CUDA_SYSTEM
HAS_FLASH_ATTN_V2_ROCM = IS_ROCM_SYSTEM
except ImportError as e:
try:
import flash_attn_cuda
except ImportError:
raise ImportError(
"Flash Attention is not installed.\n"
"Use the official Docker image (ghcr.io/huggingface/text-generation-inference:latest) "
"or install flash attention with `cd server && make install install-flash-attention`"
) from e
if IS_CUDA_SYSTEM and not (is_sm75 or is_sm8x or is_sm90):
raise ImportError(
f"GPU with CUDA capability {major} {minor} is not supported"
) from e
elif IS_ROCM_SYSTEM: elif IS_ROCM_SYSTEM:
architecture_suffix = "-rocm" for idx in range(torch.cuda.device_count()):
raise ImportError( if "MI210" not in torch.cuda.get_device_name(
"Flash Attention V2 is not installed.\n" idx
"Use the official Docker image (ghcr.io/huggingface/text-generation-inference:latest) " ) and "MI250" not in torch.cuda.get_device_name(idx):
f"or install flash attention v2 with `cd server && make install install-flash-attention-v2{architecture_suffix}`" raise ImportError(
) f"AMD GPU {torch.cuda.get_device_name(idx)} does not support flash-attention"
if not (is_sm8x or is_sm90) and IS_CUDA_SYSTEM: )
raise ImportError(
f"GPU with CUDA capability {major} {minor} is not supported for "
"Flash Attention V2"
)
HAS_FLASH_ATTN_V2_CUDA = IS_CUDA_SYSTEM
HAS_FLASH_ATTN_V2_ROCM = IS_ROCM_SYSTEM
except ImportError as e:
try:
import flash_attn_cuda
except ImportError:
raise ImportError(
"Flash Attention is not installed.\n"
"Use the official Docker image (ghcr.io/huggingface/text-generation-inference:latest) "
"or install flash attention with `cd server && make install install-flash-attention`"
) from e
if IS_CUDA_SYSTEM and not (is_sm75 or is_sm8x or is_sm90):
raise ImportError(
f"GPU with CUDA capability {major} {minor} is not supported"
) from e
elif IS_ROCM_SYSTEM:
for idx in range(torch.cuda.device_count()):
if "MI210" not in torch.cuda.get_device_name(
idx
) and "MI250" not in torch.cuda.get_device_name(idx):
raise ImportError(
f"AMD GPU {torch.cuda.get_device_name(idx)} does not support flash-attention"
)
logger.warning(f"Unable to use Flash Attention V2: {e}") logger.warning(f"Unable to use Flash Attention V2: {e}")
HAS_FLASH_ATTN = True HAS_FLASH_ATTN = True
def attention( def attention(
...@@ -80,6 +92,28 @@ def attention( ...@@ -80,6 +92,28 @@ def attention(
if window_size_left <= 0 and window_size_left != -1: if window_size_left <= 0 and window_size_left != -1:
raise ValueError("`window_size_left` must be > 0 or -1") raise ValueError("`window_size_left` must be > 0 or -1")
if IS_XPU_SYSTEM:
if window_size_left != -1:
raise ValueError(
f"XPU version of Flash Attention does not support window attention (window_size_left != -1, got window_size_left={window_size_left})."
)
return ipex.llm.functional.varlen_attention(
q,
k,
v,
out,
cu_seqlens,
cu_seqlens,
max_s,
max_s,
0.0,
softmax_scale,
False,
True,
False,
None,
)
if HAS_FLASH_ATTN_V2_CUDA: if HAS_FLASH_ATTN_V2_CUDA:
return flash_attn_2_cuda.varlen_fwd( return flash_attn_2_cuda.varlen_fwd(
q, q,
......
import torch import torch
def is_xpu_available():
try:
import intel_extension_for_pytorch
except ImportError:
return False
return hasattr(torch, "xpu") and torch.xpu.is_available()
IS_ROCM_SYSTEM = torch.version.hip is not None IS_ROCM_SYSTEM = torch.version.hip is not None
IS_CUDA_SYSTEM = torch.version.cuda is not None IS_CUDA_SYSTEM = torch.version.cuda is not None
IS_XPU_SYSTEM = is_xpu_available()
...@@ -8,6 +8,8 @@ from typing import List, Tuple, Optional ...@@ -8,6 +8,8 @@ from typing import List, Tuple, Optional
from loguru import logger from loguru import logger
from functools import lru_cache from functools import lru_cache
from text_generation_server.utils.speculate import get_speculate
HAS_BITS_AND_BYTES = True HAS_BITS_AND_BYTES = True
try: try:
import bitsandbytes as bnb import bitsandbytes as bnb
...@@ -18,7 +20,14 @@ except ImportError: ...@@ -18,7 +20,14 @@ except ImportError:
from accelerate import init_empty_weights from accelerate import init_empty_weights
from text_generation_server.utils.gptq.quant_linear import QuantLinear from text_generation_server.utils.gptq.quant_linear import QuantLinear
from text_generation_server.utils.import_utils import IS_CUDA_SYSTEM, IS_ROCM_SYSTEM from text_generation_server.utils.import_utils import (
IS_CUDA_SYSTEM,
IS_ROCM_SYSTEM,
IS_XPU_SYSTEM,
)
if IS_XPU_SYSTEM:
import intel_extension_for_pytorch as ipex
HAS_AWQ = True HAS_AWQ = True
try: try:
...@@ -437,7 +446,7 @@ class MedusaModel(torch.nn.Module): ...@@ -437,7 +446,7 @@ class MedusaModel(torch.nn.Module):
self.heads = torch.nn.ModuleList( self.heads = torch.nn.ModuleList(
[ [
MedusaHead(config, medusa_config, prefix=f"{i}", weights=weights) MedusaHead(config, medusa_config, prefix=f"{i}", weights=weights)
for i in range(medusa_config["medusa_num_heads"]) for i in range(get_speculate())
] ]
) )
...@@ -534,7 +543,7 @@ class MedusaHeadV2(nn.Module): ...@@ -534,7 +543,7 @@ class MedusaHeadV2(nn.Module):
) )
routing[k] = filename routing[k] = filename
self.n_medusa_heads = medusa_config["medusa_num_heads"] self.n_medusa_heads = get_speculate()
assert medusa_config["medusa_num_layers"] == 1 assert medusa_config["medusa_num_layers"] == 1
self.linear = TensorParallelColumnLinear.load_multi( self.linear = TensorParallelColumnLinear.load_multi(
...@@ -696,6 +705,19 @@ class TensorParallelHead(SuperLayer): ...@@ -696,6 +705,19 @@ class TensorParallelHead(SuperLayer):
class TensorParallelColumnLinear(SuperLayer): class TensorParallelColumnLinear(SuperLayer):
@classmethod
def load_gate_up(cls, config, prefix: str, weights, bias: bool):
"""Specific method when the QKV was joined after the fact"""
weight = weights.get_weights_col_packed_gate_up(
prefix, quantize=config.quantize
)
if bias:
raise NotImplementedError("packed_gate_up only implemented without bias")
else:
bias = None
linear = get_linear(weight, bias, config.quantize)
return cls(linear)
@classmethod @classmethod
def load_qkv(cls, config, prefix: str, weights, bias: bool): def load_qkv(cls, config, prefix: str, weights, bias: bool):
"""Specific method when the QKV was joined after the fact""" """Specific method when the QKV was joined after the fact"""
...@@ -799,7 +821,15 @@ try: ...@@ -799,7 +821,15 @@ try:
class FastLayerNorm(nn.LayerNorm): class FastLayerNorm(nn.LayerNorm):
def forward(self, hidden_states, residual=None): def forward(self, hidden_states, residual=None):
if hidden_states.shape[-1] > 8192 or IS_ROCM_SYSTEM: if IS_XPU_SYSTEM:
res_out = hidden_states
out = ipex.llm.functional.add_layer_norm(
residual, hidden_states, self.weight, self.bias, self.eps, True
)
if residual is not None:
res_out = residual
return out, res_out
elif hidden_states.shape[-1] > 8192 or IS_ROCM_SYSTEM:
if residual is not None: if residual is not None:
hidden_states += residual hidden_states += residual
residual = hidden_states residual = hidden_states
...@@ -845,7 +875,20 @@ try: ...@@ -845,7 +875,20 @@ try:
return cls(weight, eps) return cls(weight, eps)
def forward(self, hidden_states, residual=None): def forward(self, hidden_states, residual=None):
if hidden_states.shape[-1] > 8192: if IS_XPU_SYSTEM:
residual_out = hidden_states
out = ipex.llm.functional.add_rms_norm(
residual,
hidden_states,
self.weight,
None,
self.variance_epsilon,
True,
)
if residual is not None:
residual_out = residual
return out, residual_out
elif hidden_states.shape[-1] > 8192:
if residual is not None: if residual is not None:
hidden_states += residual hidden_states += residual
residual = hidden_states residual = hidden_states
...@@ -971,6 +1014,10 @@ try: ...@@ -971,6 +1014,10 @@ try:
# Inplace operation, updating query and key. # Inplace operation, updating query and key.
pos_encoding_ops.rotary_embedding(query, key, head_size, cos, sin, True) pos_encoding_ops.rotary_embedding(query, key, head_size, cos, sin, True)
elif IS_XPU_SYSTEM:
ipex.llm.functional.rotary_embedding(
query, key, sin, cos, query.size(-1), True
)
else: else:
raise ValueError( raise ValueError(
"Your system seem to be not supported. Please check your install or open an issue at https://github.com/huggingface/text-generation-inference/issues with a clear reproduction." "Your system seem to be not supported. Please check your install or open an issue at https://github.com/huggingface/text-generation-inference/issues with a clear reproduction."
...@@ -1090,6 +1137,7 @@ try: ...@@ -1090,6 +1137,7 @@ try:
cos = torch.index_select(self._cos_cached, 0, position_ids) cos = torch.index_select(self._cos_cached, 0, position_ids)
sin = torch.index_select(self._sin_cached, 0, position_ids) sin = torch.index_select(self._sin_cached, 0, position_ids)
# Note: this unsqueeze is not necessary on RoCm + VLLM ROPE implementation, but we leave it as is to avoid yet an other controlflow. # Note: this unsqueeze is not necessary on RoCm + VLLM ROPE implementation, but we leave it as is to avoid yet an other controlflow.
return cos.unsqueeze(1), sin.unsqueeze(1) return cos.unsqueeze(1), sin.unsqueeze(1)
......
...@@ -143,13 +143,16 @@ class FrequencyPenaltyLogitsProcessor(LogitsProcessor): ...@@ -143,13 +143,16 @@ class FrequencyPenaltyLogitsProcessor(LogitsProcessor):
score = torch.gather(scores, 1, input_ids) score = torch.gather(scores, 1, input_ids)
# if score < 0 then penalty has to be multiplied to reduce the previous token probability # if score < 0 then penalty has to be multiplied to reduce the previous token probability
score = -torch.where(score < 0, score * self.penalty, score / self.penalty) score = -torch.where(score < 0, score * self.penalty, score / self.penalty)
# set score to 0 where input_ids is a padding token
score *= input_ids.ne(0)
return scores.scatter_add_(1, input_ids, score) return scores.scatter_add_(1, input_ids, score)
class HeterogeneousFrequencyPenaltyLogitsProcessor(LogitsProcessor): class HeterogeneousFrequencyPenaltyLogitsProcessor(LogitsProcessor):
r""" r"""
Frequency penalty as defined by OpenAI Frequency penalty as defined by OpenAI in
https://platform.openai.com/docs/guides/text-generation/parameter-details
Args: Args:
frequency_penalty (`List[float]`): frequency_penalty (`List[float]`):
...@@ -163,13 +166,19 @@ class HeterogeneousFrequencyPenaltyLogitsProcessor(LogitsProcessor): ...@@ -163,13 +166,19 @@ class HeterogeneousFrequencyPenaltyLogitsProcessor(LogitsProcessor):
).unsqueeze(1) ).unsqueeze(1)
def __call__(self, input_ids: torch.Tensor, scores: torch.Tensor) -> torch.Tensor: def __call__(self, input_ids: torch.Tensor, scores: torch.Tensor) -> torch.Tensor:
score = torch.gather(scores, 1, input_ids) batch_size, input_size = input_ids.size()
# if score < 0 then penalty has to be multiplied to reduce the previous token probability vocab_size = scores.size(1)
score = -torch.where(
score < 0, score * self.penalty_tensor, score / self.penalty_tensor # Calculate the frequency for each token so far
token_freq = torch.zeros(batch_size, vocab_size, device=input_ids.device)
token_freq.scatter_add_(
1, input_ids, torch.ones_like(input_ids, dtype=torch.float)
) )
token_freq /= input_size
return scores.scatter_add_(1, input_ids, score) # Apply the frequency penalty to logits
scores -= token_freq * self.penalty_tensor
return scores
def filter(self, indices): def filter(self, indices):
self.penalty = [self.penalty[i] for i in indices] self.penalty = [self.penalty[i] for i in indices]
......
import torch import torch
from text_generation_server.utils.import_utils import (
from text_generation_server.utils.import_utils import IS_CUDA_SYSTEM, IS_ROCM_SYSTEM IS_CUDA_SYSTEM,
from loguru import logger IS_ROCM_SYSTEM,
IS_XPU_SYSTEM,
)
_PARTITION_SIZE = 512 _PARTITION_SIZE = 512
if IS_XPU_SYSTEM:
import intel_extension_for_pytorch as ipex
def reshape_and_cache( def reshape_and_cache(
key: torch.Tensor, key: torch.Tensor,
...@@ -22,8 +27,11 @@ def reshape_and_cache( ...@@ -22,8 +27,11 @@ def reshape_and_cache(
elif IS_ROCM_SYSTEM: elif IS_ROCM_SYSTEM:
from vllm import cache_ops from vllm import cache_ops
cache_ops.reshape_and_cache(key, value, key_cache, value_cache, slots.int()) cache_ops.reshape_and_cache(key, value, key_cache, value_cache, slots)
# cache_ops.reshape_and_cache(key, value, key_cache, value_cache, slots) elif IS_XPU_SYSTEM:
ipex.llm.modules.PagedAttention.reshape_and_cache(
key, value, key_cache, value_cache, slots
)
else: else:
raise ValueError("vllm is not supported on your system") raise ValueError("vllm is not supported on your system")
...@@ -60,6 +68,22 @@ def attention( ...@@ -60,6 +68,22 @@ def attention(
block_size = value_cache.shape[3] block_size = value_cache.shape[3]
num_seqs, num_heads, head_size = query.shape num_seqs, num_heads, head_size = query.shape
max_num_partitions = (max_s + _PARTITION_SIZE - 1) // _PARTITION_SIZE max_num_partitions = (max_s + _PARTITION_SIZE - 1) // _PARTITION_SIZE
if IS_XPU_SYSTEM:
query = query.contiguous()
return ipex.llm.modules.PagedAttention.single_query_cached_kv_attention(
out,
query,
key_cache,
value_cache,
kv_head_mapping,
softmax_scale,
block_tables,
input_lengths,
block_size,
max_s,
None,
)
# NOTE(woosuk): We use a simple heuristic to decide whether to use # NOTE(woosuk): We use a simple heuristic to decide whether to use
# PagedAttention V1 or V2. If the number of partitions is 1, we use # PagedAttention V1 or V2. If the number of partitions is 1, we use
# V1 to avoid the overhead of reduction. Also, if the number of # V1 to avoid the overhead of reduction. Also, if the number of
......
import re import re
from typing import List, Optional, Tuple from typing import List, Optional, Tuple, Set, Union
import math import math
import torch import torch
...@@ -143,12 +143,22 @@ class StopSequenceCriteria: ...@@ -143,12 +143,22 @@ class StopSequenceCriteria:
class StoppingCriteria: class StoppingCriteria:
def __init__( def __init__(
self, self,
eos_token_id: int, eos_token_ids: Optional[Union[Set[int], int]],
stop_sequence_criterias: List[StopSequenceCriteria], stop_sequence_criterias: List[StopSequenceCriteria],
max_new_tokens: int = 20, max_new_tokens: int = 20,
ignore_eos_token: bool = False, ignore_eos_token: bool = False,
): ):
self.eos_token_id = eos_token_id if eos_token_ids is None:
eos_token_ids = set()
elif isinstance(eos_token_ids, int):
eos_token_ids = set([eos_token_ids])
elif isinstance(eos_token_ids, set):
eos_token_ids = eos_token_ids
else:
raise RuntimeError(
f"eos_token_ids is of invalid type {type(eos_token_ids)}, expected int, None or set[int]"
)
self.eos_token_ids = eos_token_ids
self.stop_sequence_criterias = stop_sequence_criterias self.stop_sequence_criterias = stop_sequence_criterias
self.max_new_tokens = max_new_tokens self.max_new_tokens = max_new_tokens
self.current_tokens = 0 self.current_tokens = 0
...@@ -160,7 +170,10 @@ class StoppingCriteria: ...@@ -160,7 +170,10 @@ class StoppingCriteria:
if self.current_tokens >= self.max_new_tokens: if self.current_tokens >= self.max_new_tokens:
return True, FinishReason.FINISH_REASON_LENGTH return True, FinishReason.FINISH_REASON_LENGTH
if not self.ignore_eos_token and last_token == self.eos_token_id: if isinstance(last_token, torch.Tensor):
last_token = last_token.item()
if not self.ignore_eos_token and last_token in self.eos_token_ids:
return True, FinishReason.FINISH_REASON_EOS_TOKEN return True, FinishReason.FINISH_REASON_EOS_TOKEN
if self.stop_sequence_criterias: if self.stop_sequence_criterias:
...@@ -184,8 +197,10 @@ class StoppingCriteria: ...@@ -184,8 +197,10 @@ class StoppingCriteria:
stop_sequence_criterias = [ stop_sequence_criterias = [
StopSequenceCriteria(sequence) for sequence in pb.stop_sequences StopSequenceCriteria(sequence) for sequence in pb.stop_sequences
] ]
# TODO Hack because eos_token_id cannot be what we want.
eos_token_id = getattr(tokenizer, "_eos_token_ids", tokenizer.eos_token_id)
return StoppingCriteria( return StoppingCriteria(
tokenizer.eos_token_id, eos_token_id,
stop_sequence_criterias, stop_sequence_criterias,
pb.max_new_tokens, pb.max_new_tokens,
pb.ignore_eos_token, pb.ignore_eos_token,
...@@ -273,7 +288,7 @@ class HeterogeneousNextTokenChooser: ...@@ -273,7 +288,7 @@ class HeterogeneousNextTokenChooser:
else None else None
) )
if any([x != 1.0 for x in temperature]): if any(x != 1.0 for x in temperature):
do_sample = [ do_sample = [
sample or x != 1.0 for x, sample in zip(temperature, do_sample) sample or x != 1.0 for x, sample in zip(temperature, do_sample)
] ]
...@@ -281,15 +296,15 @@ class HeterogeneousNextTokenChooser: ...@@ -281,15 +296,15 @@ class HeterogeneousNextTokenChooser:
HeterogeneousTemperatureLogitsWarper(temperature, dtype, device) HeterogeneousTemperatureLogitsWarper(temperature, dtype, device)
) )
if any([x != 0 for x in top_k]): if any(x != 0 for x in top_k):
do_sample = [sample or x != 0 for x, sample in zip(top_k, do_sample)] do_sample = [sample or x != 0 for x, sample in zip(top_k, do_sample)]
warpers.append(HeterogeneousTopKLogitsWarper(top_k, device)) warpers.append(HeterogeneousTopKLogitsWarper(top_k, device))
if any([x < 1.0 for x in top_p]): if any(x < 1.0 for x in top_p):
do_sample = [sample or x < 1.0 for x, sample in zip(top_p, do_sample)] do_sample = [sample or x < 1.0 for x, sample in zip(top_p, do_sample)]
warpers.append(HeterogeneousTopPLogitsWarper(top_p, dtype, device)) warpers.append(HeterogeneousTopPLogitsWarper(top_p, dtype, device))
if any([x < 1.0 for x in typical_p]): if any(x < 1.0 for x in typical_p):
do_sample = [sample or x < 1.0 for x, sample in zip(typical_p, do_sample)] do_sample = [sample or x < 1.0 for x, sample in zip(typical_p, do_sample)]
warpers.append(HeterogeneousTypicalLogitsWarper(typical_p, dtype, device)) warpers.append(HeterogeneousTypicalLogitsWarper(typical_p, dtype, device))
......
...@@ -141,6 +141,12 @@ class Weights: ...@@ -141,6 +141,12 @@ class Weights:
return weight return weight
def get_weights_col_packed_qkv(self, prefix: str, quantize: str): def get_weights_col_packed_qkv(self, prefix: str, quantize: str):
return self.get_weights_col_packed(prefix, quantize, 3)
def get_weights_col_packed_gate_up(self, prefix: str, quantize: str):
return self.get_weights_col_packed(prefix, quantize, 2)
def get_weights_col_packed(self, prefix: str, quantize: str, blocks: int):
""" """
Highly specific when the underlying tensor is a simple cat of Q,K,V instead of being Highly specific when the underlying tensor is a simple cat of Q,K,V instead of being
already alternating Q,K,V within the main tensor already alternating Q,K,V within the main tensor
...@@ -181,8 +187,8 @@ class Weights: ...@@ -181,8 +187,8 @@ class Weights:
else: else:
slice_ = self._get_slice(f"{prefix}.weight") slice_ = self._get_slice(f"{prefix}.weight")
total_size = slice_.get_shape()[0] total_size = slice_.get_shape()[0]
assert total_size % 3 == 0, "Prepacked qkv is not divisible by 3" assert total_size % blocks == 0, f"Prepacked is not divisible by {blocks}"
single_size = total_size // 3 single_size = total_size // blocks
world_size = self.process_group.size() world_size = self.process_group.size()
rank = self.process_group.rank() rank = self.process_group.rank()
...@@ -192,10 +198,11 @@ class Weights: ...@@ -192,10 +198,11 @@ class Weights:
block_size = single_size // world_size block_size = single_size // world_size
start = rank * block_size start = rank * block_size
stop = (rank + 1) * block_size stop = (rank + 1) * block_size
q = slice_[start:stop] tensors = []
k = slice_[start + single_size : stop + single_size] for i in range(blocks):
v = slice_[start + 2 * single_size : stop + 2 * single_size] tensor = slice_[start + i * single_size : stop + i * single_size]
weight = torch.cat([q, k, v], dim=0) tensors.append(tensor)
weight = torch.cat(tensors, dim=0)
weight = weight.to(device=self.device) weight = weight.to(device=self.device)
weight = weight.to(dtype=self.dtype) weight = weight.to(dtype=self.dtype)
return weight return weight
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment