Unverified Commit 2f8844ba authored by Zhuohan Li's avatar Zhuohan Li Committed by GitHub
Browse files

Re-enable the 80 char line width limit (#3305)

parent 4b59f00e
...@@ -45,7 +45,7 @@ class ModelConfig: ...@@ -45,7 +45,7 @@ class ModelConfig:
a tag name, or a commit id. If unspecified, will use the default a tag name, or a commit id. If unspecified, will use the default
version. version.
code_revision: The specific revision to use for the model code on code_revision: The specific revision to use for the model code on
Hugging Face Hub. It can be a branch name, a tag name, or a Hugging Face Hub. It can be a branch name, a tag name, or a
commit id. If unspecified, will use the default version. commit id. If unspecified, will use the default version.
tokenizer_revision: The specific tokenizer version to use. It can be a tokenizer_revision: The specific tokenizer version to use. It can be a
branch name, a tag name, or a commit id. If unspecified, will use branch name, a tag name, or a commit id. If unspecified, will use
...@@ -189,8 +189,8 @@ class ModelConfig: ...@@ -189,8 +189,8 @@ class ModelConfig:
if is_hip( if is_hip(
) and self.quantization in rocm_not_supported_quantization: ) and self.quantization in rocm_not_supported_quantization:
raise ValueError( raise ValueError(
f"{self.quantization} quantization is currently not supported " f"{self.quantization} quantization is currently not "
f"in ROCm.") f"supported in ROCm.")
if self.quantization != "marlin": if self.quantization != "marlin":
logger.warning( logger.warning(
f"{self.quantization} quantization is not fully " f"{self.quantization} quantization is not fully "
...@@ -321,7 +321,8 @@ class CacheConfig: ...@@ -321,7 +321,8 @@ class CacheConfig:
self.num_cpu_blocks = None self.num_cpu_blocks = None
def metrics_info(self): def metrics_info(self):
# convert cache_config to dict(key: str, value: str) for prometheus metrics info # convert cache_config to dict(key: str, value: str) for prometheus
# metrics info
return {key: str(value) for key, value in self.__dict__.items()} return {key: str(value) for key, value in self.__dict__.items()}
def _verify_args(self) -> None: def _verify_args(self) -> None:
...@@ -399,8 +400,9 @@ class ParallelConfig: ...@@ -399,8 +400,9 @@ class ParallelConfig:
) -> None: ) -> None:
self.pipeline_parallel_size = pipeline_parallel_size self.pipeline_parallel_size = pipeline_parallel_size
if is_neuron(): if is_neuron():
# For Neuron device support, here we assign TP=1 to avoid sharding within vLLM directly. # For Neuron device support, here we assign TP=1 to avoid sharding
# Transformer-neuronx would take neuron_tp_degree attribute, and distribute the workload # within vLLM directly. Transformer-neuronx would take
# neuron_tp_degree attribute, and distribute the workload
# to multiple NeuronCores. # to multiple NeuronCores.
self.tensor_parallel_size = 1 self.tensor_parallel_size = 1
self.neuron_tp_degree = tensor_parallel_size self.neuron_tp_degree = tensor_parallel_size
......
...@@ -95,13 +95,15 @@ class BlockAllocator: ...@@ -95,13 +95,15 @@ class BlockAllocator:
del self.cached_blocks[block.block_hash] del self.cached_blocks[block.block_hash]
def get_num_free_blocks(self) -> int: def get_num_free_blocks(self) -> int:
return self.num_blocks - self.current_num_blocks + self.evictor.num_blocks return (self.num_blocks - self.current_num_blocks +
self.evictor.num_blocks)
def contains_block(self, block_hash: int) -> bool: def contains_block(self, block_hash: int) -> bool:
return block_hash in self.cached_blocks or block_hash in self.evictor return block_hash in self.cached_blocks or block_hash in self.evictor
def update_hash(self, block_hash: int, block: PhysicalTokenBlock): def update_hash(self, block_hash: int, block: PhysicalTokenBlock):
# If caching is enabled, update the hash of block and the cached_blocks dictionary. # If caching is enabled, update the hash of block and the
# cached_blocks dictionary.
if self.enable_caching: if self.enable_caching:
assert not self.contains_block(block_hash) assert not self.contains_block(block_hash)
old_hash = block.block_hash old_hash = block.block_hash
...@@ -218,10 +220,12 @@ class BlockSpaceManager: ...@@ -218,10 +220,12 @@ class BlockSpaceManager:
seq: Sequence, seq: Sequence,
last_block: PhysicalTokenBlock, last_block: PhysicalTokenBlock,
) -> PhysicalTokenBlock: ) -> PhysicalTokenBlock:
# Compute a new hash for the block so that it can be shared by other Sequences # Compute a new hash for the block so that it can be shared by
# other Sequences
new_hash = seq.hash_of_block(len(seq.logical_token_blocks) - 1) new_hash = seq.hash_of_block(len(seq.logical_token_blocks) - 1)
# if new_hash is already in the cached table, then free last_block and return the cached version # if new_hash is already in the cached table, then free last_block
# and return the cached version
if self.gpu_allocator.contains_block(new_hash): if self.gpu_allocator.contains_block(new_hash):
self.gpu_allocator.free(last_block) self.gpu_allocator.free(last_block)
return self.gpu_allocator.allocate(new_hash) return self.gpu_allocator.allocate(new_hash)
...@@ -289,7 +293,8 @@ class BlockSpaceManager: ...@@ -289,7 +293,8 @@ class BlockSpaceManager:
assert last_block.device == Device.GPU assert last_block.device == Device.GPU
if last_block.ref_count == 1: if last_block.ref_count == 1:
# Not shared with other sequences. Appendable. # Not shared with other sequences. Appendable.
# If the last block is now complete, promote it to a full block so that it can be shared # If the last block is now complete, promote it to a full block so
# that it can be shared
new_block = self._maybe_promote_last_block(seq, last_block) new_block = self._maybe_promote_last_block(seq, last_block)
block_table[-1] = new_block block_table[-1] = new_block
return None return None
......
...@@ -39,9 +39,9 @@ class Evictor(ABC): ...@@ -39,9 +39,9 @@ class Evictor(ABC):
@abstractmethod @abstractmethod
def remove(self, block_hash: int) -> PhysicalTokenBlock: def remove(self, block_hash: int) -> PhysicalTokenBlock:
"""Simply removes the block with the hash value block_hash from the """Simply removes the block with the hash value block_hash from the
evictor. Caller is responsible for making sure that block_hash is contained evictor. Caller is responsible for making sure that block_hash is
in the evictor before calling remove. Should be used to "bring back" blocks contained in the evictor before calling remove. Should be used to
that have been freed but not evicted yet. "bring back" blocks that have been freed but not evicted yet.
""" """
pass pass
......
...@@ -214,8 +214,8 @@ class Scheduler: ...@@ -214,8 +214,8 @@ class Scheduler:
lora_int_id = 0 lora_int_id = 0
if self.lora_enabled: if self.lora_enabled:
lora_int_id = seq_group.lora_int_id lora_int_id = seq_group.lora_int_id
if lora_int_id > 0 and lora_int_id not in curr_loras and len( if (lora_int_id > 0 and lora_int_id not in curr_loras
curr_loras) >= self.lora_config.max_loras: and len(curr_loras) >= self.lora_config.max_loras):
# We don't have a space for another LoRA, so # We don't have a space for another LoRA, so
# we ignore this request for now. # we ignore this request for now.
leftover_waiting_sequences.appendleft(seq_group) leftover_waiting_sequences.appendleft(seq_group)
...@@ -309,8 +309,8 @@ class Scheduler: ...@@ -309,8 +309,8 @@ class Scheduler:
lora_int_id = 0 lora_int_id = 0
if self.lora_enabled: if self.lora_enabled:
lora_int_id = seq_group.lora_int_id lora_int_id = seq_group.lora_int_id
if lora_int_id > 0 and lora_int_id not in curr_loras and len( if (lora_int_id > 0 and lora_int_id not in curr_loras
curr_loras) >= self.lora_config.max_loras: and len(curr_loras) >= self.lora_config.max_loras):
# We don't have a space for another LoRA, so # We don't have a space for another LoRA, so
# we ignore this request for now. # we ignore this request for now.
leftover_swapped.appendleft(seq_group) leftover_swapped.appendleft(seq_group)
......
...@@ -100,7 +100,8 @@ class LLMEngine: ...@@ -100,7 +100,8 @@ class LLMEngine:
f"download_dir={model_config.download_dir!r}, " f"download_dir={model_config.download_dir!r}, "
f"load_format={model_config.load_format}, " f"load_format={model_config.load_format}, "
f"tensor_parallel_size={parallel_config.tensor_parallel_size}, " f"tensor_parallel_size={parallel_config.tensor_parallel_size}, "
f"disable_custom_all_reduce={parallel_config.disable_custom_all_reduce}, " f"disable_custom_all_reduce="
f"{parallel_config.disable_custom_all_reduce}, "
f"quantization={model_config.quantization}, " f"quantization={model_config.quantization}, "
f"enforce_eager={model_config.enforce_eager}, " f"enforce_eager={model_config.enforce_eager}, "
f"kv_cache_dtype={cache_config.cache_dtype}, " f"kv_cache_dtype={cache_config.cache_dtype}, "
...@@ -929,7 +930,8 @@ class LLMEngine: ...@@ -929,7 +930,8 @@ class LLMEngine:
# Latency Timings. # Latency Timings.
time_last_iters = [] time_last_iters = []
for seq_group in scheduler_outputs.scheduled_seq_groups: for seq_group in scheduler_outputs.scheduled_seq_groups:
# Time since last token. (n.b. updates seq_group.metrics.last_token_time) # Time since last token.
# (n.b. updates seq_group.metrics.last_token_time)
time_last_iters.append(seq_group.get_last_latency(now)) time_last_iters.append(seq_group.get_last_latency(now))
# Time since arrival for all finished requests. # Time since arrival for all finished requests.
if seq_group.is_finished(): if seq_group.is_finished():
...@@ -961,16 +963,17 @@ class LLMEngine: ...@@ -961,16 +963,17 @@ class LLMEngine:
for token_id, sample_logprob in logprobs.items(): for token_id, sample_logprob in logprobs.items():
if (sample_logprob.decoded_token is None and token_id != -1): if (sample_logprob.decoded_token is None and token_id != -1):
all_input_ids_with_logprob = all_input_ids[:-1] + [token_id] all_input_ids_with_logprob = all_input_ids[:-1] + [token_id]
_, new_text, prefix_offset, read_offset = detokenize_incrementally( (_, new_text, prefix_offset,
self.get_tokenizer_for_seq(seq), read_offset) = detokenize_incrementally(
all_input_ids=all_input_ids_with_logprob, self.get_tokenizer_for_seq(seq),
prev_tokens=seq.tokens, all_input_ids=all_input_ids_with_logprob,
prefix_offset=seq.prefix_offset, prev_tokens=seq.tokens,
read_offset=seq.read_offset, prefix_offset=seq.prefix_offset,
skip_special_tokens=prms.skip_special_tokens, read_offset=seq.read_offset,
spaces_between_special_tokens=prms. skip_special_tokens=prms.skip_special_tokens,
spaces_between_special_tokens, spaces_between_special_tokens=prms.
) spaces_between_special_tokens,
)
sample_logprob.decoded_token = new_text sample_logprob.decoded_token = new_text
def _decode_sequence(self, seq: Sequence, prms: SamplingParams) -> None: def _decode_sequence(self, seq: Sequence, prms: SamplingParams) -> None:
......
from vllm.logger import init_logger from vllm.logger import init_logger
from prometheus_client import Counter, Gauge, Histogram, Info, REGISTRY, disable_created_metrics from prometheus_client import (Counter, Gauge, Histogram, Info, REGISTRY,
disable_created_metrics)
import time import time
import numpy as np import numpy as np
...@@ -177,10 +178,12 @@ class StatLogger: ...@@ -177,10 +178,12 @@ class StatLogger:
def _log_prometheus_interval(self, prompt_throughput: float, def _log_prometheus_interval(self, prompt_throughput: float,
generation_throughput: float) -> None: generation_throughput: float) -> None:
# Logs metrics to prometheus that are computed every logging_interval. # Logs metrics to prometheus that are computed every logging_interval.
# Support legacy gauge metrics that make throughput calculations on the vLLM side. # Support legacy gauge metrics that make throughput calculations on
# Moving forward, we should use counters like counter_prompt_tokens, counter_generation_tokens # the vLLM side. Moving forward, we should use counters like
# Which log raw data and calculate summaries using rate() on the grafana/prometheus side. # counter_prompt_tokens, counter_generation_tokens
# See https://github.com/vllm-project/vllm/pull/2316#discussion_r1464204666 # Which log raw data and calculate summaries using rate() on the
# grafana/prometheus side. See
# https://github.com/vllm-project/vllm/pull/2316#discussion_r1464204666
self.metrics.gauge_avg_prompt_throughput.labels( self.metrics.gauge_avg_prompt_throughput.labels(
**self.labels).set(prompt_throughput) **self.labels).set(prompt_throughput)
self.metrics.gauge_avg_generation_throughput.labels( self.metrics.gauge_avg_generation_throughput.labels(
...@@ -188,7 +191,7 @@ class StatLogger: ...@@ -188,7 +191,7 @@ class StatLogger:
def log(self, stats: Stats) -> None: def log(self, stats: Stats) -> None:
"""Called by LLMEngine. """Called by LLMEngine.
Logs to prometheus and tracked stats every iteration. Logs to prometheus and tracked stats every iteration.
Logs to Stdout every self.local_interval seconds.""" Logs to Stdout every self.local_interval seconds."""
# Log to prometheus. # Log to prometheus.
...@@ -200,8 +203,8 @@ class StatLogger: ...@@ -200,8 +203,8 @@ class StatLogger:
# Log locally every local_interval seconds. # Log locally every local_interval seconds.
if self._local_interval_elapsed(stats.now): if self._local_interval_elapsed(stats.now):
# Compute summary metrics for tracked stats (and log them
# Compute summary metrics for tracked stats (and log them to promethus if applicable). # to promethus if applicable).
prompt_throughput = self._get_throughput(self.num_prompt_tokens, prompt_throughput = self._get_throughput(self.num_prompt_tokens,
now=stats.now) now=stats.now)
generation_throughput = self._get_throughput( generation_throughput = self._get_throughput(
...@@ -213,7 +216,8 @@ class StatLogger: ...@@ -213,7 +216,8 @@ class StatLogger:
# Log to stdout. # Log to stdout.
logger.info( logger.info(
f"Avg prompt throughput: {prompt_throughput:.1f} tokens/s, " f"Avg prompt throughput: {prompt_throughput:.1f} tokens/s, "
f"Avg generation throughput: {generation_throughput:.1f} tokens/s, " f"Avg generation throughput: "
f"{generation_throughput:.1f} tokens/s, "
f"Running: {stats.num_running} reqs, " f"Running: {stats.num_running} reqs, "
f"Swapped: {stats.num_swapped} reqs, " f"Swapped: {stats.num_swapped} reqs, "
f"Pending: {stats.num_waiting} reqs, " f"Pending: {stats.num_waiting} reqs, "
......
""" """
NOTE: This API server is used only for demonstrating usage of AsyncEngine and simple performance benchmarks. NOTE: This API server is used only for demonstrating usage of AsyncEngine
It is not intended for production use. For production use, we recommend using our OpenAI compatible server. and simple performance benchmarks. It is not intended for production use.
We are also not going to accept PRs modifying this file, please change `vllm/entrypoints/openai/api_server.py` instead. For production use, we recommend using our OpenAI compatible server.
We are also not going to accept PRs modifying this file, please
change `vllm/entrypoints/openai/api_server.py` instead.
""" """
import argparse import argparse
......
...@@ -18,7 +18,9 @@ from fastapi.responses import JSONResponse, StreamingResponse, Response ...@@ -18,7 +18,9 @@ from fastapi.responses import JSONResponse, StreamingResponse, Response
import vllm import vllm
from vllm.engine.arg_utils import AsyncEngineArgs from vllm.engine.arg_utils import AsyncEngineArgs
from vllm.engine.async_llm_engine import AsyncLLMEngine from vllm.engine.async_llm_engine import AsyncLLMEngine
from vllm.entrypoints.openai.protocol import CompletionRequest, ChatCompletionRequest, ErrorResponse from vllm.entrypoints.openai.protocol import (CompletionRequest,
ChatCompletionRequest,
ErrorResponse)
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.entrypoints.openai.serving_chat import OpenAIServingChat from vllm.entrypoints.openai.serving_chat import OpenAIServingChat
from vllm.entrypoints.openai.serving_completion import OpenAIServingCompletion from vllm.entrypoints.openai.serving_completion import OpenAIServingCompletion
...@@ -84,13 +86,11 @@ def parse_args(): ...@@ -84,13 +86,11 @@ def parse_args():
type=json.loads, type=json.loads,
default=["*"], default=["*"],
help="allowed headers") help="allowed headers")
parser.add_argument( parser.add_argument("--api-key",
"--api-key", type=str,
type=str, default=None,
default=None, help="If provided, the server will require this key "
help= "to be presented in the header.")
"If provided, the server will require this key to be presented in the header."
)
parser.add_argument("--served-model-name", parser.add_argument("--served-model-name",
type=str, type=str,
default=None, default=None,
...@@ -103,9 +103,8 @@ def parse_args(): ...@@ -103,9 +103,8 @@ def parse_args():
default=None, default=None,
nargs='+', nargs='+',
action=LoRAParserAction, action=LoRAParserAction,
help= help="LoRA module configurations in the format name=path. "
"LoRA module configurations in the format name=path. Multiple modules can be specified." "Multiple modules can be specified.")
)
parser.add_argument("--chat-template", parser.add_argument("--chat-template",
type=str, type=str,
default=None, default=None,
...@@ -138,9 +137,10 @@ def parse_args(): ...@@ -138,9 +137,10 @@ def parse_args():
help="Additional ASGI middleware to apply to the app. " help="Additional ASGI middleware to apply to the app. "
"We accept multiple --middleware arguments. " "We accept multiple --middleware arguments. "
"The value should be an import path. " "The value should be an import path. "
"If a function is provided, vLLM will add it to the server using @app.middleware('http'). " "If a function is provided, vLLM will add it to the server "
"If a class is provided, vLLM will add it to the server using app.add_middleware(). " "using @app.middleware('http'). "
) "If a class is provided, vLLM will add it to the server "
"using app.add_middleware(). ")
parser = AsyncEngineArgs.add_cli_args(parser) parser = AsyncEngineArgs.add_cli_args(parser)
return parser.parse_args() return parser.parse_args()
...@@ -235,9 +235,8 @@ if __name__ == "__main__": ...@@ -235,9 +235,8 @@ if __name__ == "__main__":
elif inspect.iscoroutinefunction(imported): elif inspect.iscoroutinefunction(imported):
app.middleware("http")(imported) app.middleware("http")(imported)
else: else:
raise ValueError( raise ValueError(f"Invalid middleware {middleware}. "
f"Invalid middleware {middleware}. Must be a function or a class." f"Must be a function or a class.")
)
logger.info(f"vLLM API server version {vllm.__version__}") logger.info(f"vLLM API server version {vllm.__version__}")
logger.info(f"args: {args}") logger.info(f"args: {args}")
......
...@@ -12,7 +12,8 @@ from vllm.entrypoints.openai.protocol import ( ...@@ -12,7 +12,8 @@ from vllm.entrypoints.openai.protocol import (
UsageInfo) UsageInfo)
from vllm.outputs import RequestOutput from vllm.outputs import RequestOutput
from vllm.entrypoints.openai.serving_engine import OpenAIServing, LoRA from vllm.entrypoints.openai.serving_engine import OpenAIServing, LoRA
from vllm.model_executor.guided_decoding import get_guided_decoding_logits_processor from vllm.model_executor.guided_decoding import (
get_guided_decoding_logits_processor)
logger = init_logger(__name__) logger = init_logger(__name__)
...@@ -37,8 +38,9 @@ class OpenAIServingChat(OpenAIServing): ...@@ -37,8 +38,9 @@ class OpenAIServingChat(OpenAIServing):
ChatCompletionResponse]: ChatCompletionResponse]:
"""Completion API similar to OpenAI's API. """Completion API similar to OpenAI's API.
See https://platform.openai.com/docs/api-reference/chat/create See https://platform.openai.com/docs/api-reference/chat/create
for the API specification. This API mimics the OpenAI ChatCompletion API. for the API specification. This API mimics the OpenAI
ChatCompletion API.
NOTE: Currently we do not support the following feature: NOTE: Currently we do not support the following feature:
- function_call (Users should implement this by themselves) - function_call (Users should implement this by themselves)
...@@ -116,7 +118,8 @@ class OpenAIServingChat(OpenAIServing): ...@@ -116,7 +118,8 @@ class OpenAIServingChat(OpenAIServing):
# the result_generator, it needs to be sent as the FIRST # the result_generator, it needs to be sent as the FIRST
# response (by the try...catch). # response (by the try...catch).
if first_iteration: if first_iteration:
# Send first response for each request.n (index) with the role # Send first response for each request.n (index) with
# the role
role = self.get_chat_request_role(request) role = self.get_chat_request_role(request)
for i in range(request.n): for i in range(request.n):
choice_data = ChatCompletionResponseStreamChoice( choice_data = ChatCompletionResponseStreamChoice(
...@@ -133,7 +136,8 @@ class OpenAIServingChat(OpenAIServing): ...@@ -133,7 +136,8 @@ class OpenAIServingChat(OpenAIServing):
data = chunk.model_dump_json(exclude_unset=True) data = chunk.model_dump_json(exclude_unset=True)
yield f"data: {data}\n\n" yield f"data: {data}\n\n"
# Send response to echo the input portion of the last message # Send response to echo the input portion of the
# last message
if request.echo: if request.echo:
last_msg_content = "" last_msg_content = ""
if request.messages and isinstance( if request.messages and isinstance(
...@@ -145,11 +149,12 @@ class OpenAIServingChat(OpenAIServing): ...@@ -145,11 +149,12 @@ class OpenAIServingChat(OpenAIServing):
if last_msg_content: if last_msg_content:
for i in range(request.n): for i in range(request.n):
choice_data = ChatCompletionResponseStreamChoice( choice_data = (
index=i, ChatCompletionResponseStreamChoice(
delta=DeltaMessage( index=i,
content=last_msg_content), delta=DeltaMessage(
finish_reason=None) content=last_msg_content),
finish_reason=None))
chunk = ChatCompletionStreamResponse( chunk = ChatCompletionStreamResponse(
id=request_id, id=request_id,
object=chunk_object_type, object=chunk_object_type,
......
import asyncio import asyncio
import time import time
from fastapi import Request from fastapi import Request
from typing import AsyncGenerator, AsyncIterator, Callable, List, Optional, Dict, Tuple from typing import (AsyncGenerator, AsyncIterator, Callable, List, Optional,
Dict, Tuple)
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.utils import random_uuid from vllm.utils import random_uuid
from vllm.engine.async_llm_engine import AsyncLLMEngine from vllm.engine.async_llm_engine import AsyncLLMEngine
...@@ -16,7 +17,8 @@ from vllm.entrypoints.openai.protocol import ( ...@@ -16,7 +17,8 @@ from vllm.entrypoints.openai.protocol import (
) )
from vllm.outputs import RequestOutput from vllm.outputs import RequestOutput
from vllm.entrypoints.openai.serving_engine import OpenAIServing, LoRA from vllm.entrypoints.openai.serving_engine import OpenAIServing, LoRA
from vllm.model_executor.guided_decoding import get_guided_decoding_logits_processor from vllm.model_executor.guided_decoding import (
get_guided_decoding_logits_processor)
logger = init_logger(__name__) logger = init_logger(__name__)
...@@ -44,9 +46,8 @@ def parse_prompt_format(prompt) -> Tuple[bool, list]: ...@@ -44,9 +46,8 @@ def parse_prompt_format(prompt) -> Tuple[bool, list]:
prompt_is_tokens = True prompt_is_tokens = True
prompts = prompt # case 4: array of token arrays prompts = prompt # case 4: array of token arrays
else: else:
raise ValueError( raise ValueError("prompt must be a string, array of strings, "
"prompt must be a string, array of strings, array of tokens, or array of token arrays" "array of tokens, or array of token arrays")
)
return prompt_is_tokens, prompts return prompt_is_tokens, prompts
...@@ -156,7 +157,8 @@ class OpenAIServingCompletion(OpenAIServing): ...@@ -156,7 +157,8 @@ class OpenAIServingCompletion(OpenAIServing):
int, RequestOutput]] = merge_async_iterators(*generators) int, RequestOutput]] = merge_async_iterators(*generators)
# Similar to the OpenAI API, when n != best_of, we do not stream the # Similar to the OpenAI API, when n != best_of, we do not stream the
# results. In addition, we do not stream the results when use beam search. # results. In addition, we do not stream the results when use
# beam search.
stream = (request.stream stream = (request.stream
and (request.best_of is None or request.n == request.best_of) and (request.best_of is None or request.n == request.best_of)
and not request.use_beam_search) and not request.use_beam_search)
...@@ -223,7 +225,8 @@ class OpenAIServingCompletion(OpenAIServing): ...@@ -223,7 +225,8 @@ class OpenAIServingCompletion(OpenAIServing):
for output in res.outputs: for output in res.outputs:
i = output.index + prompt_idx * request.n i = output.index + prompt_idx * request.n
# TODO(simon): optimize the performance by avoiding full text O(n^2) sending. # TODO(simon): optimize the performance by avoiding full
# text O(n^2) sending.
if request.echo and request.max_tokens == 0: if request.echo and request.max_tokens == 0:
# only return the prompt # only return the prompt
...@@ -231,11 +234,12 @@ class OpenAIServingCompletion(OpenAIServing): ...@@ -231,11 +234,12 @@ class OpenAIServingCompletion(OpenAIServing):
delta_token_ids = res.prompt_token_ids delta_token_ids = res.prompt_token_ids
top_logprobs = res.prompt_logprobs top_logprobs = res.prompt_logprobs
has_echoed[i] = True has_echoed[i] = True
elif request.echo and request.max_tokens > 0 and not has_echoed[ elif (request.echo and request.max_tokens > 0
i]: and not has_echoed[i]):
# echo the prompt and first token # echo the prompt and first token
delta_text = res.prompt + output.text delta_text = res.prompt + output.text
delta_token_ids = res.prompt_token_ids + output.token_ids delta_token_ids = (res.prompt_token_ids +
output.token_ids)
top_logprobs = res.prompt_logprobs + (output.logprobs top_logprobs = res.prompt_logprobs + (output.logprobs
or []) or [])
has_echoed[i] = True has_echoed[i] = True
...@@ -248,7 +252,9 @@ class OpenAIServingCompletion(OpenAIServing): ...@@ -248,7 +252,9 @@ class OpenAIServingCompletion(OpenAIServing):
i]:] if output.logprobs else None i]:] if output.logprobs else None
if request.logprobs is not None: if request.logprobs is not None:
assert top_logprobs is not None, "top_logprobs must be provided when logprobs is requested" assert top_logprobs is not None, (
"top_logprobs must be provided when logprobs "
"is requested")
logprobs = self._create_logprobs( logprobs = self._create_logprobs(
token_ids=delta_token_ids, token_ids=delta_token_ids,
top_logprobs=top_logprobs, top_logprobs=top_logprobs,
......
...@@ -50,10 +50,12 @@ class OpenAIServing: ...@@ -50,10 +50,12 @@ class OpenAIServing:
except RuntimeError: except RuntimeError:
event_loop = None event_loop = None
if event_loop is not None and event_loop.is_running( if event_loop is not None and event_loop.is_running():
): # If the current is instanced by Ray Serve, there is already a running event loop # If the current is instanced by Ray Serve,
# there is already a running event loop
event_loop.create_task(self._post_init()) event_loop.create_task(self._post_init())
else: # When using single vLLM without engine_use_ray else:
# When using single vLLM without engine_use_ray
asyncio.run(self._post_init()) asyncio.run(self._post_init())
async def _post_init(self): async def _post_init(self):
...@@ -178,8 +180,9 @@ class OpenAIServing: ...@@ -178,8 +180,9 @@ class OpenAIServing:
if token_num + request.max_tokens > self.max_model_len: if token_num + request.max_tokens > self.max_model_len:
raise ValueError( raise ValueError(
f"This model's maximum context length is {self.max_model_len} tokens. " f"This model's maximum context length is "
f"However, you requested {request.max_tokens + token_num} tokens " f"{self.max_model_len} tokens. However, you requested "
f"{request.max_tokens + token_num} tokens "
f"({token_num} in the messages, " f"({token_num} in the messages, "
f"{request.max_tokens} in the completion). " f"{request.max_tokens} in the completion). "
f"Please reduce the length of the messages or completion.", ) f"Please reduce the length of the messages or completion.", )
......
...@@ -20,10 +20,12 @@ from vllm.model_executor.layers.linear import (ColumnParallelLinear, ...@@ -20,10 +20,12 @@ from vllm.model_executor.layers.linear import (ColumnParallelLinear,
RowParallelLinear, RowParallelLinear,
QKVParallelLinear, QKVParallelLinear,
MergedColumnParallelLinear) MergedColumnParallelLinear)
from vllm.model_executor.layers.vocab_parallel_embedding import VocabParallelEmbedding, ParallelLMHead from vllm.model_executor.layers.vocab_parallel_embedding import (
VocabParallelEmbedding, ParallelLMHead)
from vllm.model_executor.parallel_utils.parallel_state import ( from vllm.model_executor.parallel_utils.parallel_state import (
get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size) get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size)
from vllm.model_executor.parallel_utils.utils import split_tensor_along_last_dim from vllm.model_executor.parallel_utils.utils import (
split_tensor_along_last_dim)
if TYPE_CHECKING: if TYPE_CHECKING:
pass pass
...@@ -84,7 +86,8 @@ def _apply_lora_packed_nslice( ...@@ -84,7 +86,8 @@ def _apply_lora_packed_nslice(
lora_b_stacked: 3 element tuple of (num_loras, output_dim, lora_rank) lora_b_stacked: 3 element tuple of (num_loras, output_dim, lora_rank)
indices: (batch_size) indices: (batch_size)
output: (batch_size, q_slice_size + 2*kv_slice_size) output: (batch_size, q_slice_size + 2*kv_slice_size)
output_slices: n-1 element tuple of (slice_size...), where n is number of slices output_slices: n-1 element tuple of (slice_size...),
where n is number of slices
""" """
org_output = output org_output = output
x = x.view(-1, x.shape[-1]) x = x.view(-1, x.shape[-1])
...@@ -819,9 +822,8 @@ class SamplerWithLoRA(BaseLayerWithLoRA): ...@@ -819,9 +822,8 @@ class SamplerWithLoRA(BaseLayerWithLoRA):
) -> None: ) -> None:
# Keep this in sync with csrc/punica/bgmv/bgmv_config.h # Keep this in sync with csrc/punica/bgmv/bgmv_config.h
if 32000 < self.base_layer.vocab_size > 33024: if 32000 < self.base_layer.vocab_size > 33024:
raise ValueError( raise ValueError("When using LoRA, vocab size must be "
"When using LoRA, vocab size must be 32000 >= vocab_size <= 33024" "32000 >= vocab_size <= 33024")
)
self.lora_a_stacked = torch.zeros( self.lora_a_stacked = torch.zeros(
( (
max_loras, max_loras,
......
...@@ -13,7 +13,8 @@ from torch import nn ...@@ -13,7 +13,8 @@ from torch import nn
from vllm.config import LoRAConfig from vllm.config import LoRAConfig
from vllm.utils import LRUCache, in_wsl from vllm.utils import LRUCache, in_wsl
from vllm.lora.layers import BaseLayerWithLoRA, LoRAMapping, from_layer, from_layer_sampler from vllm.lora.layers import (BaseLayerWithLoRA, LoRAMapping, from_layer,
from_layer_sampler)
from vllm.lora.lora import LoRALayerWeights, PackedLoRALayerWeights from vllm.lora.lora import LoRALayerWeights, PackedLoRALayerWeights
from vllm.lora.utils import parse_fine_tuned_lora_name, replace_submodule from vllm.lora.utils import parse_fine_tuned_lora_name, replace_submodule
......
...@@ -154,10 +154,9 @@ class WorkerLoRAManager(AbstractWorkerLoRAManager): ...@@ -154,10 +154,9 @@ class WorkerLoRAManager(AbstractWorkerLoRAManager):
f"LoRA rank {lora.rank} is greater than max_lora_rank " f"LoRA rank {lora.rank} is greater than max_lora_rank "
f"{self.lora_config.max_lora_rank}.") f"{self.lora_config.max_lora_rank}.")
if lora.extra_vocab_size > self.lora_config.lora_extra_vocab_size: if lora.extra_vocab_size > self.lora_config.lora_extra_vocab_size:
raise ValueError( raise ValueError(f"LoRA added vocab size {lora.extra_vocab_size} "
f"LoRA added vocab size {lora.extra_vocab_size} is greater than " f"is greater than lora_extra_vocab_size "
f"lora_extra_vocab_size {self.lora_config.lora_extra_vocab_size}." f"{self.lora_config.lora_extra_vocab_size}.")
)
return lora return lora
def add_dummy_lora(self, lora_request: LoRARequest, rank: int) -> bool: def add_dummy_lora(self, lora_request: LoRARequest, rank: int) -> bool:
......
...@@ -8,8 +8,10 @@ from re import escape as regex_escape ...@@ -8,8 +8,10 @@ from re import escape as regex_escape
from typing import Union, Tuple from typing import Union, Tuple
from pydantic import BaseModel from pydantic import BaseModel
from vllm.entrypoints.openai.protocol import CompletionRequest, ChatCompletionRequest from vllm.entrypoints.openai.protocol import (CompletionRequest,
from vllm.model_executor.guided_logits_processors import JSONLogitsProcessor, RegexLogitsProcessor ChatCompletionRequest)
from vllm.model_executor.guided_logits_processors import (JSONLogitsProcessor,
RegexLogitsProcessor)
class GuidedDecodingMode(Enum): class GuidedDecodingMode(Enum):
......
...@@ -107,12 +107,15 @@ class JSONLogitsProcessor(RegexLogitsProcessor): ...@@ -107,12 +107,15 @@ class JSONLogitsProcessor(RegexLogitsProcessor):
Parameters Parameters
---------- ----------
schema schema
A JSON schema that encodes the structure we want the model to generate A JSON schema that encodes the structure we want the model to
generate
tokenizer tokenizer
The model's tokenizer The model's tokenizer
whitespace_pattern whitespace_pattern
Pattern to use for JSON syntactic whitespace (doesn't impact string literals) Pattern to use for JSON syntactic whitespace (doesn't impact
Example: allow only a single space or newline with `whitespace_pattern=r"[\n ]?"` string literals)
Example: allow only a single space or newline with
`whitespace_pattern=r"[\n ]?"`
""" """
if isinstance(schema, type(BaseModel)): if isinstance(schema, type(BaseModel)):
schema_str = json.dumps(schema.model_json_schema()) schema_str = json.dumps(schema.model_json_schema())
...@@ -122,8 +125,8 @@ class JSONLogitsProcessor(RegexLogitsProcessor): ...@@ -122,8 +125,8 @@ class JSONLogitsProcessor(RegexLogitsProcessor):
schema_str = schema schema_str = schema
else: else:
raise ValueError( raise ValueError(
f"Cannot parse schema {schema}. The schema must be either " + f"Cannot parse schema {schema}. The schema must be either "
"a Pydantic object, a dictionary or a string that contains the JSON " f"a Pydantic object, a dictionary or a string that contains "
+ "Schema specification") f"the JSON Schema specification")
regex_string = build_regex_from_schema(schema_str, whitespace_pattern) regex_string = build_regex_from_schema(schema_str, whitespace_pattern)
super().__init__(regex_string, tokenizer) super().__init__(regex_string, tokenizer)
...@@ -35,12 +35,12 @@ class Attention(nn.Module): ...@@ -35,12 +35,12 @@ class Attention(nn.Module):
) -> None: ) -> None:
super().__init__() super().__init__()
if _use_flash_attn(): if _use_flash_attn():
from vllm.model_executor.layers.attention.backends.flash_attn import FlashAttentionBackend from vllm.model_executor.layers.attention.backends.flash_attn import FlashAttentionBackend # noqa: E501
self.backend = FlashAttentionBackend(num_heads, head_size, scale, self.backend = FlashAttentionBackend(num_heads, head_size, scale,
num_kv_heads, alibi_slopes, num_kv_heads, alibi_slopes,
sliding_window) sliding_window)
else: else:
from vllm.model_executor.layers.attention.backends.xformers import XFormersBackend from vllm.model_executor.layers.attention.backends.xformers import XFormersBackend # noqa: E501
self.backend = XFormersBackend(num_heads, head_size, scale, self.backend = XFormersBackend(num_heads, head_size, scale,
num_kv_heads, alibi_slopes, num_kv_heads, alibi_slopes,
sliding_window) sliding_window)
......
...@@ -30,9 +30,10 @@ def fused_moe_kernel( ...@@ -30,9 +30,10 @@ def fused_moe_kernel(
K, K,
EM, EM,
num_valid_tokens, num_valid_tokens,
# The stride variables represent how much to increase the ptr by when moving by 1 # The stride variables represent how much to increase the ptr by when
# element in a particular dimension. E.g. `stride_am` is how much to increase `a_ptr` # moving by 1 element in a particular dimension. E.g. `stride_am` is
# by to get the element one row down (A has M rows). # how much to increase `a_ptr` by to get the element one row down
# (A has M rows).
stride_am, stride_am,
stride_ak, stride_ak,
stride_be, stride_be,
...@@ -50,17 +51,30 @@ def fused_moe_kernel( ...@@ -50,17 +51,30 @@ def fused_moe_kernel(
compute_type: tl.constexpr, compute_type: tl.constexpr,
): ):
""" """
Implements the fused computation for a Mixture of Experts (MOE) using token and expert matrices. Implements the fused computation for a Mixture of Experts (MOE) using
token and expert matrices.
Key Parameters: Key Parameters:
- A: The input tensor representing tokens with shape (*, K), where '*' can be any shape representing batches and K is the feature dimension of each token. - A: The input tensor representing tokens with shape (*, K), where '*' can
- B: The stacked MOE weight tensor with shape (E, N, K), where E is the number of experts, K is the input feature dimension, and N is the output feature dimension. be any shape representing batches and K is the feature dimension of
- C: The output cache tensor with shape (M, topk, N), where M is the total number of tokens post padding, topk is the number of times each token is repeated, each token.
and N is the output feature dimension. - B: The stacked MOE weight tensor with shape (E, N, K), where E is
- sorted_token_ids: A tensor containing the sorted indices of tokens, repeated topk times and arranged by the expert index they are assigned to. the number of experts, K is the input feature dimension, and N is
- expert_ids: A tensor containing the indices of the expert for each block. It determines which expert matrix from B should be used for each block in A. the output feature dimension.
This kernel performs the multiplication of a token by its corresponding expert matrix as determined by `expert_ids`. The sorting of `sorted_token_ids` - C: The output cache tensor with shape (M, topk, N), where M is the
by expert index and padding ensures divisibility by BLOCK_SIZE_M, which is necessary to maintain consistency in block matrix multiplication across different blocks processed by the same expert. total number of tokens post padding, topk is the number of times
each token is repeated, and N is the output feature dimension.
- sorted_token_ids: A tensor containing the sorted indices of tokens,
repeated topk times and arranged by the expert index they are
assigned to.
- expert_ids: A tensor containing the indices of the expert for each
block. It determines which expert matrix from B should be used for
each block in A.
This kernel performs the multiplication of a token by its corresponding
expert matrix as determined by `expert_ids`. The sorting of
`sorted_token_ids` by expert index and padding ensures divisibility by
BLOCK_SIZE_M, which is necessary to maintain consistency in block matrix
multiplication across different blocks processed by the same expert.
""" """
# ----------------------------------------------------------- # -----------------------------------------------------------
# Map program ids `pid` to the block of C it should compute. # Map program ids `pid` to the block of C it should compute.
...@@ -105,7 +119,8 @@ def fused_moe_kernel( ...@@ -105,7 +119,8 @@ def fused_moe_kernel(
accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
for k in range(0, tl.cdiv(K, BLOCK_SIZE_K)): for k in range(0, tl.cdiv(K, BLOCK_SIZE_K)):
# Load the next block of A and B, generate a mask by checking the K dimension. # Load the next block of A and B, generate a mask by checking the
# K dimension.
a = tl.load(a_ptrs, a = tl.load(a_ptrs,
mask=token_mask[:, None] & mask=token_mask[:, None] &
(offs_k[None, :] < K - k * BLOCK_SIZE_K), (offs_k[None, :] < K - k * BLOCK_SIZE_K),
...@@ -139,30 +154,41 @@ def moe_align_block_size( ...@@ -139,30 +154,41 @@ def moe_align_block_size(
topk_ids: torch.Tensor, block_size: int, topk_ids: torch.Tensor, block_size: int,
num_experts: int) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: num_experts: int) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
""" """
Aligns the token distribution across experts to be compatible with block size for matrix multiplication. Aligns the token distribution across experts to be compatible with block
size for matrix multiplication.
Parameters: Parameters:
- topk_ids: A tensor of shape [total_tokens, top_k] representing the top-k expert indices for each token. - topk_ids: A tensor of shape [total_tokens, top_k] representing the
top-k expert indices for each token.
- block_size: The block size used in block matrix multiplication. - block_size: The block size used in block matrix multiplication.
- num_experts: The total number of experts. - num_experts: The total number of experts.
Returns: Returns:
- sorted_token_ids: A tensor containing the sorted token indices according to their allocated expert. - sorted_token_ids: A tensor containing the sorted token indices according
to their allocated expert.
- expert_ids: A tensor indicating the assigned expert index for each block. - expert_ids: A tensor indicating the assigned expert index for each block.
- num_tokens_post_padded: The total number of tokens after padding, ensuring divisibility by block_size. - num_tokens_post_padded: The total number of tokens after padding,
ensuring divisibility by block_size.
This function pads the number of tokens that each expert needs to process so that it is divisible by block_size. This function pads the number of tokens that each expert needs to process
Padding ensures that during block matrix multiplication, the dimensions align correctly. so that it is divisible by block_size.
Padding ensures that during block matrix multiplication, the dimensions
align correctly.
Example: Example:
Given topk_ids = [[2, 3, 4], [1, 2, 4], [1, 3, 4], [1, 2, 3]], block_size = 4, and num_experts = 4: Given topk_ids = [[2, 3, 4], [1, 2, 4], [1, 3, 4], [1, 2, 3]],
- We initially have 12 tokens (after repeating 'top_k' times) and 4 experts, with each expert needing to process 3 tokens. block_size = 4, and num_experts = 4:
- We initially have 12 tokens (after repeating 'top_k' times) and 4 experts,
with each expert needing to process 3 tokens.
- As block_size is 4, we pad 1 token for each expert. - As block_size is 4, we pad 1 token for each expert.
- First, flatten topk_ids to [2, 3, 4, 1, 2, 4, 1, 3, 4, 1, 2, 3]. - First, flatten topk_ids to [2, 3, 4, 1, 2, 4, 1, 3, 4, 1, 2, 3].
- Then append padding tokens [12, 12, 12, 12] for each block. - Then append padding tokens [12, 12, 12, 12] for each block.
- After sorting by expert index, we obtain token_ids [3, 6, 9, 12, 0, 4, 10, 12, 1, 7, 11, 12, 2, 5, 8, 12]. - After sorting by expert index, we obtain token_ids
Tokens 12 are non-existent (padding) and are ignored in the subsequent matrix multiplication. [3, 6, 9, 12, 0, 4, 10, 12, 1, 7, 11, 12, 2, 5, 8, 12].
- The padding ensures that the total number of tokens is now divisible by block_size for proper block matrix operations. Tokens 12 are non-existent (padding) and are ignored in
the subsequent matrix multiplication.
- The padding ensures that the total number of tokens is now divisible
by block_size for proper block matrix operations.
""" """
sorted_ids = torch.empty( sorted_ids = torch.empty(
(topk_ids.numel() + num_experts * (block_size - 1), ), (topk_ids.numel() + num_experts * (block_size - 1), ),
...@@ -224,13 +250,14 @@ def get_moe_configs(E: int, N: int) -> Optional[Dict[int, Any]]: ...@@ -224,13 +250,14 @@ def get_moe_configs(E: int, N: int) -> Optional[Dict[int, Any]]:
""" """
Return optimized configurations for the fused MoE kernel. Return optimized configurations for the fused MoE kernel.
The return value will be a dictionary that maps an irregular grid of batch sizes The return value will be a dictionary that maps an irregular grid of
to configurations of the fused_moe kernel. To evaluate the kernel on a given batch batch sizes to configurations of the fused_moe kernel. To evaluate the
size bs, the closest batch size in the grid should be picked and the associated kernel on a given batch size bs, the closest batch size in the grid should
configuration chosen to invoke the kernel. be picked and the associated configuration chosen to invoke the kernel.
""" """
# First look up if an optimized configuration is available in the configs directory # First look up if an optimized configuration is available in the configs
# directory
device_name = torch.cuda.get_device_name().replace(" ", "_") device_name = torch.cuda.get_device_name().replace(" ", "_")
config_file_path = os.path.join( config_file_path = os.path.join(
...@@ -243,7 +270,8 @@ def get_moe_configs(E: int, N: int) -> Optional[Dict[int, Any]]: ...@@ -243,7 +270,8 @@ def get_moe_configs(E: int, N: int) -> Optional[Dict[int, Any]]:
# If a configuration has been found, return it # If a configuration has been found, return it
return {int(key): val for key, val in json.load(f).items()} return {int(key): val for key, val in json.load(f).items()}
# If no optimized configuration is available, we will use the default configuration # If no optimized configuration is available, we will use the default
# configuration
return None return None
...@@ -258,18 +286,22 @@ def fused_moe( ...@@ -258,18 +286,22 @@ def fused_moe(
override_config: Optional[Dict[str, Any]] = None, override_config: Optional[Dict[str, Any]] = None,
) -> torch.Tensor: ) -> torch.Tensor:
""" """
This function computes a Mixture of Experts (MoE) layer using two sets of weights, w1 and w2, and top-k gating mechanism. This function computes a Mixture of Experts (MoE) layer using two sets of
weights, w1 and w2, and top-k gating mechanism.
Parameters: Parameters:
- hidden_states (torch.Tensor): The input tensor to the MoE layer. - hidden_states (torch.Tensor): The input tensor to the MoE layer.
- w1 (torch.Tensor): The first set of expert weights. - w1 (torch.Tensor): The first set of expert weights.
- w2 (torch.Tensor): The second set of expert weights. - w2 (torch.Tensor): The second set of expert weights.
- gating_output (torch.Tensor): The output of the gating operation (before softmax). - gating_output (torch.Tensor): The output of the gating operation
(before softmax).
- topk (int): The number of top-k experts to select. - topk (int): The number of top-k experts to select.
- renormalize (bool): If True, renormalize the top-k weights to sum to 1. - renormalize (bool): If True, renormalize the top-k weights to sum to 1.
- inplace (bool): If True, perform the operation in-place. Defaults to False. - inplace (bool): If True, perform the operation in-place.
- override_config (Optional[Dict[str, Any]]): Optional override for the kernel configuration. Defaults to False.
- override_config (Optional[Dict[str, Any]]): Optional override
for the kernel configuration.
Returns: Returns:
- torch.Tensor: The output tensor after applying the MoE layer. - torch.Tensor: The output tensor after applying the MoE layer.
""" """
...@@ -325,7 +357,8 @@ def fused_moe( ...@@ -325,7 +357,8 @@ def fused_moe(
configs = get_moe_configs(E, w2.shape[2]) configs = get_moe_configs(E, w2.shape[2])
if configs: if configs:
# If an optimal configuration map has been found, look up the optimal config # If an optimal configuration map has been found, look up the
# optimal config
config = configs[min(configs.keys(), key=lambda x: abs(x - M))] config = configs[min(configs.keys(), key=lambda x: abs(x - M))]
else: else:
# Else use the default config # Else use the default config
......
...@@ -285,7 +285,8 @@ class MergedColumnParallelLinear(ColumnParallelLinear): ...@@ -285,7 +285,8 @@ class MergedColumnParallelLinear(ColumnParallelLinear):
shard_size = shard_size // param.pack_factor shard_size = shard_size // param.pack_factor
shard_offset = shard_offset // param.pack_factor shard_offset = shard_offset // param.pack_factor
# If marlin, we need to adjust the offset and size to account for the tiling. # If marlin, we need to adjust the offset and size to
# account for the tiling.
shard_size, shard_offset = adjust_marlin_shard( shard_size, shard_offset = adjust_marlin_shard(
param, shard_size, shard_offset) param, shard_size, shard_offset)
...@@ -307,7 +308,8 @@ class MergedColumnParallelLinear(ColumnParallelLinear): ...@@ -307,7 +308,8 @@ class MergedColumnParallelLinear(ColumnParallelLinear):
shard_size = shard_size // param.pack_factor shard_size = shard_size // param.pack_factor
shard_offset = shard_offset // param.pack_factor shard_offset = shard_offset // param.pack_factor
# If marlin, we need to adjust the offset and size to account for the tiling. # If marlin, we need to adjust the offset and size to
# account for the tiling.
shard_size, shard_offset = adjust_marlin_shard( shard_size, shard_offset = adjust_marlin_shard(
param, shard_size, shard_offset) param, shard_size, shard_offset)
...@@ -413,7 +415,8 @@ class QKVParallelLinear(ColumnParallelLinear): ...@@ -413,7 +415,8 @@ class QKVParallelLinear(ColumnParallelLinear):
shard_size = shard_size // param.pack_factor shard_size = shard_size // param.pack_factor
shard_offset = shard_offset // param.pack_factor shard_offset = shard_offset // param.pack_factor
# If marlin, we need to adjust the offset and size to account for the tiling. # If marlin, we need to adjust the offset and size to
# account for the tiling.
shard_size, shard_offset = adjust_marlin_shard( shard_size, shard_offset = adjust_marlin_shard(
param, shard_size, shard_offset) param, shard_size, shard_offset)
...@@ -442,7 +445,8 @@ class QKVParallelLinear(ColumnParallelLinear): ...@@ -442,7 +445,8 @@ class QKVParallelLinear(ColumnParallelLinear):
shard_size = shard_size // param.pack_factor shard_size = shard_size // param.pack_factor
shard_offset = shard_offset // param.pack_factor shard_offset = shard_offset // param.pack_factor
# If marlin, we need to adjust the offset and size to account for the tiling. # If marlin, we need to adjust the offset and size to
# account for the tiling.
shard_size, shard_offset = adjust_marlin_shard( shard_size, shard_offset = adjust_marlin_shard(
param, shard_size, shard_offset) param, shard_size, shard_offset)
......
from typing import Type from typing import Type
from vllm.model_executor.layers.quantization.base_config import QuantizationConfig from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig)
from vllm.model_executor.layers.quantization.awq import AWQConfig from vllm.model_executor.layers.quantization.awq import AWQConfig
from vllm.model_executor.layers.quantization.gptq import GPTQConfig from vllm.model_executor.layers.quantization.gptq import GPTQConfig
from vllm.model_executor.layers.quantization.squeezellm import SqueezeLLMConfig from vllm.model_executor.layers.quantization.squeezellm import SqueezeLLMConfig
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment