Commit 40542023 authored by zhuwenwen's avatar zhuwenwen
Browse files

merge v0.3.2

parents 5e5b497d 8fbd84bf
...@@ -33,7 +33,7 @@ class FCFS(Policy): ...@@ -33,7 +33,7 @@ class FCFS(Policy):
now: float, now: float,
seq_group: SequenceGroup, seq_group: SequenceGroup,
) -> float: ) -> float:
return now - seq_group.arrival_time return now - seq_group.metrics.arrival_time
class PolicyFactory: class PolicyFactory:
......
...@@ -365,10 +365,13 @@ class Scheduler: ...@@ -365,10 +365,13 @@ class Scheduler:
# This function call changes the internal states of the scheduler # This function call changes the internal states of the scheduler
# such as self.running, self.swapped, and self.waiting. # such as self.running, self.swapped, and self.waiting.
scheduler_outputs = self._schedule() scheduler_outputs = self._schedule()
now = time.time()
# Create input data structures. # Create input data structures.
seq_group_metadata_list: List[SequenceGroupMetadata] = [] seq_group_metadata_list: List[SequenceGroupMetadata] = []
for seq_group in scheduler_outputs.scheduled_seq_groups: for seq_group in scheduler_outputs.scheduled_seq_groups:
seq_group.maybe_set_first_scheduled_time(now)
seq_data: Dict[int, SequenceData] = {} seq_data: Dict[int, SequenceData] = {}
block_tables: Dict[int, List[int]] = {} block_tables: Dict[int, List[int]] = {}
for seq in seq_group.get_seqs(status=SequenceStatus.RUNNING): for seq in seq_group.get_seqs(status=SequenceStatus.RUNNING):
...@@ -384,6 +387,7 @@ class Scheduler: ...@@ -384,6 +387,7 @@ class Scheduler:
block_tables=block_tables, block_tables=block_tables,
lora_request=seq_group.lora_request, lora_request=seq_group.lora_request,
prefix=seq_group.prefix, prefix=seq_group.prefix,
state=seq_group.state,
) )
seq_group_metadata_list.append(seq_group_metadata) seq_group_metadata_list.append(seq_group_metadata)
return seq_group_metadata_list, scheduler_outputs return seq_group_metadata_list, scheduler_outputs
......
...@@ -32,6 +32,7 @@ class EngineArgs: ...@@ -32,6 +32,7 @@ class EngineArgs:
max_paddings: int = 256 max_paddings: int = 256
disable_log_stats: bool = False disable_log_stats: bool = False
revision: Optional[str] = None revision: Optional[str] = None
code_revision: Optional[str] = None
tokenizer_revision: Optional[str] = None tokenizer_revision: Optional[str] = None
quantization: Optional[str] = None quantization: Optional[str] = None
enforce_eager: bool = False enforce_eager: bool = False
...@@ -75,6 +76,13 @@ class EngineArgs: ...@@ -75,6 +76,13 @@ class EngineArgs:
help='the specific model version to use. It can be a branch ' help='the specific model version to use. It can be a branch '
'name, a tag name, or a commit id. If unspecified, will use ' 'name, a tag name, or a commit id. If unspecified, will use '
'the default version.') 'the default version.')
parser.add_argument(
'--code-revision',
type=str,
default=None,
help='the specific revision to use for the model code on '
'Hugging Face Hub. It can be a branch name, a tag name, or a '
'commit id. If unspecified, will use the default version.')
parser.add_argument( parser.add_argument(
'--tokenizer-revision', '--tokenizer-revision',
type=str, type=str,
...@@ -165,7 +173,6 @@ class EngineArgs: ...@@ -165,7 +173,6 @@ class EngineArgs:
default=EngineArgs.block_size, default=EngineArgs.block_size,
choices=[8, 16, 32], choices=[8, 16, 32],
help='token block size') help='token block size')
# TODO(woosuk): Support fine-grained seeds (e.g., seed per request).
parser.add_argument('--seed', parser.add_argument('--seed',
type=int, type=int,
default=EngineArgs.seed, default=EngineArgs.seed,
...@@ -279,13 +286,12 @@ class EngineArgs: ...@@ -279,13 +286,12 @@ class EngineArgs:
) -> Tuple[ModelConfig, CacheConfig, ParallelConfig, SchedulerConfig, ) -> Tuple[ModelConfig, CacheConfig, ParallelConfig, SchedulerConfig,
DeviceConfig, Optional[LoRAConfig]]: DeviceConfig, Optional[LoRAConfig]]:
device_config = DeviceConfig(self.device) device_config = DeviceConfig(self.device)
model_config = ModelConfig(self.model, self.tokenizer, model_config = ModelConfig(
self.tokenizer_mode, self.trust_remote_code, self.model, self.tokenizer, self.tokenizer_mode,
self.download_dir, self.load_format, self.trust_remote_code, self.download_dir, self.load_format,
self.dtype, self.seed, self.revision, self.dtype, self.seed, self.revision, self.code_revision,
self.tokenizer_revision, self.max_model_len, self.tokenizer_revision, self.max_model_len, self.quantization,
self.quantization, self.enforce_eager, self.enforce_eager, self.max_context_len_to_capture)
self.max_context_len_to_capture)
cache_config = CacheConfig(self.block_size, cache_config = CacheConfig(self.block_size,
self.gpu_memory_utilization, self.gpu_memory_utilization,
self.swap_space, self.kv_cache_dtype, self.swap_space, self.kv_cache_dtype,
......
...@@ -464,6 +464,9 @@ class LLMEngine: ...@@ -464,6 +464,9 @@ class LLMEngine:
prompt_token_ids[:prefix_pos], lora_request.lora_int_id prompt_token_ids[:prefix_pos], lora_request.lora_int_id
if lora_request else 0) if prefix_pos is not None else None if lora_request else 0) if prefix_pos is not None else None
# Defensive copy of SamplingParams, which are used by the sampler
sampling_params = copy.deepcopy(sampling_params)
# Create the sequence group. # Create the sequence group.
seq_group = SequenceGroup(request_id, [seq], sampling_params, seq_group = SequenceGroup(request_id, [seq], sampling_params,
arrival_time, lora_request, prefix) arrival_time, lora_request, prefix)
...@@ -725,6 +728,7 @@ class LLMEngine: ...@@ -725,6 +728,7 @@ class LLMEngine:
def _process_model_outputs( def _process_model_outputs(
self, output: SamplerOutput, self, output: SamplerOutput,
scheduler_outputs: SchedulerOutputs) -> List[RequestOutput]: scheduler_outputs: SchedulerOutputs) -> List[RequestOutput]:
now = time.time()
# Update the scheduled sequence groups with the model outputs. # Update the scheduled sequence groups with the model outputs.
scheduled_seq_groups = scheduler_outputs.scheduled_seq_groups scheduled_seq_groups = scheduler_outputs.scheduled_seq_groups
for seq_group, outputs in zip(scheduled_seq_groups, output): for seq_group, outputs in zip(scheduled_seq_groups, output):
...@@ -736,6 +740,7 @@ class LLMEngine: ...@@ -736,6 +740,7 @@ class LLMEngine:
# Create the outputs. # Create the outputs.
request_outputs: List[RequestOutput] = [] request_outputs: List[RequestOutput] = []
for seq_group in scheduled_seq_groups: for seq_group in scheduled_seq_groups:
seq_group.maybe_set_first_token_time(now)
request_output = RequestOutput.from_seq_group(seq_group) request_output = RequestOutput.from_seq_group(seq_group)
request_outputs.append(request_output) request_outputs.append(request_output)
for seq_group in scheduler_outputs.ignored_seq_groups: for seq_group in scheduler_outputs.ignored_seq_groups:
...@@ -864,18 +869,21 @@ class LLMEngine: ...@@ -864,18 +869,21 @@ class LLMEngine:
# Number of Tokens. # Number of Tokens.
if prompt_run: if prompt_run:
num_prompt_tokens = scheduler_outputs.num_batched_tokens num_prompt_tokens = sum(
len(seq_group.prompt_token_ids)
for seq_group in scheduler_outputs.scheduled_seq_groups)
else: else:
num_generation_tokens = scheduler_outputs.num_batched_tokens num_generation_tokens = scheduler_outputs.num_batched_tokens
# Latency Timings. # Latency Timings.
time_last_iters = [] time_last_iters = []
for seq_group in scheduler_outputs.scheduled_seq_groups: for seq_group in scheduler_outputs.scheduled_seq_groups:
# Time since last token. (n.b. updates seq_group.last_token_time) # Time since last token. (n.b. updates seq_group.metrics.last_token_time)
time_last_iters.append(seq_group.get_last_latency(now)) time_last_iters.append(seq_group.get_last_latency(now))
# Time since arrival for all finished requests. # Time since arrival for all finished requests.
if seq_group.is_finished(): if seq_group.is_finished():
time_e2e_requests.append(now - seq_group.arrival_time) time_e2e_requests.append(now -
seq_group.metrics.arrival_time)
time_to_first_tokens = time_last_iters if prompt_run else [] time_to_first_tokens = time_last_iters if prompt_run else []
time_per_output_tokens = [] if prompt_run else time_last_iters time_per_output_tokens = [] if prompt_run else time_last_iters
......
"""
NOTE: This API server is used only for demonstrating usage of AsyncEngine and simple performance benchmarks.
It is not intended for production use. For production use, we recommend using our OpenAI compatible server.
We are also not going to accept PRs modifying this file, please change `vllm/entrypoints/openai/api_server.py` instead.
"""
import argparse import argparse
import json import json
from typing import AsyncGenerator from typing import AsyncGenerator
......
...@@ -23,6 +23,7 @@ from vllm.entrypoints.openai.protocol import CompletionRequest, ChatCompletionRe ...@@ -23,6 +23,7 @@ from vllm.entrypoints.openai.protocol import CompletionRequest, ChatCompletionRe
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.entrypoints.openai.serving_chat import OpenAIServingChat from vllm.entrypoints.openai.serving_chat import OpenAIServingChat
from vllm.entrypoints.openai.serving_completion import OpenAIServingCompletion from vllm.entrypoints.openai.serving_completion import OpenAIServingCompletion
from vllm.entrypoints.openai.serving_engine import LoRA
TIMEOUT_KEEP_ALIVE = 5 # seconds TIMEOUT_KEEP_ALIVE = 5 # seconds
...@@ -48,6 +49,16 @@ async def lifespan(app: fastapi.FastAPI): ...@@ -48,6 +49,16 @@ async def lifespan(app: fastapi.FastAPI):
app = fastapi.FastAPI(lifespan=lifespan) app = fastapi.FastAPI(lifespan=lifespan)
class LoRAParserAction(argparse.Action):
def __call__(self, parser, namespace, values, option_string=None):
lora_list = []
for item in values:
name, path = item.split('=')
lora_list.append(LoRA(name, path))
setattr(namespace, self.dest, lora_list)
def parse_args(): def parse_args():
parser = argparse.ArgumentParser( parser = argparse.ArgumentParser(
description="vLLM OpenAI-Compatible RESTful API server.") description="vLLM OpenAI-Compatible RESTful API server.")
...@@ -81,6 +92,15 @@ def parse_args(): ...@@ -81,6 +92,15 @@ def parse_args():
help="The model name used in the API. If not " help="The model name used in the API. If not "
"specified, the model name will be the same as " "specified, the model name will be the same as "
"the huggingface name.") "the huggingface name.")
parser.add_argument(
"--lora-modules",
type=str,
default=None,
nargs='+',
action=LoRAParserAction,
help=
"LoRA module configurations in the format name=path. Multiple modules can be specified."
)
parser.add_argument("--chat-template", parser.add_argument("--chat-template",
type=str, type=str,
default=None, default=None,
...@@ -217,8 +237,10 @@ if __name__ == "__main__": ...@@ -217,8 +237,10 @@ if __name__ == "__main__":
engine = AsyncLLMEngine.from_engine_args(engine_args) engine = AsyncLLMEngine.from_engine_args(engine_args)
openai_serving_chat = OpenAIServingChat(engine, served_model, openai_serving_chat = OpenAIServingChat(engine, served_model,
args.response_role, args.response_role,
args.lora_modules,
args.chat_template) args.chat_template)
openai_serving_completion = OpenAIServingCompletion(engine, served_model) openai_serving_completion = OpenAIServingCompletion(
engine, served_model, args.lora_modules)
# Register labels for metrics # Register labels for metrics
add_global_metrics_labels(model_name=engine_args.model) add_global_metrics_labels(model_name=engine_args.model)
......
...@@ -60,6 +60,7 @@ class ChatCompletionRequest(BaseModel): ...@@ -60,6 +60,7 @@ class ChatCompletionRequest(BaseModel):
top_p: Optional[float] = 1.0 top_p: Optional[float] = 1.0
n: Optional[int] = 1 n: Optional[int] = 1
max_tokens: Optional[int] = None max_tokens: Optional[int] = None
seed: Optional[int] = None
stop: Optional[Union[str, List[str]]] = Field(default_factory=list) stop: Optional[Union[str, List[str]]] = Field(default_factory=list)
stream: Optional[bool] = False stream: Optional[bool] = False
presence_penalty: Optional[float] = 0.0 presence_penalty: Optional[float] = 0.0
...@@ -90,6 +91,7 @@ class ChatCompletionRequest(BaseModel): ...@@ -90,6 +91,7 @@ class ChatCompletionRequest(BaseModel):
temperature=self.temperature, temperature=self.temperature,
top_p=self.top_p, top_p=self.top_p,
min_p=self.min_p, min_p=self.min_p,
seed=self.seed,
stop=self.stop, stop=self.stop,
stop_token_ids=self.stop_token_ids, stop_token_ids=self.stop_token_ids,
max_tokens=self.max_tokens, max_tokens=self.max_tokens,
...@@ -117,6 +119,7 @@ class CompletionRequest(BaseModel): ...@@ -117,6 +119,7 @@ class CompletionRequest(BaseModel):
logprobs: Optional[int] = None logprobs: Optional[int] = None
echo: Optional[bool] = False echo: Optional[bool] = False
stop: Optional[Union[str, List[str]]] = Field(default_factory=list) stop: Optional[Union[str, List[str]]] = Field(default_factory=list)
seed: Optional[int] = None
presence_penalty: Optional[float] = 0.0 presence_penalty: Optional[float] = 0.0
frequency_penalty: Optional[float] = 0.0 frequency_penalty: Optional[float] = 0.0
best_of: Optional[int] = None best_of: Optional[int] = None
...@@ -147,6 +150,7 @@ class CompletionRequest(BaseModel): ...@@ -147,6 +150,7 @@ class CompletionRequest(BaseModel):
top_p=self.top_p, top_p=self.top_p,
top_k=self.top_k, top_k=self.top_k,
min_p=self.min_p, min_p=self.min_p,
seed=self.seed,
stop=self.stop, stop=self.stop,
stop_token_ids=self.stop_token_ids, stop_token_ids=self.stop_token_ids,
ignore_eos=self.ignore_eos, ignore_eos=self.ignore_eos,
......
import time import time
import codecs import codecs
from fastapi import Request from fastapi import Request
from typing import AsyncGenerator, AsyncIterator, Union from typing import AsyncGenerator, AsyncIterator, Optional, List, Union
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.utils import random_uuid from vllm.utils import random_uuid
from vllm.engine.async_llm_engine import AsyncLLMEngine from vllm.engine.async_llm_engine import AsyncLLMEngine
...@@ -11,7 +11,7 @@ from vllm.entrypoints.openai.protocol import ( ...@@ -11,7 +11,7 @@ from vllm.entrypoints.openai.protocol import (
ChatCompletionStreamResponse, ChatMessage, DeltaMessage, ErrorResponse, ChatCompletionStreamResponse, ChatMessage, DeltaMessage, ErrorResponse,
UsageInfo) UsageInfo)
from vllm.outputs import RequestOutput from vllm.outputs import RequestOutput
from vllm.entrypoints.openai.serving_engine import OpenAIServing from vllm.entrypoints.openai.serving_engine import OpenAIServing, LoRA
logger = init_logger(__name__) logger = init_logger(__name__)
...@@ -22,8 +22,11 @@ class OpenAIServingChat(OpenAIServing): ...@@ -22,8 +22,11 @@ class OpenAIServingChat(OpenAIServing):
engine: AsyncLLMEngine, engine: AsyncLLMEngine,
served_model: str, served_model: str,
response_role: str, response_role: str,
lora_modules: Optional[List[LoRA]] = None,
chat_template=None): chat_template=None):
super().__init__(engine=engine, served_model=served_model) super().__init__(engine=engine,
served_model=served_model,
lora_modules=lora_modules)
self.response_role = response_role self.response_role = response_role
self._load_chat_template(chat_template) self._load_chat_template(chat_template)
...@@ -64,11 +67,13 @@ class OpenAIServingChat(OpenAIServing): ...@@ -64,11 +67,13 @@ class OpenAIServingChat(OpenAIServing):
token_ids = self._validate_prompt_and_tokenize(request, token_ids = self._validate_prompt_and_tokenize(request,
prompt=prompt) prompt=prompt)
sampling_params = request.to_sampling_params() sampling_params = request.to_sampling_params()
lora_request = self._maybe_get_lora(request)
except ValueError as e: except ValueError as e:
return self.create_error_response(str(e)) return self.create_error_response(str(e))
result_generator = self.engine.generate(prompt, sampling_params, result_generator = self.engine.generate(prompt, sampling_params,
request_id, token_ids) request_id, token_ids,
lora_request)
# Streaming response # Streaming response
if request.stream: if request.stream:
return self.chat_completion_stream_generator( return self.chat_completion_stream_generator(
......
...@@ -15,7 +15,7 @@ from .protocol import ( ...@@ -15,7 +15,7 @@ from .protocol import (
UsageInfo, UsageInfo,
) )
from vllm.outputs import RequestOutput from vllm.outputs import RequestOutput
from vllm.entrypoints.openai.serving_engine import OpenAIServing from vllm.entrypoints.openai.serving_engine import OpenAIServing, LoRA
logger = init_logger(__name__) logger = init_logger(__name__)
...@@ -249,8 +249,13 @@ def merge_async_iterators(*iterators): ...@@ -249,8 +249,13 @@ def merge_async_iterators(*iterators):
class OpenAIServingCompletion(OpenAIServing): class OpenAIServingCompletion(OpenAIServing):
def __init__(self, engine: AsyncLLMEngine, served_model: str): def __init__(self,
super().__init__(engine=engine, served_model=served_model) engine: AsyncLLMEngine,
served_model: str,
lora_modules: Optional[List[LoRA]] = None):
super().__init__(engine=engine,
served_model=served_model,
lora_modules=lora_modules)
async def create_completion(self, request: CompletionRequest, async def create_completion(self, request: CompletionRequest,
raw_request: Request): raw_request: Request):
...@@ -284,6 +289,7 @@ class OpenAIServingCompletion(OpenAIServing): ...@@ -284,6 +289,7 @@ class OpenAIServingCompletion(OpenAIServing):
generators = [] generators = []
try: try:
sampling_params = request.to_sampling_params() sampling_params = request.to_sampling_params()
lora_request = self._maybe_get_lora(request)
prompt_is_tokens, prompts = parse_prompt_format(request.prompt) prompt_is_tokens, prompts = parse_prompt_format(request.prompt)
for i, prompt in enumerate(prompts): for i, prompt in enumerate(prompts):
...@@ -298,7 +304,8 @@ class OpenAIServingCompletion(OpenAIServing): ...@@ -298,7 +304,8 @@ class OpenAIServingCompletion(OpenAIServing):
self.engine.generate(None, self.engine.generate(None,
sampling_params, sampling_params,
f"{request_id}-{i}", f"{request_id}-{i}",
prompt_token_ids=input_ids)) prompt_token_ids=input_ids,
lora_request=lora_request))
except ValueError as e: except ValueError as e:
return self.create_error_response(str(e)) return self.create_error_response(str(e))
......
import asyncio import asyncio
from dataclasses import dataclass
from http import HTTPStatus from http import HTTPStatus
from typing import Dict, List, Optional, Union from typing import Dict, List, Optional, Union
from vllm.logger import init_logger from vllm.logger import init_logger
...@@ -9,15 +10,35 @@ from vllm.entrypoints.openai.protocol import (CompletionRequest, ...@@ -9,15 +10,35 @@ from vllm.entrypoints.openai.protocol import (CompletionRequest,
ErrorResponse, LogProbs, ErrorResponse, LogProbs,
ModelCard, ModelList, ModelCard, ModelList,
ModelPermission) ModelPermission)
from vllm.lora.request import LoRARequest
logger = init_logger(__name__) logger = init_logger(__name__)
@dataclass
class LoRA:
name: str
local_path: str
class OpenAIServing: class OpenAIServing:
def __init__(self, engine: AsyncLLMEngine, served_model: str): def __init__(self,
engine: AsyncLLMEngine,
served_model: str,
lora_modules=Optional[List[LoRA]]):
self.engine = engine self.engine = engine
self.served_model = served_model self.served_model = served_model
if lora_modules is None:
self.lora_requests = []
else:
self.lora_requests = [
LoRARequest(
lora_name=lora.name,
lora_int_id=i,
lora_local_path=lora.local_path,
) for i, lora in enumerate(lora_modules, start=1)
]
self.max_model_len = 0 self.max_model_len = 0
self.tokenizer = None self.tokenizer = None
...@@ -50,6 +71,13 @@ class OpenAIServing: ...@@ -50,6 +71,13 @@ class OpenAIServing:
root=self.served_model, root=self.served_model,
permission=[ModelPermission()]) permission=[ModelPermission()])
] ]
lora_cards = [
ModelCard(id=lora.lora_name,
root=self.served_model,
permission=[ModelPermission()])
for lora in self.lora_requests
]
model_cards.extend(lora_cards)
return ModelList(data=model_cards) return ModelList(data=model_cards)
def _create_logprobs( def _create_logprobs(
...@@ -99,11 +127,22 @@ class OpenAIServing: ...@@ -99,11 +127,22 @@ class OpenAIServing:
async def _check_model(self, request) -> Optional[ErrorResponse]: async def _check_model(self, request) -> Optional[ErrorResponse]:
if request.model == self.served_model: if request.model == self.served_model:
return return
if request.model in [lora.lora_name for lora in self.lora_requests]:
return
return self.create_error_response( return self.create_error_response(
message=f"The model `{request.model}` does not exist.", message=f"The model `{request.model}` does not exist.",
err_type="NotFoundError", err_type="NotFoundError",
status_code=HTTPStatus.NOT_FOUND) status_code=HTTPStatus.NOT_FOUND)
def _maybe_get_lora(self, request) -> Optional[LoRARequest]:
if request.model == self.served_model:
return
for lora in self.lora_requests:
if request.model == lora.lora_name:
return lora
# if _check_model has been called earlier, this will be unreachable
raise ValueError("The model `{request.model}` does not exist.")
def _validate_prompt_and_tokenize( def _validate_prompt_and_tokenize(
self, self,
request: Union[ChatCompletionRequest, CompletionRequest], request: Union[ChatCompletionRequest, CompletionRequest],
......
...@@ -5,6 +5,8 @@ import logging ...@@ -5,6 +5,8 @@ import logging
import sys import sys
import os import os
VLLM_CONFIGURE_LOGGING = int(os.getenv("VLLM_CONFIGURE_LOGGING", "1"))
_FORMAT = "%(levelname)s %(asctime)s %(filename)s:%(lineno)d] %(message)s" _FORMAT = "%(levelname)s %(asctime)s %(filename)s:%(lineno)d] %(message)s"
_DATE_FORMAT = "%m-%d %H:%M:%S" _DATE_FORMAT = "%m-%d %H:%M:%S"
...@@ -45,13 +47,15 @@ def _setup_logger(): ...@@ -45,13 +47,15 @@ def _setup_logger():
# The logger is initialized when the module is imported. # The logger is initialized when the module is imported.
# This is thread-safe as the module is only imported once, # This is thread-safe as the module is only imported once,
# guaranteed by the Python GIL. # guaranteed by the Python GIL.
_setup_logger() if VLLM_CONFIGURE_LOGGING:
_setup_logger()
def init_logger(name: str): def init_logger(name: str):
# Use the same settings as above for root logger # Use the same settings as above for root logger
logger = logging.getLogger(name) logger = logging.getLogger(name)
logger.setLevel(os.getenv("LOG_LEVEL", "DEBUG")) logger.setLevel(os.getenv("LOG_LEVEL", "DEBUG"))
logger.addHandler(_default_handler) if VLLM_CONFIGURE_LOGGING:
logger.propagate = False logger.addHandler(_default_handler)
logger.propagate = False
return logger return logger
...@@ -342,7 +342,9 @@ def _beam_search_sample( ...@@ -342,7 +342,9 @@ def _beam_search_sample(
def _multinomial( def _multinomial(
probs: torch.Tensor, probs: torch.Tensor,
num_samples: int, num_samples: int,
): seq_groups: Optional[List[Tuple[List[int], SamplingParams]]] = None,
generators: Optional[List[torch.Generator]] = None,
) -> torch.Tensor:
if num_samples > 1: if num_samples > 1:
# This is equivalent to torch.repeat_interleaved (which also # This is equivalent to torch.repeat_interleaved (which also
# forces a GPU<->CPU sync). # forces a GPU<->CPU sync).
...@@ -352,7 +354,15 @@ def _multinomial( ...@@ -352,7 +354,15 @@ def _multinomial(
probs = probs[:, None, :].expand(probs.shape[0], num_samples, probs = probs[:, None, :].expand(probs.shape[0], num_samples,
probs.shape[1]).contiguous().view( probs.shape[1]).contiguous().view(
-1, probs.shape[1]) -1, probs.shape[1])
q = torch.empty_like(probs).exponential_(1) q = torch.empty_like(probs)
if seq_groups is None:
q.exponential_()
else:
sample_idx = 0
for (seq_ids, _), generator in zip(seq_groups, generators):
next_sample_idx = sample_idx + len(seq_ids) * num_samples
q[sample_idx:next_sample_idx].exponential_(generator=generator)
sample_idx = next_sample_idx
return probs.div_(q).argmax(dim=1).view(-1, num_samples) return probs.div_(q).argmax(dim=1).view(-1, num_samples)
...@@ -370,6 +380,7 @@ def _sample( ...@@ -370,6 +380,7 @@ def _sample(
sample_results_dict: Dict[int, Tuple[List[int], List[int]]] = {} sample_results_dict: Dict[int, Tuple[List[int], List[int]]] = {}
sample_metadata = {} sample_metadata = {}
multinomial_samples = {}
# Counterintiutively, having two loops here is actually faster. # Counterintiutively, having two loops here is actually faster.
# The first loop can run without waiting on GPU<->CPU sync. # The first loop can run without waiting on GPU<->CPU sync.
...@@ -385,14 +396,18 @@ def _sample( ...@@ -385,14 +396,18 @@ def _sample(
is_prompts, sample_indices) is_prompts, sample_indices)
if sampling_type == SamplingType.GREEDY: if sampling_type == SamplingType.GREEDY:
greedy_samples = torch.argmax(logprobs[sample_indices], dim=-1) greedy_samples = torch.argmax(logprobs[sample_indices], dim=-1)
elif sampling_type == SamplingType.RANDOM: elif sampling_type in (SamplingType.RANDOM, SamplingType.RANDOM_SEED):
max_best_of = 1 max_best_of = 1
for seq_group, is_prompt in zip(seq_groups, is_prompts): for seq_group, is_prompt in zip(seq_groups, is_prompts):
if is_prompt: if is_prompt:
_, sampling_params = seq_group _, sampling_params = seq_group
max_best_of = max(max_best_of, sampling_params.best_of) max_best_of = max(max_best_of, sampling_params.best_of)
multinomial_samples = _multinomial(probs[sample_indices], seeded_args = {} if sampling_type == SamplingType.RANDOM else {
max_best_of) "seq_groups": seq_groups,
"generators": sampling_metadata.generators,
}
multinomial_samples[sampling_type] = _multinomial(
probs[sample_indices], max_best_of, **seeded_args)
elif sampling_type == SamplingType.BEAM: elif sampling_type == SamplingType.BEAM:
beam_search_logprobs = logprobs[sample_indices] beam_search_logprobs = logprobs[sample_indices]
else: else:
...@@ -407,9 +422,9 @@ def _sample( ...@@ -407,9 +422,9 @@ def _sample(
sampling_type] sampling_type]
if sampling_type == SamplingType.GREEDY: if sampling_type == SamplingType.GREEDY:
sample_results = _greedy_sample(seq_groups, greedy_samples) sample_results = _greedy_sample(seq_groups, greedy_samples)
elif sampling_type == SamplingType.RANDOM: elif sampling_type in (SamplingType.RANDOM, SamplingType.RANDOM_SEED):
sample_results = _random_sample(seq_groups, is_prompts, sample_results = _random_sample(seq_groups, is_prompts,
multinomial_samples) multinomial_samples[sampling_type])
elif sampling_type == SamplingType.BEAM: elif sampling_type == SamplingType.BEAM:
sample_results = _beam_search_sample(seq_groups, is_prompts, sample_results = _beam_search_sample(seq_groups, is_prompts,
sampling_metadata.seq_data, sampling_metadata.seq_data,
......
...@@ -20,6 +20,7 @@ _MODELS = { ...@@ -20,6 +20,7 @@ _MODELS = {
"DeciLMForCausalLM": ("decilm", "DeciLMForCausalLM"), "DeciLMForCausalLM": ("decilm", "DeciLMForCausalLM"),
"DeepseekForCausalLM": ("deepseek", "DeepseekForCausalLM"), "DeepseekForCausalLM": ("deepseek", "DeepseekForCausalLM"),
"FalconForCausalLM": ("falcon", "FalconForCausalLM"), "FalconForCausalLM": ("falcon", "FalconForCausalLM"),
"GemmaForCausalLM": ("gemma", "GemmaForCausalLM"),
"GPT2LMHeadModel": ("gpt2", "GPT2LMHeadModel"), "GPT2LMHeadModel": ("gpt2", "GPT2LMHeadModel"),
"GPTBigCodeForCausalLM": ("gpt_bigcode", "GPTBigCodeForCausalLM"), "GPTBigCodeForCausalLM": ("gpt_bigcode", "GPTBigCodeForCausalLM"),
"GPTJForCausalLM": ("gpt_j", "GPTJForCausalLM"), "GPTJForCausalLM": ("gpt_j", "GPTJForCausalLM"),
...@@ -35,6 +36,7 @@ _MODELS = { ...@@ -35,6 +36,7 @@ _MODELS = {
# transformers's mpt class has lower case # transformers's mpt class has lower case
"MptForCausalLM": ("mpt", "MPTForCausalLM"), "MptForCausalLM": ("mpt", "MPTForCausalLM"),
"MPTForCausalLM": ("mpt", "MPTForCausalLM"), "MPTForCausalLM": ("mpt", "MPTForCausalLM"),
"OLMoForCausalLM": ("olmo", "OLMoForCausalLM"),
"OPTForCausalLM": ("opt", "OPTForCausalLM"), "OPTForCausalLM": ("opt", "OPTForCausalLM"),
"PhiForCausalLM": ("phi", "PhiForCausalLM"), "PhiForCausalLM": ("phi", "PhiForCausalLM"),
"QWenLMHeadModel": ("qwen", "QWenLMHeadModel"), "QWenLMHeadModel": ("qwen", "QWenLMHeadModel"),
......
# coding=utf-8
# Copyright 2023 The vLLM team.
# Copyright (c) Google Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Inference-only Gemma model compatible with HuggingFace weights."""
from typing import List, Optional, Tuple
import torch
from torch import nn
from transformers import GemmaConfig
from vllm.model_executor.input_metadata import InputMetadata
from vllm.model_executor.layers.attention import PagedAttention
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
LinearMethodBase,
QKVParallelLinear,
RowParallelLinear)
from vllm.model_executor.layers.rotary_embedding import get_rope
from vllm.model_executor.layers.sampler import Sampler
from vllm.model_executor.layers.vocab_parallel_embedding import (
VocabParallelEmbedding)
from vllm.model_executor.parallel_utils.parallel_state import (
get_tensor_model_parallel_world_size)
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.model_executor.weight_utils import (default_weight_loader,
hf_model_weights_iterator)
from vllm.sequence import SamplerOutput
KVCache = Tuple[torch.Tensor, torch.Tensor]
class GemmaRMSNorm(nn.Module):
def __init__(self, dim: int, eps: float = 1e-6):
super().__init__()
self.eps = eps
self.weight = nn.Parameter(torch.zeros(dim))
def _norm(self, x):
return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)
def forward(self, x):
output = self._norm(x.float()).type_as(x)
return output * (1 + self.weight)
class GemmaMLP(nn.Module):
def __init__(
self,
hidden_size: int,
intermediate_size: int,
linear_method: Optional[LinearMethodBase] = None,
) -> None:
super().__init__()
self.gate_proj = ColumnParallelLinear(hidden_size,
intermediate_size,
bias=False,
linear_method=linear_method)
self.up_proj = ColumnParallelLinear(hidden_size,
intermediate_size,
bias=False,
linear_method=linear_method)
self.down_proj = RowParallelLinear(intermediate_size,
hidden_size,
bias=False,
linear_method=linear_method)
self.act_fn = nn.GELU()
def forward(self, x):
gate, _ = self.gate_proj(x)
gate = self.act_fn(gate)
up, _ = self.up_proj(x)
fuse = gate * up
outputs, _ = self.down_proj(fuse)
return outputs
class GemmaAttention(nn.Module):
def __init__(self,
hidden_size: int,
num_heads: int,
num_kv_heads: int,
head_dim: int,
max_position_embeddings: int = 8192,
rope_theta: float = 10000,
linear_method: Optional[LinearMethodBase] = None) -> None:
super().__init__()
self.hidden_size = hidden_size
tp_size = get_tensor_model_parallel_world_size()
self.total_num_heads = num_heads
assert self.total_num_heads % tp_size == 0
self.num_heads = self.total_num_heads // tp_size
self.total_num_kv_heads = num_kv_heads
if self.total_num_kv_heads >= tp_size:
# Number of KV heads is greater than TP size, so we partition
# the KV heads across multiple tensor parallel GPUs.
assert self.total_num_kv_heads % tp_size == 0
else:
# Number of KV heads is less than TP size, so we replicate
# the KV heads across multiple tensor parallel GPUs.
assert tp_size % self.total_num_kv_heads == 0
self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size)
self.head_dim = head_dim
self.q_size = self.num_heads * self.head_dim
self.kv_size = self.num_kv_heads * self.head_dim
self.scaling = self.head_dim**-0.5
self.rope_theta = rope_theta
self.qkv_proj = QKVParallelLinear(
hidden_size,
self.head_dim,
self.total_num_heads,
self.total_num_kv_heads,
bias=False,
linear_method=linear_method,
)
self.o_proj = RowParallelLinear(
self.total_num_heads * self.head_dim,
hidden_size,
bias=False,
linear_method=linear_method,
)
self.rotary_emb = get_rope(
self.head_dim,
rotary_dim=self.head_dim,
max_position=max_position_embeddings,
base=self.rope_theta,
is_neox_style=True,
)
self.attn = PagedAttention(self.num_heads,
self.head_dim,
self.scaling,
num_kv_heads=self.num_kv_heads)
def forward(
self,
positions: torch.Tensor,
hidden_states: torch.Tensor,
kv_cache: KVCache,
input_metadata: InputMetadata,
) -> torch.Tensor:
qkv, _ = self.qkv_proj(hidden_states)
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
q, k = self.rotary_emb(positions, q, k)
k_cache, v_cache = kv_cache
attn_output = self.attn(q, k, v, k_cache, v_cache, input_metadata)
output, _ = self.o_proj(attn_output)
return output
class GemmaDecoderLayer(nn.Module):
def __init__(
self,
config: GemmaConfig,
linear_method: Optional[LinearMethodBase] = None,
) -> None:
super().__init__()
self.hidden_size = config.hidden_size
self.self_attn = GemmaAttention(
hidden_size=self.hidden_size,
num_heads=config.num_attention_heads,
num_kv_heads=config.num_key_value_heads,
head_dim=config.head_dim,
max_position_embeddings=config.max_position_embeddings,
rope_theta=config.rope_theta,
linear_method=linear_method,
)
self.mlp = GemmaMLP(
hidden_size=self.hidden_size,
intermediate_size=config.intermediate_size,
linear_method=linear_method,
)
self.input_layernorm = GemmaRMSNorm(config.hidden_size,
eps=config.rms_norm_eps)
self.post_attention_layernorm = GemmaRMSNorm(config.hidden_size,
eps=config.rms_norm_eps)
def forward(
self,
positions: torch.Tensor,
hidden_states: torch.Tensor,
kv_cache: KVCache,
input_metadata: InputMetadata,
) -> Tuple[torch.Tensor, torch.Tensor]:
# Self Attention
residual = hidden_states
hidden_states = self.input_layernorm(hidden_states)
hidden_states = self.self_attn(
positions=positions,
hidden_states=hidden_states,
kv_cache=kv_cache,
input_metadata=input_metadata,
)
hidden_states = residual + hidden_states
# Fully Connected
residual = hidden_states
hidden_states = self.post_attention_layernorm(hidden_states)
hidden_states = self.mlp(hidden_states)
hidden_states = residual + hidden_states
return hidden_states
class GemmaModel(nn.Module):
def __init__(
self,
config: GemmaConfig,
linear_method: Optional[LinearMethodBase] = None,
) -> None:
super().__init__()
self.config = config
self.embed_tokens = VocabParallelEmbedding(
config.vocab_size,
config.hidden_size,
)
self.layers = nn.ModuleList([
GemmaDecoderLayer(config, linear_method)
for _ in range(config.num_hidden_layers)
])
self.norm = GemmaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
def forward(
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
kv_caches: List[KVCache],
input_metadata: InputMetadata,
) -> torch.Tensor:
hidden_states = self.embed_tokens(input_ids)
# Normalize the embedding by sqrt(hidden_size)
hidden_states = hidden_states * (self.config.hidden_size**0.5)
for i in range(len(self.layers)):
layer = self.layers[i]
hidden_states = layer(
positions,
hidden_states,
kv_caches[i],
input_metadata,
)
hidden_states = self.norm(hidden_states)
return hidden_states
class GemmaForCausalLM(nn.Module):
def __init__(
self,
config: GemmaConfig,
linear_method: Optional[LinearMethodBase] = None,
) -> None:
super().__init__()
self.config = config
self.linear_method = linear_method
self.model = GemmaModel(config, linear_method)
self.sampler = Sampler(config.vocab_size)
@torch.no_grad()
def forward(
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
kv_caches: List[KVCache],
input_metadata: InputMetadata,
) -> torch.Tensor:
hidden_states = self.model(input_ids, positions, kv_caches,
input_metadata)
return hidden_states
def sample(
self,
hidden_states: torch.Tensor,
sampling_metadata: SamplingMetadata,
) -> Optional[SamplerOutput]:
next_tokens = self.sampler(self.model.embed_tokens.weight,
hidden_states, sampling_metadata)
return next_tokens
def load_weights(self,
model_name_or_path: str,
cache_dir: Optional[str] = None,
load_format: str = "auto",
revision: Optional[str] = None):
stacked_params_mapping = [
# (param_name, shard_name, shard_id)
("qkv_proj", "q_proj", "q"),
("qkv_proj", "k_proj", "k"),
("qkv_proj", "v_proj", "v"),
]
params_dict = dict(self.named_parameters())
loaded_params = set()
for name, loaded_weight in hf_model_weights_iterator(
model_name_or_path, cache_dir, load_format, revision):
for (param_name, shard_name, shard_id) in stacked_params_mapping:
if shard_name not in name:
continue
name = name.replace(shard_name, param_name)
param = params_dict[name]
weight_loader = param.weight_loader
weight_loader(param, loaded_weight, shard_id)
break
else:
# Skip loading extra layer for lora models.
if "lm_head" in name:
continue
param = params_dict[name]
weight_loader = getattr(param, "weight_loader",
default_weight_loader)
weight_loader(param, loaded_weight)
loaded_params.add(name)
unloaded_params = params_dict.keys() - loaded_params
if unloaded_params:
raise RuntimeError(
f"Some weights are not initialized from checkpoints: {unloaded_params}"
)
# coding=utf-8
# Adapted from
# https://github.com/allenai/OLMo/blob/v0.2.4/olmo/model.py and
# https://github.com/allenai/OLMo/blob/v0.2.4/hf_olmo/modeling_olmo.py
# Copyright 2023 The vLLM team.
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
#
# BSD 3-Clause License
#
# Copyright (c) 2022, Tri Dao, trid@cs.stanford.edu.
# All rights reserved.
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are met:
#
# * Redistributions of source code must retain the above copyright notice, this
# list of conditions and the following disclaimer.
#
# * Redistributions in binary form must reproduce the above copyright notice,
# this list of conditions and the following disclaimer in the documentation
# and/or other materials provided with the distribution.
#
# * Neither the name of the copyright holder nor the names of its
# contributors may be used to endorse or promote products derived from
# this software without specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
"""Inference-only OLMo model compatible with HuggingFace weights."""
from typing import List, Optional, Tuple
import torch
import torch.nn.functional as F
from torch import nn
from vllm.model_executor.input_metadata import InputMetadata
from vllm.model_executor.layers.attention import PagedAttention
from vllm.model_executor.layers.linear import (
ColumnParallelLinear,
LinearMethodBase,
QKVParallelLinear,
RowParallelLinear,
)
from vllm.model_executor.layers.rotary_embedding import get_rope
from vllm.model_executor.layers.sampler import Sampler
from vllm.model_executor.layers.vocab_parallel_embedding import VocabParallelEmbedding
from vllm.model_executor.parallel_utils.parallel_state import (
get_tensor_model_parallel_world_size, )
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.model_executor.weight_utils import (
default_weight_loader,
hf_model_weights_iterator,
)
from vllm.sequence import SamplerOutput
from vllm.transformers_utils.configs.olmo import OLMoConfig
KVCache = Tuple[torch.Tensor, torch.Tensor]
class SwiGLU(nn.Module):
def forward(self, x: torch.Tensor) -> torch.Tensor:
x, gate = x.chunk(2, dim=-1)
return F.silu(gate) * x
@property
def output_multiplier(self) -> float:
return 0.5
class OlmoAttention(nn.Module):
"""
This is the attention block where the output is computed as ``Attention(LN(x))`` in ``MLP(LN(x + Attention(LN(x))))``
(plus another skip connection).
"""
def __init__(
self,
config: OLMoConfig,
linear_method: Optional[LinearMethodBase] = None,
):
super().__init__()
self.config = config
self.hidden_size = config.d_model
assert config.d_model % config.n_heads == 0
tensor_model_parallel_world_size = get_tensor_model_parallel_world_size(
)
self.total_num_heads = self.config.n_heads
assert self.total_num_heads % tensor_model_parallel_world_size == 0
self.num_heads = self.total_num_heads // tensor_model_parallel_world_size
self.head_dim = self.hidden_size // self.total_num_heads
# Layer norms.
self.attn_norm = nn.LayerNorm(config.d_model,
elementwise_affine=False,
bias=False)
# Attention input projection. Projects x -> (q, k, v)
self.att_proj = QKVParallelLinear(
config.d_model,
self.head_dim,
self.total_num_heads,
bias=config.include_bias,
linear_method=linear_method,
)
# Rotary embeddings.
if self.config.rope:
rope_theta = getattr(config, "rope_theta", 10000)
max_position_embeddings = getattr(config,
"max_position_embeddings", 8192)
self.rotary_emb = get_rope(
self.head_dim,
rotary_dim=self.head_dim,
max_position=max_position_embeddings,
base=rope_theta,
)
self.scaling = self.head_dim**-0.5
self.attn = PagedAttention(self.num_heads,
self.head_dim,
scale=self.scaling)
# Attention output projection.
self.attn_out = RowParallelLinear(
config.d_model,
config.d_model,
bias=config.include_bias,
linear_method=linear_method,
)
def forward(
self,
positions: torch.Tensor,
hidden_states: torch.Tensor,
kv_cache: KVCache,
input_metadata: InputMetadata,
) -> torch.Tensor:
hidden_states = self.attn_norm(hidden_states)
qkv, _ = self.att_proj(hidden_states)
q, k, v = qkv.chunk(chunks=3, dim=-1)
if self.config.rope:
q, k = self.rotary_emb(positions, q, k)
k_cache, v_cache = kv_cache
attn_output = self.attn(q, k, v, k_cache, v_cache, input_metadata)
output, _ = self.attn_out(attn_output)
return output
class OlmoMLP(nn.Module):
"""
This is the MLP block where the output is computed as ``MLP(LN(x))`` in ``MLP(LN(x + Attention(LN(x))))``
(plus another skip connection).
"""
def __init__(
self,
config: OLMoConfig,
linear_method: Optional[LinearMethodBase] = None,
):
super().__init__()
self.config = config
self.hidden_size = (config.mlp_hidden_size if config.mlp_hidden_size
is not None else config.mlp_ratio * config.d_model)
# Layer norms.
self.ff_norm = nn.LayerNorm(config.d_model,
elementwise_affine=False,
bias=False)
# Feed-forward input projection.
self.ff_proj = ColumnParallelLinear(
config.d_model,
self.hidden_size,
bias=config.include_bias,
linear_method=linear_method,
)
# Activation function.
# self.act = SiluAndMul()
# self.act.output_multiplier = 0.5
self.act = SwiGLU()
assert (self.act.output_multiplier * self.hidden_size) % 1 == 0
# Feed-forward output projection.
self.ff_out = RowParallelLinear(
int(self.act.output_multiplier * self.hidden_size),
config.d_model,
bias=config.include_bias,
linear_method=linear_method,
)
def forward(
self,
x: torch.Tensor,
) -> torch.Tensor:
# Add feed-forward projection.
# shape: (batch_size, seq_len, d_model)
og_x = x
x = self.ff_norm(x)
x, _ = self.ff_proj(x)
x = self.act(x)
x, _ = self.ff_out(x)
x = og_x + x
return x
class OlmoBlock(nn.Module):
"""
This is a typical transformer block where the output is computed as ``MLP(LN(x + Attention(LN(x))))``
(plus another skip connection).
"""
def __init__(self,
config: OLMoConfig,
linear_method: Optional[LinearMethodBase] = None):
super().__init__()
# Attention block.
self.attn = OlmoAttention(config, linear_method)
# MLP block.
self.mlp = OlmoMLP(config, linear_method)
def forward(
self,
positions: torch.Tensor,
hidden_states: torch.Tensor,
kv_cache: KVCache,
input_metadata: InputMetadata,
) -> Tuple[torch.Tensor, Optional[Tuple[torch.Tensor, torch.Tensor]]]:
# Attention block.
og_x = hidden_states
x = self.attn(positions, hidden_states, kv_cache, input_metadata)
x = x + og_x
# MLP block.
hidden_states = self.mlp(x)
return hidden_states
class OlmoModel(nn.Module):
def __init__(self,
config: OLMoConfig,
linear_method: Optional[LinearMethodBase] = None):
super().__init__()
self.config = config
self.transformer = nn.ModuleDict(
dict(
wte=VocabParallelEmbedding(
config.embedding_size or config.vocab_size,
config.d_model,
),
ln_f=nn.LayerNorm(config.d_model,
elementwise_affine=False,
bias=False),
))
blocks = [
OlmoBlock(config, linear_method) for i in range(config.n_layers)
]
if self.config.block_group_size > 1:
raise NotImplementedError("Block group size > 1 not supported yet")
else:
self.transformer.update({"blocks": nn.ModuleList(blocks)})
if not config.weight_tying:
self.transformer.update({
"ff_out":
ColumnParallelLinear(
config.d_model,
config.embedding_size or config.vocab_size,
bias=config.include_bias,
linear_method=linear_method,
)
})
def forward(
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
kv_caches: List[KVCache],
input_metadata: InputMetadata,
) -> torch.Tensor:
"""
:param input_ids: A tensor of shape `(batch_size, seq_len)`.
"""
# Get embeddings of input.
# shape: (batch_size, seq_len, d_model)
x = self.transformer.wte(input_ids) # type: ignore
# Apply blocks one-by-one.
for block_idx, block in enumerate(self.transformer.blocks):
# shape: (batch_size, seq_len, d_model)
x = block(
positions,
x,
kv_caches[block_idx],
input_metadata,
)
# Apply final layer norm.
# shape: (batch_size, seq_len or 1, d_model)
x = self.transformer.ln_f(x) # type: ignore
return x
class OLMoForCausalLM(nn.Module):
"""
Extremely barebones HF model wrapper.
"""
def __init__(self,
config: OLMoConfig,
linear_method: Optional[LinearMethodBase] = None):
super().__init__()
self.config = config
self.linear_method = linear_method
self.model = OlmoModel(config, linear_method)
self.lm_head_weight = (self.model.transformer.wte.weight
if config.weight_tying else
self.model.transformer.ff_out.weight)
self.sampler = Sampler(config.vocab_size)
def forward(
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
kv_caches: List[KVCache],
input_metadata: InputMetadata,
) -> torch.Tensor:
hidden_states = self.model(
input_ids=input_ids,
positions=positions,
kv_caches=kv_caches,
input_metadata=input_metadata,
)
return hidden_states
def sample(
self,
hidden_states: torch.Tensor,
sampling_metadata: SamplingMetadata,
) -> Optional[SamplerOutput]:
next_tokens = self.sampler(self.lm_head_weight, hidden_states,
sampling_metadata)
return next_tokens
def load_weights(
self,
model_name_or_path: str,
cache_dir: Optional[str] = None,
load_format: str = "auto",
revision: Optional[str] = None,
):
params_dict = dict(self.named_parameters(remove_duplicate=False))
for name, loaded_weight in hf_model_weights_iterator(
model_name_or_path, cache_dir, load_format, revision):
# attention
if ".att" in name:
name = name.replace(".att", ".attn.att")
# mlp
if ".ff" in name and "transformer.ff_out" not in name:
name = name.replace(".ff", ".mlp.ff")
# there is no bias in olmo
param = params_dict[name]
weight_loader = getattr(param, "weight_loader",
default_weight_loader)
weight_loader(param, loaded_weight)
...@@ -19,6 +19,7 @@ class SamplingMetadata: ...@@ -19,6 +19,7 @@ class SamplingMetadata:
prompt_lens: Lengths of prompts. prompt_lens: Lengths of prompts.
selected_token_indices: Token indices selected for sampling. selected_token_indices: Token indices selected for sampling.
categorized_sample_indices: SamplingType -> token indices to sample. categorized_sample_indices: SamplingType -> token indices to sample.
generators: List of torch.Generators to use for seeded sampling
perform_sampling: Whether to perform sampling. This option is used to perform_sampling: Whether to perform sampling. This option is used to
make the sampling only happens in the driver worker, and disable make the sampling only happens in the driver worker, and disable
sampling in other worker processes. sampling in other worker processes.
...@@ -31,6 +32,7 @@ class SamplingMetadata: ...@@ -31,6 +32,7 @@ class SamplingMetadata:
prompt_lens: Optional[List[int]], prompt_lens: Optional[List[int]],
selected_token_indices: torch.Tensor, selected_token_indices: torch.Tensor,
categorized_sample_indices: Optional[Dict[SamplingType, torch.Tensor]], categorized_sample_indices: Optional[Dict[SamplingType, torch.Tensor]],
generators: Optional[List[torch.Generator]] = None,
perform_sampling: bool = True, perform_sampling: bool = True,
) -> None: ) -> None:
self.seq_groups = seq_groups self.seq_groups = seq_groups
...@@ -38,6 +40,7 @@ class SamplingMetadata: ...@@ -38,6 +40,7 @@ class SamplingMetadata:
self.prompt_lens = prompt_lens self.prompt_lens = prompt_lens
self.selected_token_indices = selected_token_indices self.selected_token_indices = selected_token_indices
self.categorized_sample_indices = categorized_sample_indices self.categorized_sample_indices = categorized_sample_indices
self.generators = generators
self.perform_sampling = perform_sampling self.perform_sampling = perform_sampling
self.num_prompts = len(prompt_lens) if prompt_lens is not None else 0 self.num_prompts = len(prompt_lens) if prompt_lens is not None else 0
......
from typing import List, Optional from typing import List, Optional
import time
from vllm.sequence import (PromptLogprobs, SampleLogprobs, SequenceGroup, from vllm.sequence import (PromptLogprobs, SampleLogprobs, SequenceGroup,
SequenceStatus) SequenceStatus, RequestMetrics)
from vllm.lora.request import LoRARequest from vllm.lora.request import LoRARequest
...@@ -60,6 +61,7 @@ class RequestOutput: ...@@ -60,6 +61,7 @@ class RequestOutput:
prompt_logprobs: The log probabilities to return per prompt token. prompt_logprobs: The log probabilities to return per prompt token.
outputs: The output sequences of the request. outputs: The output sequences of the request.
finished: Whether the whole request is finished. finished: Whether the whole request is finished.
metrics: Metrics associated with the request.
lora_request: The LoRA request that was used to generate the output. lora_request: The LoRA request that was used to generate the output.
""" """
...@@ -71,6 +73,7 @@ class RequestOutput: ...@@ -71,6 +73,7 @@ class RequestOutput:
prompt_logprobs: Optional[PromptLogprobs], prompt_logprobs: Optional[PromptLogprobs],
outputs: List[CompletionOutput], outputs: List[CompletionOutput],
finished: bool, finished: bool,
metrics: Optional[RequestMetrics] = None,
lora_request: Optional[LoRARequest] = None, lora_request: Optional[LoRARequest] = None,
) -> None: ) -> None:
self.request_id = request_id self.request_id = request_id
...@@ -79,6 +82,7 @@ class RequestOutput: ...@@ -79,6 +82,7 @@ class RequestOutput:
self.prompt_logprobs = prompt_logprobs self.prompt_logprobs = prompt_logprobs
self.outputs = outputs self.outputs = outputs
self.finished = finished self.finished = finished
self.metrics = metrics
self.lora_request = lora_request self.lora_request = lora_request
@classmethod @classmethod
...@@ -115,12 +119,15 @@ class RequestOutput: ...@@ -115,12 +119,15 @@ class RequestOutput:
prompt_token_ids = seq_group.prompt_token_ids prompt_token_ids = seq_group.prompt_token_ids
prompt_logprobs = seq_group.prompt_logprobs prompt_logprobs = seq_group.prompt_logprobs
finished = seq_group.is_finished() finished = seq_group.is_finished()
finished_time = time.time() if finished else None
seq_group.set_finished_time(finished_time)
return cls(seq_group.request_id, return cls(seq_group.request_id,
prompt, prompt,
prompt_token_ids, prompt_token_ids,
prompt_logprobs, prompt_logprobs,
outputs, outputs,
finished, finished,
seq_group.metrics,
lora_request=seq_group.lora_request) lora_request=seq_group.lora_request)
def __repr__(self) -> str: def __repr__(self) -> str:
...@@ -130,4 +137,5 @@ class RequestOutput: ...@@ -130,4 +137,5 @@ class RequestOutput:
f"prompt_logprobs={self.prompt_logprobs}, " f"prompt_logprobs={self.prompt_logprobs}, "
f"outputs={self.outputs}, " f"outputs={self.outputs}, "
f"finished={self.finished}, " f"finished={self.finished}, "
f"metrics={self.metrics}, "
f"lora_request={self.lora_request})") f"lora_request={self.lora_request})")
...@@ -11,7 +11,8 @@ _SAMPLING_EPS = 1e-5 ...@@ -11,7 +11,8 @@ _SAMPLING_EPS = 1e-5
class SamplingType(IntEnum): class SamplingType(IntEnum):
GREEDY = 0 GREEDY = 0
RANDOM = 1 RANDOM = 1
BEAM = 2 RANDOM_SEED = 2
BEAM = 3
LogitsProcessor = Callable[[List[int], torch.Tensor], torch.Tensor] LogitsProcessor = Callable[[List[int], torch.Tensor], torch.Tensor]
...@@ -56,6 +57,7 @@ class SamplingParams: ...@@ -56,6 +57,7 @@ class SamplingParams:
min_p: Float that represents the minimum probability for a token to be min_p: Float that represents the minimum probability for a token to be
considered, relative to the probability of the most likely token. considered, relative to the probability of the most likely token.
Must be in [0, 1]. Set to 0 to disable this. Must be in [0, 1]. Set to 0 to disable this.
seed: Random seed to use for the generation.
use_beam_search: Whether to use beam search instead of sampling. use_beam_search: Whether to use beam search instead of sampling.
length_penalty: Float that penalizes sequences based on their length. length_penalty: Float that penalizes sequences based on their length.
Used in beam search. Used in beam search.
...@@ -101,6 +103,7 @@ class SamplingParams: ...@@ -101,6 +103,7 @@ class SamplingParams:
top_p: float = 1.0, top_p: float = 1.0,
top_k: int = -1, top_k: int = -1,
min_p: float = 0.0, min_p: float = 0.0,
seed: Optional[int] = None,
use_beam_search: bool = False, use_beam_search: bool = False,
length_penalty: float = 1.0, length_penalty: float = 1.0,
early_stopping: Union[bool, str] = False, early_stopping: Union[bool, str] = False,
...@@ -124,6 +127,7 @@ class SamplingParams: ...@@ -124,6 +127,7 @@ class SamplingParams:
self.top_p = top_p self.top_p = top_p
self.top_k = top_k self.top_k = top_k
self.min_p = min_p self.min_p = min_p
self.seed = seed
self.use_beam_search = use_beam_search self.use_beam_search = use_beam_search
self.length_penalty = length_penalty self.length_penalty = length_penalty
self.early_stopping = early_stopping self.early_stopping = early_stopping
...@@ -229,6 +233,8 @@ class SamplingParams: ...@@ -229,6 +233,8 @@ class SamplingParams:
return SamplingType.BEAM return SamplingType.BEAM
if self.temperature < _SAMPLING_EPS: if self.temperature < _SAMPLING_EPS:
return SamplingType.GREEDY return SamplingType.GREEDY
if self.seed is not None:
return SamplingType.RANDOM_SEED
return SamplingType.RANDOM return SamplingType.RANDOM
def __repr__(self) -> str: def __repr__(self) -> str:
...@@ -242,6 +248,7 @@ class SamplingParams: ...@@ -242,6 +248,7 @@ class SamplingParams:
f"top_p={self.top_p}, " f"top_p={self.top_p}, "
f"top_k={self.top_k}, " f"top_k={self.top_k}, "
f"min_p={self.min_p}, " f"min_p={self.min_p}, "
f"seed={self.seed}, "
f"use_beam_search={self.use_beam_search}, " f"use_beam_search={self.use_beam_search}, "
f"length_penalty={self.length_penalty}, " f"length_penalty={self.length_penalty}, "
f"early_stopping={self.early_stopping}, " f"early_stopping={self.early_stopping}, "
......
"""Sequence and its related classes.""" """Sequence and its related classes."""
import copy import copy
import enum import enum
from dataclasses import dataclass
from typing import Dict, List, Optional, Union from typing import Dict, List, Optional, Union
from vllm.block import LogicalTokenBlock from vllm.block import LogicalTokenBlock
...@@ -49,6 +50,25 @@ class SequenceStatus(enum.Enum): ...@@ -49,6 +50,25 @@ class SequenceStatus(enum.Enum):
return finish_reason return finish_reason
@dataclass
class RequestMetrics:
"""Metrics associated with a request.
Args:
arrival_time: The time when the request arrived.
first_scheduled_time: The time when the request was first scheduled.
first_token_time: The time when the first token was generated.
time_in_queue: The time the request spent in the queue.
finished_time: The time when the request was finished.
"""
arrival_time: float
last_token_time: float
first_scheduled_time: Optional[float]
first_token_time: Optional[float]
time_in_queue: Optional[float]
finished_time: Optional[float] = None
class SequenceData: class SequenceData:
"""Data associated with a sequence. """Data associated with a sequence.
...@@ -228,6 +248,14 @@ class Sequence: ...@@ -228,6 +248,14 @@ class Sequence:
f"num_blocks={len(self.logical_token_blocks)})") f"num_blocks={len(self.logical_token_blocks)})")
@dataclass
class SequenceGroupState:
"""Mutable state tied to a specific sequence group"""
# torch.Generator used in seeded sampling
generator: Optional = None
class SequenceGroup: class SequenceGroup:
"""A group of sequences that are generated from the same prompt. """A group of sequences that are generated from the same prompt.
...@@ -252,11 +280,15 @@ class SequenceGroup: ...@@ -252,11 +280,15 @@ class SequenceGroup:
self.request_id = request_id self.request_id = request_id
self.seqs_dict = {seq.seq_id: seq for seq in seqs} self.seqs_dict = {seq.seq_id: seq for seq in seqs}
self.sampling_params = sampling_params self.sampling_params = sampling_params
self.arrival_time = arrival_time self.metrics = RequestMetrics(arrival_time=arrival_time,
self.last_token_time = arrival_time last_token_time=arrival_time,
first_scheduled_time=None,
first_token_time=None,
time_in_queue=None)
self.lora_request = lora_request self.lora_request = lora_request
self.prefix: Optional[Prefix] = prefix self.prefix: Optional[Prefix] = prefix
self.prompt_logprobs: Optional[PromptLogprobs] = None self.prompt_logprobs: Optional[PromptLogprobs] = None
self.state = SequenceGroupState()
@property @property
def prompt(self) -> str: def prompt(self) -> str:
...@@ -276,10 +308,25 @@ class SequenceGroup: ...@@ -276,10 +308,25 @@ class SequenceGroup:
def get_last_latency(self, now: float) -> float: def get_last_latency(self, now: float) -> float:
"""Gets last token latency for Request level timings.""" """Gets last token latency for Request level timings."""
latency = now - self.last_token_time latency = now - self.metrics.last_token_time
self.last_token_time = now self.metrics.last_token_time = now
return latency return latency
def maybe_set_first_token_time(self, time: float) -> None:
"""Sets the first token time for Request level timings."""
if self.metrics.first_token_time is None:
self.metrics.first_token_time = time
def maybe_set_first_scheduled_time(self, time: float) -> None:
"""Sets the first scheduled time and time in queue for Request level timings."""
if self.metrics.first_scheduled_time is None:
self.metrics.first_scheduled_time = time
self.metrics.time_in_queue = time - self.metrics.arrival_time
def set_finished_time(self, time: Optional[float]) -> None:
"""Sets the finished time for Request level timings."""
self.metrics.finished_time = time
def get_max_num_running_seqs(self) -> int: def get_max_num_running_seqs(self) -> int:
"""The maximum number of sequences running in parallel in the remaining """The maximum number of sequences running in parallel in the remaining
lifetime of the request.""" lifetime of the request."""
...@@ -359,6 +406,7 @@ class SequenceGroupMetadata: ...@@ -359,6 +406,7 @@ class SequenceGroupMetadata:
sampling_params: The sampling parameters used to generate the outputs. sampling_params: The sampling parameters used to generate the outputs.
block_tables: The block tables. (Seq id -> list of physical block block_tables: The block tables. (Seq id -> list of physical block
numbers) numbers)
state: Internal state tied to this sequence group.
lora_request: LoRA request. lora_request: LoRA request.
prefix: The prefix of the prompt of the sequence group. prefix: The prefix of the prompt of the sequence group.
""" """
...@@ -372,6 +420,7 @@ class SequenceGroupMetadata: ...@@ -372,6 +420,7 @@ class SequenceGroupMetadata:
block_tables: Dict[int, List[int]], block_tables: Dict[int, List[int]],
lora_request: Optional[LoRARequest] = None, lora_request: Optional[LoRARequest] = None,
prefix: Optional[Prefix] = None, prefix: Optional[Prefix] = None,
state: Optional[SequenceGroupState] = None,
) -> None: ) -> None:
self.request_id = request_id self.request_id = request_id
self.is_prompt = is_prompt self.is_prompt = is_prompt
...@@ -380,6 +429,7 @@ class SequenceGroupMetadata: ...@@ -380,6 +429,7 @@ class SequenceGroupMetadata:
self.block_tables = block_tables self.block_tables = block_tables
self.lora_request = lora_request self.lora_request = lora_request
self.prefix = prefix self.prefix = prefix
self.state = SequenceGroupState() if state is None else state
@property @property
def lora_int_id(self) -> int: def lora_int_id(self) -> int:
......
...@@ -16,10 +16,14 @@ _CONFIG_REGISTRY = { ...@@ -16,10 +16,14 @@ _CONFIG_REGISTRY = {
def get_config(model: str, def get_config(model: str,
trust_remote_code: bool, trust_remote_code: bool,
revision: Optional[str] = None) -> PretrainedConfig: revision: Optional[str] = None,
code_revision: Optional[str] = None) -> PretrainedConfig:
try: try:
config = AutoConfig.from_pretrained( config = AutoConfig.from_pretrained(
model, trust_remote_code=trust_remote_code, revision=revision) model,
trust_remote_code=trust_remote_code,
revision=revision,
code_revision=code_revision)
except ValueError as e: except ValueError as e:
if (not trust_remote_code and if (not trust_remote_code and
"requires you to execute the configuration file" in str(e)): "requires you to execute the configuration file" in str(e)):
...@@ -33,5 +37,7 @@ def get_config(model: str, ...@@ -33,5 +37,7 @@ def get_config(model: str,
raise e raise e
if config.model_type in _CONFIG_REGISTRY: if config.model_type in _CONFIG_REGISTRY:
config_class = _CONFIG_REGISTRY[config.model_type] config_class = _CONFIG_REGISTRY[config.model_type]
config = config_class.from_pretrained(model, revision=revision) config = config_class.from_pretrained(model,
revision=revision,
code_revision=code_revision)
return config return config
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment