Commit 5a1cf2f0 authored by huangwb's avatar huangwb
Browse files

Merge tag 'v2.0.2' into dev-rocm

parents 24f58bb6 6073ece4
......@@ -64,6 +64,46 @@ def get_anyres_image_grid_shape(image_size, grid_pinpoints, patch_size):
return height // patch_size, width // patch_size
def image_text_replacement(image_input, config, image_id) -> str:
if config.model_type == "idefics2":
# TODO technically depends on image splitting which is not implemented.
num_features = 320
return (
"<fake_token_around_image>"
+ "<image>" * num_features
+ "<fake_token_around_image>"
)
elif config.model_type == "llava_next":
height, width = image_input["image_sizes"][image_id]
num_features = get_number_of_features(height, width, config)
from loguru import logger
logger.info(f"Found {num_features} in image of resolution {height}x{width}")
return "<image>" * num_features
else:
raise RuntimeError(f"Unknown config {config.model_type} for multimodal")
def get_unpadded_features(
height: int, width: int, npatches: int, num_patch_height: int, num_patch_width: int
) -> Tuple[int, int]:
current_height = npatches * num_patch_height
current_width = npatches * num_patch_width
aspect_ratio: float = width / height
current_aspect_ratio: float = current_width / current_height
if aspect_ratio > current_aspect_ratio:
new_height = (height * current_width) // width
current_height = new_height
else:
new_width = (width * current_height) // height
current_width = new_width
unpadded_features = current_height * current_width
newline_features = current_height
return (unpadded_features, newline_features)
def get_number_of_features(height: int, width: int, config) -> int:
# From config
# Hardcoded for CLIP for now
......@@ -81,12 +121,9 @@ def get_number_of_features(height: int, width: int, config) -> int:
image_grid_pinpoints,
image_size,
)
height_of_patch = math.ceil(height / width * npatches)
unpadded_features = npatches * height_of_patch * num_patch_height * num_patch_width
# They are only added after width
newline_features = height_of_patch * num_patch_width
unpadded_features, newline_features = get_unpadded_features(
height, width, npatches, num_patch_height, num_patch_width
)
# The base patch covers the entire image
base_features = npatches**2
return unpadded_features + newline_features + base_features
......@@ -99,12 +136,9 @@ def load_data_uri(image_uri: str) -> Image.Image:
return image
# assert get_number_of_features(889, 1024) == 2634, f"{get_number_of_features(889, 1024)}"
# assert get_number_of_features(640, 640) == 2928
class VlmCausalLMBatch(FlashMistralBatch):
pixel_values: Optional[List[torch.Tensor]]
pixel_attention_mask: Optional[List[torch.Tensor]]
image_sizes: Optional[List[Tuple[int, int]]]
@classmethod
......@@ -112,6 +146,7 @@ class VlmCausalLMBatch(FlashMistralBatch):
def concatenate(cls, batches):
batch = super(VlmCausalLMBatch, cls).concatenate(batches)
batch.pixel_values = None
batch.pixel_attention_mask = None
batch.image_sizes = None
return batch
......@@ -119,6 +154,7 @@ class VlmCausalLMBatch(FlashMistralBatch):
def filter(self, request_ids: List[int]):
batch = super().filter(request_ids)
batch.pixel_values = None
batch.pixel_attention_mask = None
batch.image_sizes = None
return batch
......@@ -130,6 +166,7 @@ class VlmCausalLMBatch(FlashMistralBatch):
for r in requests:
chunks = split(r.inputs)
full_text = ""
image_id = 0
for chunk in chunks:
if chunk["type"] == "text":
full_text += chunk["content"]
......@@ -147,9 +184,7 @@ class VlmCausalLMBatch(FlashMistralBatch):
"Cannot process input image not starting with data:"
)
image_input = processor.image_processor(image, return_tensors="pt")
height, width = image_input["image_sizes"][0]
num_features = get_number_of_features(height, width, config)
full_text += "<image>" * num_features
full_text += image_text_replacement(image_input, config, image_id)
image_inputs.append(image_input)
else:
raise RuntimeError(f"Invalid chunk type {chunk['type']}")
......@@ -161,12 +196,21 @@ class VlmCausalLMBatch(FlashMistralBatch):
batch_inputs, truncation=True, max_length=max_truncation
)["input_ids"]
if image_inputs:
image_inputs = {
image_input = image_inputs[0]
new_image_inputs = {
"pixel_values": torch.cat(
[img["pixel_values"] for img in image_inputs], dim=0
),
"image_sizes": torch.cat([img["image_sizes"] for img in image_inputs]),
}
if "pixel_attention_mask" in image_input:
new_image_inputs["pixel_attention_mask"] = torch.cat(
[img["pixel_attention_mask"] for img in image_inputs], dim=0
)
if "image_sizes" in image_input:
new_image_inputs["image_sizes"] = torch.cat(
[img["image_sizes"] for img in image_inputs], dim=0
)
image_inputs = new_image_inputs
else:
image_inputs = None
return batch_tokenized_inputs, image_inputs
......@@ -187,9 +231,19 @@ class VlmCausalLMBatch(FlashMistralBatch):
batch = cls.from_tokenized(pb, tokenizer, batch_tokenized_inputs, dtype, device)
if image_inputs is not None:
batch.pixel_values = image_inputs["pixel_values"].to(device=device)
batch.image_sizes = image_inputs["image_sizes"].to(device=device)
if "pixel_attention_mask" in image_inputs:
batch.pixel_attention_mask = image_inputs["pixel_attention_mask"].to(
device=device
)
else:
batch.pixel_attention_mask = None
if "image_sizes" in image_inputs:
batch.image_sizes = image_inputs["image_sizes"].to(device=device)
else:
batch.image_sizes = None
else:
batch.pixel_values = None
batch.pixel_attention_mask = None
batch.image_sizes = None
return batch
......@@ -199,16 +253,6 @@ class VlmCausalLM(BaseFlashMistral):
def batch_type(self) -> Type[VlmCausalLMBatch]:
return VlmCausalLMBatch
def get_layer_config(self, model) -> Tuple[int, int, int]:
return (
len(model.language_model.model.layers),
model.language_model.model.num_key_value_heads,
model.language_model.model.head_size,
)
def max_past(self) -> Optional[int]:
return getattr(self.model.language_model, "max_past", None)
def forward(
self, batch: VlmCausalLMBatch
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
......@@ -270,17 +314,14 @@ class VlmCausalLM(BaseFlashMistral):
max_s = min(self.max_past(), max_s)
bs = input_ids.shape[0]
padded_bs = bs
if bs == 3:
padded_bs = 4
elif 3 < bs <= 8:
padded_bs = 8
elif bs > 8:
padded_bs = (bs + 7) // 8 * 8
# Try to find an associated cuda graph
cuda_graph = self.cuda_graphs.get(padded_bs, None)
bs = input_ids.shape[0]
sorted_padded_bs = sorted([k for k in self.cuda_graphs.keys() if k >= bs])
if sorted_padded_bs:
# Get associated cuda graph
cuda_graph = self.cuda_graphs[sorted_padded_bs[0]]
else:
cuda_graph = None
if cu_seqlen_prefill is not None or cuda_graph is None:
logits, speculative_logits = self.model.forward(
input_ids=input_ids,
......@@ -294,12 +335,15 @@ class VlmCausalLM(BaseFlashMistral):
prefill_cache_indices=batch.prefill_cache_indices,
lm_head_indices=lm_head_indices,
pixel_values=batch.pixel_values,
pixel_attention_mask=batch.pixel_attention_mask,
image_sizes=batch.image_sizes,
)
if batch.prefill_cache_indices is not None:
batch.prefill_cache_indices = None
if batch.pixel_values is not None:
batch.pixel_values = None
if batch.pixel_attention_mask is not None:
batch.pixel_attention_mask = None
if batch.image_sizes is not None:
batch.image_sizes = None
return logits, speculative_logits
......
......@@ -2,6 +2,7 @@ import asyncio
import os
import torch
import time
import signal
from grpc import aio
from loguru import logger
......@@ -19,6 +20,21 @@ from text_generation_server.tracing import UDSOpenTelemetryAioServerInterceptor
from text_generation_server.models.idefics_causal_lm import IdeficsCausalLMBatch
class SignalHandler:
KEEP_PROCESSING = True
def __init__(self):
signal.signal(signal.SIGINT, self.exit_gracefully)
signal.signal(signal.SIGTERM, self.exit_gracefully)
def exit_gracefully(self, signum, frame):
print(f"Exiting gracefully: Signal {signum}")
self.KEEP_PROCESSING = False
signal_handler = SignalHandler()
class TextGenerationService(generate_pb2_grpc.TextGenerationServiceServicer):
def __init__(
self,
......@@ -231,11 +247,8 @@ def serve(
logger.info("Server started at {}".format(local_url))
try:
await server.wait_for_termination()
except KeyboardInterrupt:
logger.info("Signal received. Shutting down")
await server.stop(0)
while signal_handler.KEEP_PROCESSING:
await asyncio.sleep(0.5)
asyncio.run(
serve_inner(
......
......@@ -57,7 +57,14 @@ def initialize_torch_distributed():
options.is_high_priority_stream = True
options._timeout = timedelta(seconds=60)
else:
backend = "gloo"
try:
import oneccl_bindings_for_pytorch
backend = "ccl"
if os.getenv("CCL_WORKER_COUNT", None) is None:
os.environ["CCL_WORKER_COUNT"] = str(1)
except ImportError:
backend = "gloo"
options = None
if WORLD_SIZE == 1:
......
......@@ -2,69 +2,81 @@ import os
import torch
from loguru import logger
import math
from text_generation_server.utils.import_utils import IS_CUDA_SYSTEM, IS_ROCM_SYSTEM
from text_generation_server.utils.import_utils import (
IS_CUDA_SYSTEM,
IS_ROCM_SYSTEM,
IS_XPU_SYSTEM,
)
if os.getenv("USE_FLASH_ATTENTION", "").lower() == "false":
raise ImportError("`USE_FLASH_ATTENTION` is false.")
HAS_FLASH_ATTN = True
HAS_FLASH_ATTN_V2_CUDA = False
HAS_FLASH_ATTN_V2_ROCM = False
if not torch.cuda.is_available():
raise ImportError("CUDA is not available")
if IS_XPU_SYSTEM:
import intel_extension_for_pytorch as ipex
major, minor = torch.cuda.get_device_capability()
is_sm75 = major == 7 and minor == 5
is_sm8x = major == 8 and minor >= 0
is_sm90 = major == 9 and minor == 0
if IS_CUDA_SYSTEM or IS_ROCM_SYSTEM:
if not torch.cuda.is_available():
raise ImportError("CUDA is not available")
HAS_FLASH_ATTN = False
HAS_FLASH_ATTN_V2_CUDA = False
HAS_FLASH_ATTN_V2_ROCM = False
try:
major, minor = torch.cuda.get_device_capability()
is_sm75 = major == 7 and minor == 5
is_sm8x = major == 8 and minor >= 0
is_sm90 = major == 9 and minor == 0
HAS_FLASH_ATTN = False
HAS_FLASH_ATTN_V2_CUDA = False
HAS_FLASH_ATTN_V2_ROCM = False
try:
import flash_attn_2_cuda
except ImportError:
architecture_suffix = ""
if IS_CUDA_SYSTEM:
architecture_suffix = "-cuda"
try:
import flash_attn_2_cuda
except ImportError:
architecture_suffix = ""
if IS_CUDA_SYSTEM:
architecture_suffix = "-cuda"
elif IS_ROCM_SYSTEM:
architecture_suffix = "-rocm"
raise ImportError(
"Flash Attention V2 is not installed.\n"
"Use the official Docker image (ghcr.io/huggingface/text-generation-inference:latest) "
f"or install flash attention v2 with `cd server && make install install-flash-attention-v2{architecture_suffix}`"
)
if not (is_sm8x or is_sm90):
raise ImportError(
f"GPU with CUDA capability {major} {minor} is not supported for "
"Flash Attention V2"
)
HAS_FLASH_ATTN_V2_CUDA = IS_CUDA_SYSTEM
HAS_FLASH_ATTN_V2_ROCM = IS_ROCM_SYSTEM
except ImportError as e:
try:
import flash_attn_cuda
except ImportError:
raise ImportError(
"Flash Attention is not installed.\n"
"Use the official Docker image (ghcr.io/huggingface/text-generation-inference:latest) "
"or install flash attention with `cd server && make install install-flash-attention`"
) from e
if IS_CUDA_SYSTEM and not (is_sm75 or is_sm8x or is_sm90):
raise ImportError(
f"GPU with CUDA capability {major} {minor} is not supported"
) from e
elif IS_ROCM_SYSTEM:
architecture_suffix = "-rocm"
raise ImportError(
"Flash Attention V2 is not installed.\n"
"Use the official Docker image (ghcr.io/huggingface/text-generation-inference:latest) "
f"or install flash attention v2 with `cd server && make install install-flash-attention-v2{architecture_suffix}`"
)
if not (is_sm8x or is_sm90) and IS_CUDA_SYSTEM:
raise ImportError(
f"GPU with CUDA capability {major} {minor} is not supported for "
"Flash Attention V2"
)
HAS_FLASH_ATTN_V2_CUDA = IS_CUDA_SYSTEM
HAS_FLASH_ATTN_V2_ROCM = IS_ROCM_SYSTEM
except ImportError as e:
try:
import flash_attn_cuda
except ImportError:
raise ImportError(
"Flash Attention is not installed.\n"
"Use the official Docker image (ghcr.io/huggingface/text-generation-inference:latest) "
"or install flash attention with `cd server && make install install-flash-attention`"
) from e
if IS_CUDA_SYSTEM and not (is_sm75 or is_sm8x or is_sm90):
raise ImportError(
f"GPU with CUDA capability {major} {minor} is not supported"
) from e
elif IS_ROCM_SYSTEM:
for idx in range(torch.cuda.device_count()):
if "MI210" not in torch.cuda.get_device_name(
idx
) and "MI250" not in torch.cuda.get_device_name(idx):
raise ImportError(
f"AMD GPU {torch.cuda.get_device_name(idx)} does not support flash-attention"
)
for idx in range(torch.cuda.device_count()):
if "MI210" not in torch.cuda.get_device_name(
idx
) and "MI250" not in torch.cuda.get_device_name(idx):
raise ImportError(
f"AMD GPU {torch.cuda.get_device_name(idx)} does not support flash-attention"
)
logger.warning(f"Unable to use Flash Attention V2: {e}")
HAS_FLASH_ATTN = True
logger.warning(f"Unable to use Flash Attention V2: {e}")
HAS_FLASH_ATTN = True
def attention(
......@@ -80,6 +92,28 @@ def attention(
if window_size_left <= 0 and window_size_left != -1:
raise ValueError("`window_size_left` must be > 0 or -1")
if IS_XPU_SYSTEM:
if window_size_left != -1:
raise ValueError(
f"XPU version of Flash Attention does not support window attention (window_size_left != -1, got window_size_left={window_size_left})."
)
return ipex.llm.functional.varlen_attention(
q,
k,
v,
out,
cu_seqlens,
cu_seqlens,
max_s,
max_s,
0.0,
softmax_scale,
False,
True,
False,
None,
)
if HAS_FLASH_ATTN_V2_CUDA:
return flash_attn_2_cuda.varlen_fwd(
q,
......
import torch
def is_xpu_available():
try:
import intel_extension_for_pytorch
except ImportError:
return False
return hasattr(torch, "xpu") and torch.xpu.is_available()
IS_ROCM_SYSTEM = torch.version.hip is not None
IS_CUDA_SYSTEM = torch.version.cuda is not None
IS_XPU_SYSTEM = is_xpu_available()
......@@ -8,6 +8,8 @@ from typing import List, Tuple, Optional
from loguru import logger
from functools import lru_cache
from text_generation_server.utils.speculate import get_speculate
HAS_BITS_AND_BYTES = True
try:
import bitsandbytes as bnb
......@@ -18,7 +20,14 @@ except ImportError:
from accelerate import init_empty_weights
from text_generation_server.utils.gptq.quant_linear import QuantLinear
from text_generation_server.utils.import_utils import IS_CUDA_SYSTEM, IS_ROCM_SYSTEM
from text_generation_server.utils.import_utils import (
IS_CUDA_SYSTEM,
IS_ROCM_SYSTEM,
IS_XPU_SYSTEM,
)
if IS_XPU_SYSTEM:
import intel_extension_for_pytorch as ipex
HAS_AWQ = True
try:
......@@ -437,7 +446,7 @@ class MedusaModel(torch.nn.Module):
self.heads = torch.nn.ModuleList(
[
MedusaHead(config, medusa_config, prefix=f"{i}", weights=weights)
for i in range(medusa_config["medusa_num_heads"])
for i in range(get_speculate())
]
)
......@@ -534,7 +543,7 @@ class MedusaHeadV2(nn.Module):
)
routing[k] = filename
self.n_medusa_heads = medusa_config["medusa_num_heads"]
self.n_medusa_heads = get_speculate()
assert medusa_config["medusa_num_layers"] == 1
self.linear = TensorParallelColumnLinear.load_multi(
......@@ -696,6 +705,19 @@ class TensorParallelHead(SuperLayer):
class TensorParallelColumnLinear(SuperLayer):
@classmethod
def load_gate_up(cls, config, prefix: str, weights, bias: bool):
"""Specific method when the QKV was joined after the fact"""
weight = weights.get_weights_col_packed_gate_up(
prefix, quantize=config.quantize
)
if bias:
raise NotImplementedError("packed_gate_up only implemented without bias")
else:
bias = None
linear = get_linear(weight, bias, config.quantize)
return cls(linear)
@classmethod
def load_qkv(cls, config, prefix: str, weights, bias: bool):
"""Specific method when the QKV was joined after the fact"""
......@@ -799,7 +821,15 @@ try:
class FastLayerNorm(nn.LayerNorm):
def forward(self, hidden_states, residual=None):
if hidden_states.shape[-1] > 8192 or IS_ROCM_SYSTEM:
if IS_XPU_SYSTEM:
res_out = hidden_states
out = ipex.llm.functional.add_layer_norm(
residual, hidden_states, self.weight, self.bias, self.eps, True
)
if residual is not None:
res_out = residual
return out, res_out
elif hidden_states.shape[-1] > 8192 or IS_ROCM_SYSTEM:
if residual is not None:
hidden_states += residual
residual = hidden_states
......@@ -845,7 +875,20 @@ try:
return cls(weight, eps)
def forward(self, hidden_states, residual=None):
if hidden_states.shape[-1] > 8192:
if IS_XPU_SYSTEM:
residual_out = hidden_states
out = ipex.llm.functional.add_rms_norm(
residual,
hidden_states,
self.weight,
None,
self.variance_epsilon,
True,
)
if residual is not None:
residual_out = residual
return out, residual_out
elif hidden_states.shape[-1] > 8192:
if residual is not None:
hidden_states += residual
residual = hidden_states
......@@ -971,6 +1014,10 @@ try:
# Inplace operation, updating query and key.
pos_encoding_ops.rotary_embedding(query, key, head_size, cos, sin, True)
elif IS_XPU_SYSTEM:
ipex.llm.functional.rotary_embedding(
query, key, sin, cos, query.size(-1), True
)
else:
raise ValueError(
"Your system seem to be not supported. Please check your install or open an issue at https://github.com/huggingface/text-generation-inference/issues with a clear reproduction."
......@@ -1090,6 +1137,7 @@ try:
cos = torch.index_select(self._cos_cached, 0, position_ids)
sin = torch.index_select(self._sin_cached, 0, position_ids)
# Note: this unsqueeze is not necessary on RoCm + VLLM ROPE implementation, but we leave it as is to avoid yet an other controlflow.
return cos.unsqueeze(1), sin.unsqueeze(1)
......
......@@ -143,13 +143,16 @@ class FrequencyPenaltyLogitsProcessor(LogitsProcessor):
score = torch.gather(scores, 1, input_ids)
# if score < 0 then penalty has to be multiplied to reduce the previous token probability
score = -torch.where(score < 0, score * self.penalty, score / self.penalty)
# set score to 0 where input_ids is a padding token
score *= input_ids.ne(0)
return scores.scatter_add_(1, input_ids, score)
class HeterogeneousFrequencyPenaltyLogitsProcessor(LogitsProcessor):
r"""
Frequency penalty as defined by OpenAI
Frequency penalty as defined by OpenAI in
https://platform.openai.com/docs/guides/text-generation/parameter-details
Args:
frequency_penalty (`List[float]`):
......@@ -163,13 +166,19 @@ class HeterogeneousFrequencyPenaltyLogitsProcessor(LogitsProcessor):
).unsqueeze(1)
def __call__(self, input_ids: torch.Tensor, scores: torch.Tensor) -> torch.Tensor:
score = torch.gather(scores, 1, input_ids)
# if score < 0 then penalty has to be multiplied to reduce the previous token probability
score = -torch.where(
score < 0, score * self.penalty_tensor, score / self.penalty_tensor
batch_size, input_size = input_ids.size()
vocab_size = scores.size(1)
# Calculate the frequency for each token so far
token_freq = torch.zeros(batch_size, vocab_size, device=input_ids.device)
token_freq.scatter_add_(
1, input_ids, torch.ones_like(input_ids, dtype=torch.float)
)
token_freq /= input_size
return scores.scatter_add_(1, input_ids, score)
# Apply the frequency penalty to logits
scores -= token_freq * self.penalty_tensor
return scores
def filter(self, indices):
self.penalty = [self.penalty[i] for i in indices]
......
import torch
from text_generation_server.utils.import_utils import IS_CUDA_SYSTEM, IS_ROCM_SYSTEM
from loguru import logger
from text_generation_server.utils.import_utils import (
IS_CUDA_SYSTEM,
IS_ROCM_SYSTEM,
IS_XPU_SYSTEM,
)
_PARTITION_SIZE = 512
if IS_XPU_SYSTEM:
import intel_extension_for_pytorch as ipex
def reshape_and_cache(
key: torch.Tensor,
......@@ -22,8 +27,11 @@ def reshape_and_cache(
elif IS_ROCM_SYSTEM:
from vllm import cache_ops
cache_ops.reshape_and_cache(key, value, key_cache, value_cache, slots.int())
# cache_ops.reshape_and_cache(key, value, key_cache, value_cache, slots)
cache_ops.reshape_and_cache(key, value, key_cache, value_cache, slots)
elif IS_XPU_SYSTEM:
ipex.llm.modules.PagedAttention.reshape_and_cache(
key, value, key_cache, value_cache, slots
)
else:
raise ValueError("vllm is not supported on your system")
......@@ -60,6 +68,22 @@ def attention(
block_size = value_cache.shape[3]
num_seqs, num_heads, head_size = query.shape
max_num_partitions = (max_s + _PARTITION_SIZE - 1) // _PARTITION_SIZE
if IS_XPU_SYSTEM:
query = query.contiguous()
return ipex.llm.modules.PagedAttention.single_query_cached_kv_attention(
out,
query,
key_cache,
value_cache,
kv_head_mapping,
softmax_scale,
block_tables,
input_lengths,
block_size,
max_s,
None,
)
# NOTE(woosuk): We use a simple heuristic to decide whether to use
# PagedAttention V1 or V2. If the number of partitions is 1, we use
# V1 to avoid the overhead of reduction. Also, if the number of
......
import re
from typing import List, Optional, Tuple
from typing import List, Optional, Tuple, Set, Union
import math
import torch
......@@ -143,12 +143,22 @@ class StopSequenceCriteria:
class StoppingCriteria:
def __init__(
self,
eos_token_id: int,
eos_token_ids: Optional[Union[Set[int], int]],
stop_sequence_criterias: List[StopSequenceCriteria],
max_new_tokens: int = 20,
ignore_eos_token: bool = False,
):
self.eos_token_id = eos_token_id
if eos_token_ids is None:
eos_token_ids = set()
elif isinstance(eos_token_ids, int):
eos_token_ids = set([eos_token_ids])
elif isinstance(eos_token_ids, set):
eos_token_ids = eos_token_ids
else:
raise RuntimeError(
f"eos_token_ids is of invalid type {type(eos_token_ids)}, expected int, None or set[int]"
)
self.eos_token_ids = eos_token_ids
self.stop_sequence_criterias = stop_sequence_criterias
self.max_new_tokens = max_new_tokens
self.current_tokens = 0
......@@ -160,7 +170,10 @@ class StoppingCriteria:
if self.current_tokens >= self.max_new_tokens:
return True, FinishReason.FINISH_REASON_LENGTH
if not self.ignore_eos_token and last_token == self.eos_token_id:
if isinstance(last_token, torch.Tensor):
last_token = last_token.item()
if not self.ignore_eos_token and last_token in self.eos_token_ids:
return True, FinishReason.FINISH_REASON_EOS_TOKEN
if self.stop_sequence_criterias:
......@@ -184,8 +197,10 @@ class StoppingCriteria:
stop_sequence_criterias = [
StopSequenceCriteria(sequence) for sequence in pb.stop_sequences
]
# TODO Hack because eos_token_id cannot be what we want.
eos_token_id = getattr(tokenizer, "_eos_token_ids", tokenizer.eos_token_id)
return StoppingCriteria(
tokenizer.eos_token_id,
eos_token_id,
stop_sequence_criterias,
pb.max_new_tokens,
pb.ignore_eos_token,
......@@ -273,7 +288,7 @@ class HeterogeneousNextTokenChooser:
else None
)
if any([x != 1.0 for x in temperature]):
if any(x != 1.0 for x in temperature):
do_sample = [
sample or x != 1.0 for x, sample in zip(temperature, do_sample)
]
......@@ -281,15 +296,15 @@ class HeterogeneousNextTokenChooser:
HeterogeneousTemperatureLogitsWarper(temperature, dtype, device)
)
if any([x != 0 for x in top_k]):
if any(x != 0 for x in top_k):
do_sample = [sample or x != 0 for x, sample in zip(top_k, do_sample)]
warpers.append(HeterogeneousTopKLogitsWarper(top_k, device))
if any([x < 1.0 for x in top_p]):
if any(x < 1.0 for x in top_p):
do_sample = [sample or x < 1.0 for x, sample in zip(top_p, do_sample)]
warpers.append(HeterogeneousTopPLogitsWarper(top_p, dtype, device))
if any([x < 1.0 for x in typical_p]):
if any(x < 1.0 for x in typical_p):
do_sample = [sample or x < 1.0 for x, sample in zip(typical_p, do_sample)]
warpers.append(HeterogeneousTypicalLogitsWarper(typical_p, dtype, device))
......
......@@ -141,6 +141,12 @@ class Weights:
return weight
def get_weights_col_packed_qkv(self, prefix: str, quantize: str):
return self.get_weights_col_packed(prefix, quantize, 3)
def get_weights_col_packed_gate_up(self, prefix: str, quantize: str):
return self.get_weights_col_packed(prefix, quantize, 2)
def get_weights_col_packed(self, prefix: str, quantize: str, blocks: int):
"""
Highly specific when the underlying tensor is a simple cat of Q,K,V instead of being
already alternating Q,K,V within the main tensor
......@@ -181,8 +187,8 @@ class Weights:
else:
slice_ = self._get_slice(f"{prefix}.weight")
total_size = slice_.get_shape()[0]
assert total_size % 3 == 0, "Prepacked qkv is not divisible by 3"
single_size = total_size // 3
assert total_size % blocks == 0, f"Prepacked is not divisible by {blocks}"
single_size = total_size // blocks
world_size = self.process_group.size()
rank = self.process_group.rank()
......@@ -192,10 +198,11 @@ class Weights:
block_size = single_size // world_size
start = rank * block_size
stop = (rank + 1) * block_size
q = slice_[start:stop]
k = slice_[start + single_size : stop + single_size]
v = slice_[start + 2 * single_size : stop + 2 * single_size]
weight = torch.cat([q, k, v], dim=0)
tensors = []
for i in range(blocks):
tensor = slice_[start + i * single_size : stop + i * single_size]
tensors.append(tensor)
weight = torch.cat(tensors, dim=0)
weight = weight.to(device=self.device)
weight = weight.to(dtype=self.dtype)
return weight
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment