Unverified Commit d6fa1be3 authored by Zhuohan Li's avatar Zhuohan Li Committed by GitHub
Browse files

[Quality] Add code formatter and linter (#326)

parent 0ffded81
...@@ -67,8 +67,7 @@ class LLMEngine: ...@@ -67,8 +67,7 @@ class LLMEngine:
f"download_dir={model_config.download_dir!r}, " f"download_dir={model_config.download_dir!r}, "
f"use_np_weights={model_config.use_np_weights}, " f"use_np_weights={model_config.use_np_weights}, "
f"tensor_parallel_size={parallel_config.tensor_parallel_size}, " f"tensor_parallel_size={parallel_config.tensor_parallel_size}, "
f"seed={model_config.seed})" f"seed={model_config.seed})")
)
# TODO(woosuk): Print more configs in debug mode. # TODO(woosuk): Print more configs in debug mode.
self.model_config = model_config self.model_config = model_config
...@@ -78,8 +77,8 @@ class LLMEngine: ...@@ -78,8 +77,8 @@ class LLMEngine:
self.log_stats = log_stats self.log_stats = log_stats
self._verify_args() self._verify_args()
self.tokenizer = get_tokenizer(model_config.tokenizer, self.tokenizer = get_tokenizer(
model_config.tokenizer_mode) model_config.tokenizer, tokenizer_mode=model_config.tokenizer_mode)
self.seq_counter = Counter() self.seq_counter = Counter()
# Create the parallel GPU workers. # Create the parallel GPU workers.
...@@ -129,8 +128,8 @@ class LLMEngine: ...@@ -129,8 +128,8 @@ class LLMEngine:
num_gpu_blocks = min(b[0] for b in num_blocks) num_gpu_blocks = min(b[0] for b in num_blocks)
num_cpu_blocks = min(b[1] for b in num_blocks) num_cpu_blocks = min(b[1] for b in num_blocks)
# FIXME(woosuk): Change to debug log. # FIXME(woosuk): Change to debug log.
logger.info(f'# GPU blocks: {num_gpu_blocks}, ' logger.info(f"# GPU blocks: {num_gpu_blocks}, "
f'# CPU blocks: {num_cpu_blocks}') f"# CPU blocks: {num_cpu_blocks}")
if num_gpu_blocks <= 0: if num_gpu_blocks <= 0:
raise ValueError("No available memory for the cache blocks. " raise ValueError("No available memory for the cache blocks. "
...@@ -152,7 +151,9 @@ class LLMEngine: ...@@ -152,7 +151,9 @@ class LLMEngine:
# Initialize the cluster. # Initialize the cluster.
distributed_init_method, devices = initialize_cluster(parallel_config) distributed_init_method, devices = initialize_cluster(parallel_config)
# Create the LLM engine. # Create the LLM engine.
engine = cls(*engine_configs, distributed_init_method, devices, engine = cls(*engine_configs,
distributed_init_method,
devices,
log_stats=not engine_args.disable_log_stats) log_stats=not engine_args.disable_log_stats)
return engine return engine
...@@ -226,8 +227,10 @@ class LLMEngine: ...@@ -226,8 +227,10 @@ class LLMEngine:
and updates the scheduler with the model outputs. Finally, it decodes and updates the scheduler with the model outputs. Finally, it decodes
the sequences and returns the newly generated results. the sequences and returns the newly generated results.
""" """
seq_group_metadata_list, scheduler_outputs, ignored_seq_groups = self.scheduler.schedule() (seq_group_metadata_list, scheduler_outputs,
if (not seq_group_metadata_list) and scheduler_outputs.is_empty() and (not ignored_seq_groups): ignored_seq_groups) = self.scheduler.schedule()
if ((not seq_group_metadata_list) and scheduler_outputs.is_empty()
and (not ignored_seq_groups)):
# Nothing to do. # Nothing to do.
return [] return []
...@@ -281,8 +284,8 @@ class LLMEngine: ...@@ -281,8 +284,8 @@ class LLMEngine:
# Truncate the output text so that the stop string is # Truncate the output text so that the stop string is
# not included in the output. # not included in the output.
seq.output_text = seq.output_text[:-len(stop_str)] seq.output_text = seq.output_text[:-len(stop_str)]
self.scheduler.free_seq(seq, self.scheduler.free_seq(
SequenceStatus.FINISHED_STOPPED) seq, SequenceStatus.FINISHED_STOPPED)
stopped = True stopped = True
break break
if stopped: if stopped:
...@@ -290,7 +293,7 @@ class LLMEngine: ...@@ -290,7 +293,7 @@ class LLMEngine:
# Check if the sequence has reached max_seq_len. # Check if the sequence has reached max_seq_len.
if (seq.get_len() >= if (seq.get_len() >=
self.scheduler.scheduler_config.max_seq_len): self.scheduler.scheduler_config.max_seq_len):
self.scheduler.free_seq( self.scheduler.free_seq(
seq, SequenceStatus.FINISHED_LENGTH_CAPPED) seq, SequenceStatus.FINISHED_LENGTH_CAPPED)
continue continue
...@@ -302,15 +305,15 @@ class LLMEngine: ...@@ -302,15 +305,15 @@ class LLMEngine:
# Check if the sequence has generated the EOS token. # Check if the sequence has generated the EOS token.
if not sampling_params.ignore_eos: if not sampling_params.ignore_eos:
if seq.get_last_token_id() == self.tokenizer.eos_token_id: if seq.get_last_token_id() == self.tokenizer.eos_token_id:
self.scheduler.free_seq(seq, self.scheduler.free_seq(
SequenceStatus.FINISHED_STOPPED) seq, SequenceStatus.FINISHED_STOPPED)
continue continue
def _run_workers( def _run_workers(
self, self,
method: str, method: str,
get_all_outputs: bool = False,
*args, *args,
get_all_outputs: bool = False,
**kwargs, **kwargs,
) -> Any: ) -> Any:
"""Runs the given method on all workers.""" """Runs the given method on all workers."""
......
...@@ -8,7 +8,8 @@ except ImportError: ...@@ -8,7 +8,8 @@ except ImportError:
from vllm.config import ParallelConfig from vllm.config import ParallelConfig
DeviceID = Tuple[int, Optional[str], int] # rank, node resource (node IP), device id # rank, node resource (node IP), device id
DeviceID = Tuple[int, Optional[str], int]
def initialize_cluster( def initialize_cluster(
...@@ -53,15 +54,15 @@ def initialize_cluster( ...@@ -53,15 +54,15 @@ def initialize_cluster(
valid_node_resources = [] valid_node_resources = []
num_devices_per_node = None num_devices_per_node = None
for node in ray.nodes(): for node in ray.nodes():
if (not node['Alive']) or node['Resources']['GPU'] <= 0: if (not node["Alive"]) or node["Resources"]["GPU"] <= 0:
continue continue
if num_devices_per_node is None: if num_devices_per_node is None:
num_devices_per_node = node['Resources']['GPU'] num_devices_per_node = node["Resources"]["GPU"]
else: else:
assert num_devices_per_node == node['Resources']['GPU'], ( assert num_devices_per_node == node["Resources"]["GPU"], (
"The number of GPUs per node is not uniform.") "The number of GPUs per node is not uniform.")
for key in node['Resources']: for key in node["Resources"]:
if key.startswith('node:'): if key.startswith("node:"):
valid_node_resources.append(key) valid_node_resources.append(key)
# Verify the parallel config. # Verify the parallel config.
......
...@@ -11,8 +11,8 @@ from vllm.engine.async_llm_engine import AsyncLLMEngine ...@@ -11,8 +11,8 @@ from vllm.engine.async_llm_engine import AsyncLLMEngine
from vllm.sampling_params import SamplingParams from vllm.sampling_params import SamplingParams
from vllm.utils import random_uuid from vllm.utils import random_uuid
TIMEOUT_KEEP_ALIVE = 5 # seconds. TIMEOUT_KEEP_ALIVE = 5 # seconds.
TIMEOUT_TO_PREVENT_DEADLOCK = 1 # seconds TIMEOUT_TO_PREVENT_DEADLOCK = 1 # seconds.
app = FastAPI() app = FastAPI()
...@@ -37,8 +37,7 @@ async def generate(request: Request) -> Response: ...@@ -37,8 +37,7 @@ async def generate(request: Request) -> Response:
async for request_output in results_generator: async for request_output in results_generator:
prompt = request_output.prompt prompt = request_output.prompt
text_outputs = [ text_outputs = [
prompt + output.text prompt + output.text for output in request_output.outputs
for output in request_output.outputs
] ]
ret = {"text": text_outputs} ret = {"text": text_outputs}
yield (json.dumps(ret) + "\0").encode("utf-8") yield (json.dumps(ret) + "\0").encode("utf-8")
...@@ -63,10 +62,7 @@ async def generate(request: Request) -> Response: ...@@ -63,10 +62,7 @@ async def generate(request: Request) -> Response:
assert final_output is not None assert final_output is not None
prompt = final_output.prompt prompt = final_output.prompt
text_outputs = [ text_outputs = [prompt + output.text for output in final_output.outputs]
prompt + output.text
for output in final_output.outputs
]
ret = {"text": text_outputs} ret = {"text": text_outputs}
return Response(content=json.dumps(ret)) return Response(content=json.dumps(ret))
...@@ -81,5 +77,8 @@ if __name__ == "__main__": ...@@ -81,5 +77,8 @@ if __name__ == "__main__":
engine_args = AsyncEngineArgs.from_cli_args(args) engine_args = AsyncEngineArgs.from_cli_args(args)
engine = AsyncLLMEngine.from_engine_args(engine_args) engine = AsyncLLMEngine.from_engine_args(engine_args)
uvicorn.run(app, host=args.host, port=args.port, log_level="debug", uvicorn.run(app,
host=args.host,
port=args.port,
log_level="debug",
timeout_keep_alive=TIMEOUT_KEEP_ALIVE) timeout_keep_alive=TIMEOUT_KEEP_ALIVE)
...@@ -63,8 +63,7 @@ class LLM: ...@@ -63,8 +63,7 @@ class LLM:
self.request_counter = Counter() self.request_counter = Counter()
def get_tokenizer( def get_tokenizer(
self, self) -> Union[PreTrainedTokenizer, PreTrainedTokenizerFast]:
) -> Union[PreTrainedTokenizer, PreTrainedTokenizerFast]:
return self.llm_engine.tokenizer return self.llm_engine.tokenizer
def set_tokenizer( def set_tokenizer(
......
# Adapted from https://github.com/lm-sys/FastChat/blob/168ccc29d3f7edc50823016105c024fe2282732a/fastchat/serve/openai_api_server.py # Adapted from
# https://github.com/lm-sys/FastChat/blob/168ccc29d3f7edc50823016105c024fe2282732a/fastchat/serve/openai_api_server.py
import argparse import argparse
from http import HTTPStatus from http import HTTPStatus
...@@ -29,7 +30,7 @@ from vllm.sampling_params import SamplingParams ...@@ -29,7 +30,7 @@ from vllm.sampling_params import SamplingParams
from vllm.transformers_utils.tokenizer import get_tokenizer from vllm.transformers_utils.tokenizer import get_tokenizer
from vllm.utils import random_uuid from vllm.utils import random_uuid
TIMEOUT_KEEP_ALIVE = 5 # seconds TIMEOUT_KEEP_ALIVE = 5 # seconds
logger = init_logger(__name__) logger = init_logger(__name__)
served_model = None served_model = None
...@@ -38,14 +39,13 @@ app = fastapi.FastAPI() ...@@ -38,14 +39,13 @@ app = fastapi.FastAPI()
def create_error_response(status_code: HTTPStatus, def create_error_response(status_code: HTTPStatus,
message: str) -> JSONResponse: message: str) -> JSONResponse:
return JSONResponse( return JSONResponse(ErrorResponse(message=message,
ErrorResponse(message=message, type="invalid_request_error").dict(), type="invalid_request_error").dict(),
status_code=status_code.value status_code=status_code.value)
)
@app.exception_handler(RequestValidationError) @app.exception_handler(RequestValidationError)
async def validation_exception_handler(request, exc): async def validation_exception_handler(request, exc): # pylint: disable=unused-argument
return create_error_response(HTTPStatus.BAD_REQUEST, str(exc)) return create_error_response(HTTPStatus.BAD_REQUEST, str(exc))
...@@ -126,8 +126,11 @@ async def check_length(request, prompt, engine): ...@@ -126,8 +126,11 @@ async def check_length(request, prompt, engine):
@app.get("/v1/models") @app.get("/v1/models")
async def show_available_models(): async def show_available_models():
"""Show available models. Right now we only have one model.""" """Show available models. Right now we only have one model."""
model_cards = [ModelCard(id=served_model, root=served_model, model_cards = [
permission=[ModelPermission()])] ModelCard(id=served_model,
root=served_model,
permission=[ModelPermission()])
]
return ModelList(data=model_cards) return ModelList(data=model_cards)
...@@ -144,12 +147,14 @@ def create_logprobs(token_ids: List[int], ...@@ -144,12 +147,14 @@ def create_logprobs(token_ids: List[int],
if len(logprobs.text_offset) == 0: if len(logprobs.text_offset) == 0:
logprobs.text_offset.append(initial_text_offset) logprobs.text_offset.append(initial_text_offset)
else: else:
logprobs.text_offset.append(logprobs.text_offset[-1] + last_token_len) logprobs.text_offset.append(logprobs.text_offset[-1] +
last_token_len)
last_token_len = len(token) last_token_len = len(token)
logprobs.top_logprobs.append( logprobs.top_logprobs.append({
{tokenizer.convert_ids_to_tokens(i): p tokenizer.convert_ids_to_tokens(i): p
for i, p in id_logprob.items()}) for i, p in id_logprob.items()
})
return logprobs return logprobs
...@@ -348,7 +353,7 @@ async def create_completion(raw_request: Request): ...@@ -348,7 +353,7 @@ async def create_completion(raw_request: Request):
if request.suffix is not None: if request.suffix is not None:
# The language models we currently support do not support suffix. # The language models we currently support do not support suffix.
return create_error_response(HTTPStatus.BAD_REQUEST, return create_error_response(HTTPStatus.BAD_REQUEST,
"suffix is not currently supported") "suffix is not currently supported")
if request.logit_bias is not None: if request.logit_bias is not None:
# TODO: support logit_bias in vLLM engine. # TODO: support logit_bias in vLLM engine.
...@@ -387,22 +392,23 @@ async def create_completion(raw_request: Request): ...@@ -387,22 +392,23 @@ async def create_completion(raw_request: Request):
except ValueError as e: except ValueError as e:
return create_error_response(HTTPStatus.BAD_REQUEST, str(e)) return create_error_response(HTTPStatus.BAD_REQUEST, str(e))
result_generator = engine.generate(prompt, sampling_params, result_generator = engine.generate(prompt, sampling_params, request_id)
request_id)
# Similar to the OpenAI API, when n != best_of, we do not stream the # Similar to the OpenAI API, when n != best_of, we do not stream the
# results. In addition, we do not stream the results when use beam search. # results. In addition, we do not stream the results when use beam search.
stream = (request.stream and stream = (request.stream
(request.best_of is None or request.n == request.best_of) and and (request.best_of is None or request.n == request.best_of)
not request.use_beam_search) and not request.use_beam_search)
async def abort_request() -> None: async def abort_request() -> None:
await engine.abort(request_id) await engine.abort(request_id)
def create_stream_response_json(index: int, def create_stream_response_json(
text: str, index: int,
logprobs: Optional[LogProbs] = None, text: str,
finish_reason: Optional[str] = None) -> str: logprobs: Optional[LogProbs] = None,
finish_reason: Optional[str] = None,
) -> str:
choice_data = CompletionResponseStreamChoice( choice_data = CompletionResponseStreamChoice(
index=index, index=index,
text=text, text=text,
...@@ -443,7 +449,8 @@ async def create_completion(raw_request: Request): ...@@ -443,7 +449,8 @@ async def create_completion(raw_request: Request):
) )
yield f"data: {response_json}\n\n" yield f"data: {response_json}\n\n"
if output.finish_reason is not None: if output.finish_reason is not None:
logprobs = LogProbs() if request.logprobs is not None else None logprobs = (LogProbs()
if request.logprobs is not None else None)
response_json = create_stream_response_json( response_json = create_stream_response_json(
index=i, index=i,
text="", text="",
...@@ -487,8 +494,8 @@ async def create_completion(raw_request: Request): ...@@ -487,8 +494,8 @@ async def create_completion(raw_request: Request):
choices.append(choice_data) choices.append(choice_data)
num_prompt_tokens = len(final_res.prompt_token_ids) num_prompt_tokens = len(final_res.prompt_token_ids)
num_generated_tokens = sum(len(output.token_ids) num_generated_tokens = sum(
for output in final_res.outputs) len(output.token_ids) for output in final_res.outputs)
usage = UsageInfo( usage = UsageInfo(
prompt_tokens=num_prompt_tokens, prompt_tokens=num_prompt_tokens,
completion_tokens=num_generated_tokens, completion_tokens=num_generated_tokens,
...@@ -506,9 +513,11 @@ async def create_completion(raw_request: Request): ...@@ -506,9 +513,11 @@ async def create_completion(raw_request: Request):
# When user requests streaming but we don't stream, we still need to # When user requests streaming but we don't stream, we still need to
# return a streaming response with a single event. # return a streaming response with a single event.
response_json = response.json(ensure_ascii=False) response_json = response.json(ensure_ascii=False)
async def fake_stream_generator() -> AsyncGenerator[str, None]: async def fake_stream_generator() -> AsyncGenerator[str, None]:
yield f"data: {response_json}\n\n" yield f"data: {response_json}\n\n"
yield "data: [DONE]\n\n" yield "data: [DONE]\n\n"
return StreamingResponse(fake_stream_generator(), return StreamingResponse(fake_stream_generator(),
media_type="text/event-stream") media_type="text/event-stream")
...@@ -517,26 +526,34 @@ async def create_completion(raw_request: Request): ...@@ -517,26 +526,34 @@ async def create_completion(raw_request: Request):
if __name__ == "__main__": if __name__ == "__main__":
parser = argparse.ArgumentParser( parser = argparse.ArgumentParser(
description="vLLM OpenAI-Compatible RESTful API server." description="vLLM OpenAI-Compatible RESTful API server.")
) parser.add_argument("--host",
parser.add_argument("--host", type=str, default="localhost", help="host name") type=str,
default="localhost",
help="host name")
parser.add_argument("--port", type=int, default=8000, help="port number") parser.add_argument("--port", type=int, default=8000, help="port number")
parser.add_argument("--allow-credentials",
action="store_true",
help="allow credentials")
parser.add_argument("--allowed-origins",
type=json.loads,
default=["*"],
help="allowed origins")
parser.add_argument("--allowed-methods",
type=json.loads,
default=["*"],
help="allowed methods")
parser.add_argument("--allowed-headers",
type=json.loads,
default=["*"],
help="allowed headers")
parser.add_argument( parser.add_argument(
"--allow-credentials", action="store_true", help="allow credentials" "--served-model-name",
) type=str,
parser.add_argument( default=None,
"--allowed-origins", type=json.loads, default=["*"], help="allowed origins" help="The model name used in the API. If not specified, "
) "the model name will be the same as the "
parser.add_argument( "huggingface name.")
"--allowed-methods", type=json.loads, default=["*"], help="allowed methods"
)
parser.add_argument(
"--allowed-headers", type=json.loads, default=["*"], help="allowed headers"
)
parser.add_argument("--served-model-name", type=str, default=None,
help="The model name used in the API. If not specified, "
"the model name will be the same as the "
"huggingface name.")
parser = AsyncEngineArgs.add_cli_args(parser) parser = AsyncEngineArgs.add_cli_args(parser)
args = parser.parse_args() args = parser.parse_args()
...@@ -556,7 +573,11 @@ if __name__ == "__main__": ...@@ -556,7 +573,11 @@ if __name__ == "__main__":
engine = AsyncLLMEngine.from_engine_args(engine_args) engine = AsyncLLMEngine.from_engine_args(engine_args)
# A separate tokenizer to map token IDs to strings. # A separate tokenizer to map token IDs to strings.
tokenizer = get_tokenizer(engine_args.tokenizer, engine_args.tokenizer_mode) tokenizer = get_tokenizer(engine_args.tokenizer,
tokenizer_mode=engine_args.tokenizer_mode)
uvicorn.run(app, host=args.host, port=args.port, log_level="info", uvicorn.run(app,
host=args.host,
port=args.port,
log_level="info",
timeout_keep_alive=TIMEOUT_KEEP_ALIVE) timeout_keep_alive=TIMEOUT_KEEP_ALIVE)
# Adapted from https://github.com/lm-sys/FastChat/blob/168ccc29d3f7edc50823016105c024fe2282732a/fastchat/protocol/openai_api_protocol.py # Adapted from
# https://github.com/lm-sys/FastChat/blob/168ccc29d3f7edc50823016105c024fe2282732a/fastchat/protocol/openai_api_protocol.py
import time import time
from typing import Dict, List, Literal, Optional, Union from typing import Dict, List, Literal, Optional, Union
...@@ -98,7 +99,8 @@ class LogProbs(BaseModel): ...@@ -98,7 +99,8 @@ class LogProbs(BaseModel):
text_offset: List[int] = Field(default_factory=list) text_offset: List[int] = Field(default_factory=list)
token_logprobs: List[Optional[float]] = Field(default_factory=list) token_logprobs: List[Optional[float]] = Field(default_factory=list)
tokens: List[str] = Field(default_factory=list) tokens: List[str] = Field(default_factory=list)
top_logprobs: List[Optional[Dict[str, float]]] = Field(default_factory=list) top_logprobs: List[Optional[Dict[str,
float]]] = Field(default_factory=list)
class CompletionResponseChoice(BaseModel): class CompletionResponseChoice(BaseModel):
......
# Adapted from https://github.com/skypilot-org/skypilot/blob/86dc0f6283a335e4aa37b3c10716f90999f48ab6/sky/sky_logging.py # Adapted from
# https://github.com/skypilot-org/skypilot/blob/86dc0f6283a335e4aa37b3c10716f90999f48ab6/sky/sky_logging.py
"""Logging configuration for vLLM."""
import logging import logging
import sys import sys
_FORMAT = "%(levelname)s %(asctime)s %(filename)s:%(lineno)d] %(message)s" _FORMAT = "%(levelname)s %(asctime)s %(filename)s:%(lineno)d] %(message)s"
_DATE_FORMAT = "%m-%d %H:%M:%S" _DATE_FORMAT = "%m-%d %H:%M:%S"
......
...@@ -2,7 +2,6 @@ from vllm.model_executor.input_metadata import InputMetadata ...@@ -2,7 +2,6 @@ from vllm.model_executor.input_metadata import InputMetadata
from vllm.model_executor.model_loader import get_model from vllm.model_executor.model_loader import get_model
from vllm.model_executor.utils import set_random_seed from vllm.model_executor.utils import set_random_seed
__all__ = [ __all__ = [
"InputMetadata", "InputMetadata",
"get_model", "get_model",
......
...@@ -8,11 +8,22 @@ from vllm.sequence import SequenceData ...@@ -8,11 +8,22 @@ from vllm.sequence import SequenceData
class InputMetadata: class InputMetadata:
"""Metadata for input sequences. Used for PagedAttention.
Args:
seq_groups: List of (seq_ids, sampling_params).
seq_data: Seq_id -> SequenceData.
prompt_lens: Lengths of prompts.
slot_mapping: The address to write the new KV to of each token.
context_lens: the length of attention context for each generation token.
max_context_len: The maximum context length.
block_tables: The block tables. (Seq id -> list of physical block)
"""
def __init__( def __init__(
self, self,
seq_groups: List[Tuple[List[int], SamplingParams]], # List of (seq_ids, sampling_params). seq_groups: List[Tuple[List[int], SamplingParams]],
seq_data: Dict[int, SequenceData], # Seq_id -> SequenceData. seq_data: Dict[int, SequenceData],
prompt_lens: List[int], prompt_lens: List[int],
slot_mapping: torch.Tensor, slot_mapping: torch.Tensor,
context_lens: torch.Tensor, context_lens: torch.Tensor,
......
...@@ -6,9 +6,10 @@ from vllm import activation_ops ...@@ -6,9 +6,10 @@ from vllm import activation_ops
_ACTIVATION_REGISTRY = { _ACTIVATION_REGISTRY = {
"gelu": nn.GELU(), "gelu": nn.GELU(),
"gelu_new": nn.GELU(approximate="tanh"), # NOTE: This may introduce small rounding errors. # NOTE: The following GELU functions may introduce small rounding errors.
"gelu_fast": nn.GELU(approximate="tanh"), # NOTE: This may introduce small rounding errors. "gelu_new": nn.GELU(approximate="tanh"),
"gelu_pytorch_tanh": nn.GELU(approximate="tanh"), # NOTE: This may introduce small rounding errors. "gelu_fast": nn.GELU(approximate="tanh"),
"gelu_pytorch_tanh": nn.GELU(approximate="tanh"),
"relu": nn.ReLU(), "relu": nn.ReLU(),
} }
...@@ -25,15 +26,13 @@ class SiluAndMul(nn.Module): ...@@ -25,15 +26,13 @@ class SiluAndMul(nn.Module):
"""An activation function for SwiGLU. """An activation function for SwiGLU.
The function computes x -> silu(x[:d]) * x[d:] where d = x.shape[1] // 2. The function computes x -> silu(x[:d]) * x[d:] where d = x.shape[1] // 2.
"""
def __init__(self): Shapes:
super().__init__() x: (num_tokens, 2 * d)
return: (num_tokens, d)
"""
def forward( def forward(self, x: torch.Tensor) -> torch.Tensor:
self,
x: torch.Tensor, # (num_tokens, 2 * d)
) -> torch.Tensor: # (num_tokens, d)
num_tokens = x.shape[0] num_tokens = x.shape[0]
d = x.shape[1] // 2 d = x.shape[1] // 2
out = torch.empty(num_tokens, d, dtype=x.dtype, device=x.device) out = torch.empty(num_tokens, d, dtype=x.dtype, device=x.device)
......
...@@ -14,6 +14,7 @@ _SUPPORTED_HEAD_SIZES = [64, 80, 96, 128] ...@@ -14,6 +14,7 @@ _SUPPORTED_HEAD_SIZES = [64, 80, 96, 128]
class PagedAttention(nn.Module): class PagedAttention(nn.Module):
# pylint: disable=line-too-long
"""GPT-style multi-head PagedAttention. """GPT-style multi-head PagedAttention.
This class takes flattened 1D query, key, and value tensors as input. The This class takes flattened 1D query, key, and value tensors as input. The
...@@ -54,12 +55,20 @@ class PagedAttention(nn.Module): ...@@ -54,12 +55,20 @@ class PagedAttention(nn.Module):
def multi_query_kv_attention( def multi_query_kv_attention(
self, self,
output: torch.Tensor, # [num_prompt_tokens, num_heads, head_size] output: torch.Tensor,
query: torch.Tensor, # [num_prompt_tokens, num_heads, head_size] query: torch.Tensor,
key: torch.Tensor, # [num_prompt_tokens, num_heads, head_size] key: torch.Tensor,
value: torch.Tensor, # [num_prompt_tokens, num_heads, head_size] value: torch.Tensor,
attn_bias: xops.AttentionBias, attn_bias: xops.AttentionBias,
) -> torch.Tensor: ) -> torch.Tensor:
"""Normal attention for the prompt tokens.
Args:
output: shape = [num_prompt_tokens, num_heads, head_size]
query: shape = [num_prompt_tokens, num_heads, head_size]
key: shape = [num_prompt_tokens, num_heads, head_size]
value: shape = [num_prompt_tokens, num_heads, head_size]
"""
# TODO(woosuk): The unsqueeze op may incur some CPU overhead. Optimize. # TODO(woosuk): The unsqueeze op may incur some CPU overhead. Optimize.
out = xops.memory_efficient_attention_forward( out = xops.memory_efficient_attention_forward(
query.unsqueeze(0), query.unsqueeze(0),
...@@ -76,12 +85,22 @@ class PagedAttention(nn.Module): ...@@ -76,12 +85,22 @@ class PagedAttention(nn.Module):
def single_query_cached_kv_attention( def single_query_cached_kv_attention(
self, self,
output: torch.Tensor, # [num_generation_tokens, num_heads, head_size] output: torch.Tensor,
query: torch.Tensor, # [num_generation_tokens, num_heads, head_size] query: torch.Tensor,
key_cache: torch.Tensor, # [num_blocks, num_heads, head_size/x, block_size, x] key_cache: torch.Tensor,
value_cache: torch.Tensor, # [num_blocks, num_heads, head_size, block_size] value_cache: torch.Tensor,
input_metadata: InputMetadata, input_metadata: InputMetadata,
) -> None: ) -> None:
"""PagedAttention for the generation tokens.
Args:
output: shape = [num_generation_tokens, num_heads, head_size]
query: shape = [num_generation_tokens, num_heads, head_size]
key_cache: shape = [num_blocks, num_heads, head_size/x,
block_size, x]
value_cache: shape = [num_blocks, num_heads, head_size, block_size]
input_metadata: metadata for paged attention.
"""
block_size = value_cache.shape[3] block_size = value_cache.shape[3]
attention_ops.single_query_cached_kv_attention( attention_ops.single_query_cached_kv_attention(
output, output,
...@@ -97,16 +116,32 @@ class PagedAttention(nn.Module): ...@@ -97,16 +116,32 @@ class PagedAttention(nn.Module):
def forward( def forward(
self, self,
query: torch.Tensor, # [num_tokens, num_heads * head_size] query: torch.Tensor,
key: torch.Tensor, # [num_tokens, num_heads * head_size] key: torch.Tensor,
value: torch.Tensor, # [num_tokens, num_heads * head_size] value: torch.Tensor,
key_cache: Optional[torch.Tensor], # [num_blocks, num_heads, head_size/x, block_size, x] key_cache: Optional[torch.Tensor],
value_cache: Optional[torch.Tensor], # [num_blocks, num_heads, head_size, block_size] value_cache: Optional[torch.Tensor],
input_metadata: InputMetadata, input_metadata: InputMetadata,
cache_event: Optional[torch.cuda.Event], cache_event: Optional[torch.cuda.Event],
) -> torch.Tensor: # [num_tokens, num_heads * head_size] ) -> torch.Tensor:
# NOTE: The query, key, and value tensors must be sliced from a qkv """PagedAttention forward pass.
# tensor of shape [num_tokens, 3 * num_heads * head_size].
NOTE: The query, key, and value tensors must be sliced from a qkv
tensor of shape [num_tokens, 3 * num_heads * head_size].
Args:
query: shape = [num_tokens, num_heads * head_size]
key: shape = [num_tokens, num_heads * head_size]
value: shape = [num_tokens, num_heads * head_size]
key_cache: shape = [num_blocks, num_heads, head_size/x,
block_size, x]
value_cache: shape = [num_blocks, num_heads, head_size, block_size]
input_metadata: metadata for paged attention.
cache_event: event to wait for the cache operations to finish.
Returns:
shape = [num_tokens, num_heads * head_size]
"""
# Reshape the query, key, and value tensors. # Reshape the query, key, and value tensors.
query = query.view(-1, self.num_heads, self.head_size) query = query.view(-1, self.num_heads, self.head_size)
...@@ -136,7 +171,7 @@ class PagedAttention(nn.Module): ...@@ -136,7 +171,7 @@ class PagedAttention(nn.Module):
# and value vectors will not be cached. # and value vectors will not be cached.
num_valid_tokens = input_metadata.num_valid_tokens num_valid_tokens = input_metadata.num_valid_tokens
if (num_valid_tokens > 0 and key_cache is not None if (num_valid_tokens > 0 and key_cache is not None
and value_cache is not None): and value_cache is not None):
# The stride is 3 because the key and value are sliced from qkv. # The stride is 3 because the key and value are sliced from qkv.
cache_ops.reshape_and_cache( cache_ops.reshape_and_cache(
key[:num_valid_tokens], key[:num_valid_tokens],
...@@ -149,15 +184,12 @@ class PagedAttention(nn.Module): ...@@ -149,15 +184,12 @@ class PagedAttention(nn.Module):
if input_metadata.num_generation_tokens > 0: if input_metadata.num_generation_tokens > 0:
assert key_cache is not None and value_cache is not None, ( assert key_cache is not None and value_cache is not None, (
"key_cache and value_cache must be provided when " "key_cache and value_cache must be provided when "
"generating tokens." "generating tokens.")
)
# Compute the attention op for generation tokens. # Compute the attention op for generation tokens.
self.single_query_cached_kv_attention( self.single_query_cached_kv_attention(
output[num_prompt_tokens:num_valid_tokens], output[num_prompt_tokens:num_valid_tokens],
query[num_prompt_tokens:num_valid_tokens], query[num_prompt_tokens:num_valid_tokens], key_cache,
key_cache, value_cache, input_metadata)
value_cache,
input_metadata)
# Reshape the output tensor. # Reshape the output tensor.
# NOTE(woosuk): The output tensor may include paddings. # NOTE(woosuk): The output tensor may include paddings.
...@@ -179,9 +211,9 @@ class PagedAttentionWithRoPE(PagedAttention): ...@@ -179,9 +211,9 @@ class PagedAttentionWithRoPE(PagedAttention):
super().__init__(num_heads, head_size, scale) super().__init__(num_heads, head_size, scale)
# Create the cos and sin cache. # Create the cos and sin cache.
inv_freq = 1.0 / (base ** (torch.arange(0, rotary_dim, 2) / rotary_dim)) inv_freq = 1.0 / (base**(torch.arange(0, rotary_dim, 2) / rotary_dim))
t = torch.arange(max_position).float() t = torch.arange(max_position).float()
freqs = torch.einsum('i,j -> ij', t, inv_freq.float()) freqs = torch.einsum("i,j -> ij", t, inv_freq.float())
cos = freqs.cos() cos = freqs.cos()
sin = freqs.sin() sin = freqs.sin()
cache = torch.cat((cos, sin), dim=-1) cache = torch.cat((cos, sin), dim=-1)
...@@ -195,15 +227,32 @@ class PagedAttentionWithRoPE(PagedAttention): ...@@ -195,15 +227,32 @@ class PagedAttentionWithRoPE(PagedAttention):
def forward( def forward(
self, self,
positions: torch.Tensor, # [num_tokens] positions: torch.Tensor,
query: torch.Tensor, # [num_tokens, num_heads * head_size] query: torch.Tensor,
key: torch.Tensor, # [num_tokens, num_heads * head_size] key: torch.Tensor,
value: torch.Tensor, # [num_tokens, num_heads * head_size] value: torch.Tensor,
key_cache: torch.Tensor, # [num_blocks, num_heads, head_size/x, block_size, x] key_cache: torch.Tensor,
value_cache: torch.Tensor, # [num_blocks, num_heads, head_size, block_size] value_cache: torch.Tensor,
input_metadata: InputMetadata, input_metadata: InputMetadata,
cache_event: Optional[torch.cuda.Event], cache_event: Optional[torch.cuda.Event],
) -> torch.Tensor: # [num_tokens, num_heads * head_size] ) -> torch.Tensor:
""" PagedAttention forward pass with rotary embedding.
Args:
positions: shape = [num_tokens]
query: shape = [num_tokens, num_heads * head_size]
key: shape = [num_tokens, num_heads * head_size]
value: shape = [num_tokens, num_heads * head_size]
key_cache: shape = [num_blocks, num_heads, head_size/x,
block_size, x]
value_cache: shape = [num_blocks, num_heads, head_size, block_size]
input_metadata: metadata for paged attention.
cache_event: event to wait for the cache operations to finish.
Returns:
shape = [num_tokens, num_heads * head_size]
"""
# Apply rotary embedding to the query and key before passing them # Apply rotary embedding to the query and key before passing them
# to the attention op. # to the attention op.
pos_encoding_ops.rotary_embedding_neox( pos_encoding_ops.rotary_embedding_neox(
......
...@@ -13,6 +13,7 @@ from vllm.sequence import SequenceOutputs ...@@ -13,6 +13,7 @@ from vllm.sequence import SequenceOutputs
_SAMPLING_EPS = 1e-5 _SAMPLING_EPS = 1e-5
class Sampler(nn.Module): class Sampler(nn.Module):
"""Samples the next tokens from the model's outputs. """Samples the next tokens from the model's outputs.
...@@ -50,19 +51,20 @@ class Sampler(nn.Module): ...@@ -50,19 +51,20 @@ class Sampler(nn.Module):
# Apply presence and frequency penalties. # Apply presence and frequency penalties.
output_tokens = _get_output_tokens(input_metadata) output_tokens = _get_output_tokens(input_metadata)
assert len(output_tokens) == logits.shape[0] assert len(output_tokens) == logits.shape[0]
presence_penalties, frequency_penalties = _get_penalties(input_metadata) presence_penalties, frequency_penalties = _get_penalties(
input_metadata)
assert len(presence_penalties) == logits.shape[0] assert len(presence_penalties) == logits.shape[0]
assert len(frequency_penalties) == logits.shape[0] assert len(frequency_penalties) == logits.shape[0]
logits = _apply_penalties( logits = _apply_penalties(logits, output_tokens, presence_penalties,
logits, output_tokens, presence_penalties, frequency_penalties, frequency_penalties, self.vocab_size)
self.vocab_size)
# Apply temperature scaling. # Apply temperature scaling.
temperatures = _get_temperatures(input_metadata) temperatures = _get_temperatures(input_metadata)
assert len(temperatures) == logits.shape[0] assert len(temperatures) == logits.shape[0]
if any(t != 1.0 for t in temperatures): if any(t != 1.0 for t in temperatures):
t = torch.tensor( t = torch.tensor(temperatures,
temperatures, dtype=logits.dtype, device=logits.device) dtype=logits.dtype,
device=logits.device)
# Use in-place division to avoid creating a new tensor. # Use in-place division to avoid creating a new tensor.
logits.div_(t.unsqueeze(dim=1)) logits.div_(t.unsqueeze(dim=1))
...@@ -75,7 +77,9 @@ class Sampler(nn.Module): ...@@ -75,7 +77,9 @@ class Sampler(nn.Module):
# Apply top-p and top-k truncation. # Apply top-p and top-k truncation.
top_ps, top_ks = _get_top_p_top_k(input_metadata, self.vocab_size) top_ps, top_ks = _get_top_p_top_k(input_metadata, self.vocab_size)
assert len(top_ps) == len(top_ks) == probs.shape[0] assert len(top_ps) == len(top_ks) == probs.shape[0]
if any(p < 1.0 - _SAMPLING_EPS for p in top_ps) or any(k != self.vocab_size for k in top_ks): do_top_p = any(p < 1.0 - _SAMPLING_EPS for p in top_ps)
do_top_k = any(k != self.vocab_size for k in top_ks)
if do_top_p or do_top_k:
probs = _apply_top_p_top_k(probs, top_ps, top_ks) probs = _apply_top_p_top_k(probs, top_ps, top_ks)
# Sample the next tokens. # Sample the next tokens.
...@@ -97,8 +101,7 @@ def _prune_hidden_states( ...@@ -97,8 +101,7 @@ def _prune_hidden_states(
def _get_penalties( def _get_penalties(
input_metadata: InputMetadata, input_metadata: InputMetadata) -> Tuple[List[float], List[float]]:
) -> Tuple[List[float], List[float]]:
# Collect the presence and frequency penalties. # Collect the presence and frequency penalties.
presence_penalties: List[float] = [] presence_penalties: List[float] = []
frequency_penalties: List[float] = [] frequency_penalties: List[float] = []
...@@ -117,9 +120,7 @@ def _get_penalties( ...@@ -117,9 +120,7 @@ def _get_penalties(
return presence_penalties, frequency_penalties return presence_penalties, frequency_penalties
def _get_output_tokens( def _get_output_tokens(input_metadata: InputMetadata) -> List[List[int]]:
input_metadata: InputMetadata,
) -> List[List[int]]:
output_tokens: List[List[int]] = [] output_tokens: List[List[int]] = []
for i, seq_group in enumerate(input_metadata.seq_groups): for i, seq_group in enumerate(input_metadata.seq_groups):
seq_ids, _ = seq_group seq_ids, _ = seq_group
...@@ -169,11 +170,13 @@ def _apply_penalties( ...@@ -169,11 +170,13 @@ def _apply_penalties(
device=logits.device) device=logits.device)
frequency_penalties = [frequency_penalties[i] for i in indices] frequency_penalties = [frequency_penalties[i] for i in indices]
frequency_penalties = torch.tensor( frequency_penalties = torch.tensor(frequency_penalties,
frequency_penalties, dtype=logits.dtype, device=logits.device) dtype=logits.dtype,
device=logits.device)
presence_penalties = [presence_penalties[i] for i in indices] presence_penalties = [presence_penalties[i] for i in indices]
presence_penalties = torch.tensor( presence_penalties = torch.tensor(presence_penalties,
presence_penalties, dtype=logits.dtype, device=logits.device) dtype=logits.dtype,
device=logits.device)
# We follow the definition in OpenAI API. # We follow the definition in OpenAI API.
# Refer to https://platform.openai.com/docs/api-reference/parameter-details # Refer to https://platform.openai.com/docs/api-reference/parameter-details
...@@ -183,9 +186,7 @@ def _apply_penalties( ...@@ -183,9 +186,7 @@ def _apply_penalties(
return logits return logits
def _get_temperatures( def _get_temperatures(input_metadata: InputMetadata) -> List[float]:
input_metadata: InputMetadata,
) -> List[float]:
# Collect the temperatures for the logits. # Collect the temperatures for the logits.
temperatures: List[float] = [] temperatures: List[float] = []
for i, seq_group in enumerate(input_metadata.seq_groups): for i, seq_group in enumerate(input_metadata.seq_groups):
...@@ -252,8 +253,9 @@ def _apply_top_p_top_k( ...@@ -252,8 +253,9 @@ def _apply_top_p_top_k(
probs_sort[top_k_mask] = 0.0 probs_sort[top_k_mask] = 0.0
# Re-sort the probabilities. # Re-sort the probabilities.
probs = torch.gather( probs = torch.gather(probs_sort,
probs_sort, dim=-1, index=torch.argsort(probs_idx, dim=-1)) dim=-1,
index=torch.argsort(probs_idx, dim=-1))
return probs return probs
...@@ -296,8 +298,9 @@ def _sample_from_prompt( ...@@ -296,8 +298,9 @@ def _sample_from_prompt(
# Random sampling. # Random sampling.
# Sample `best_of` tokens for the prompt. # Sample `best_of` tokens for the prompt.
num_seqs = sampling_params.best_of num_seqs = sampling_params.best_of
next_token_ids = torch.multinomial( next_token_ids = torch.multinomial(prob,
prob, num_samples=num_seqs, replacement=True) num_samples=num_seqs,
replacement=True)
next_token_ids = next_token_ids.tolist() next_token_ids = next_token_ids.tolist()
return next_token_ids return next_token_ids
...@@ -315,8 +318,9 @@ def _sample_from_generation_tokens( ...@@ -315,8 +318,9 @@ def _sample_from_generation_tokens(
if sampling_params.use_beam_search: if sampling_params.use_beam_search:
# Beam search. # Beam search.
# Add cumulative logprobs for the sequences in the group. # Add cumulative logprobs for the sequences in the group.
seq_logprobs = torch.tensor( seq_logprobs = torch.tensor(seq_logprobs,
seq_logprobs, dtype=torch.float, device=logprobs.device) dtype=torch.float,
device=logprobs.device)
logprobs = logprobs + seq_logprobs.unsqueeze(dim=1) logprobs = logprobs + seq_logprobs.unsqueeze(dim=1)
vocab_size = logprobs.size(-1) vocab_size = logprobs.size(-1)
...@@ -353,8 +357,9 @@ def _sample_from_generation_tokens( ...@@ -353,8 +357,9 @@ def _sample_from_generation_tokens(
else: else:
# Random sampling. # Random sampling.
# Sample 1 token for each sequence in the group. # Sample 1 token for each sequence in the group.
next_token_ids = torch.multinomial( next_token_ids = torch.multinomial(probs,
probs, num_samples=1, replacement=True) num_samples=1,
replacement=True)
next_token_ids = next_token_ids.squeeze(dim=-1).tolist() next_token_ids = next_token_ids.squeeze(dim=-1).tolist()
parent_seq_ids = seq_ids parent_seq_ids = seq_ids
return parent_seq_ids, next_token_ids return parent_seq_ids, next_token_ids
...@@ -381,15 +386,16 @@ def _sample( ...@@ -381,15 +386,16 @@ def _sample(
# Sample the next tokens. # Sample the next tokens.
next_token_ids = _sample_from_prompt(prob, sampling_params) next_token_ids = _sample_from_prompt(prob, sampling_params)
# Get top-k log probabilities for the next tokens. # Get top-k log probabilities for the next tokens.
next_logprobs = _get_topk_logprobs( next_logprobs = _get_topk_logprobs(logprob,
logprob, sampling_params.logprobs) sampling_params.logprobs)
# Build the output. # Build the output.
for seq_id, next_token_id in zip(seq_ids, next_token_ids): for seq_id, next_token_id in zip(seq_ids, next_token_ids):
output_logprobs = next_logprobs.copy() output_logprobs = next_logprobs.copy()
output_logprobs[next_token_id] = logprob[next_token_id].item() output_logprobs[next_token_id] = logprob[next_token_id].item()
seq_outputs[seq_id] = SequenceOutputs( seq_outputs[seq_id] = SequenceOutputs(seq_id, seq_id,
seq_id, seq_id, next_token_id, output_logprobs) next_token_id,
output_logprobs)
else: else:
# Generate the next tokens for generation tokens. # Generate the next tokens for generation tokens.
prob = probs[idx:idx + len(seq_ids)] prob = probs[idx:idx + len(seq_ids)]
...@@ -399,22 +405,24 @@ def _sample( ...@@ -399,22 +405,24 @@ def _sample(
# Sample the next tokens. # Sample the next tokens.
seq_logprobs = [ seq_logprobs = [
input_metadata.seq_data[seq_id].cumulative_logprob input_metadata.seq_data[seq_id].cumulative_logprob
for seq_id in seq_ids] for seq_id in seq_ids
]
parent_seq_ids, next_token_ids = _sample_from_generation_tokens( parent_seq_ids, next_token_ids = _sample_from_generation_tokens(
seq_ids, prob, logprob, seq_logprobs, sampling_params) seq_ids, prob, logprob, seq_logprobs, sampling_params)
# Get top-k log probabilities for the next tokens. # Get top-k log probabilities for the next tokens.
next_logprobs: Dict[int, Dict[int, float]] = {} next_logprobs: Dict[int, Dict[int, float]] = {}
for i, seq_id in enumerate(seq_ids): for j, seq_id in enumerate(seq_ids):
next_logprobs[seq_id] = _get_topk_logprobs( next_logprobs[seq_id] = _get_topk_logprobs(
logprob[i], sampling_params.logprobs) logprob[j], sampling_params.logprobs)
# Build the output. # Build the output.
for seq_id, parent_seq_id, next_token_id in zip( for seq_id, parent_seq_id, next_token_id in zip(
seq_ids, parent_seq_ids, next_token_ids): seq_ids, parent_seq_ids, next_token_ids):
i = seq_ids.index(parent_seq_id) j = seq_ids.index(parent_seq_id)
output_logprobs = next_logprobs[parent_seq_id].copy() output_logprobs = next_logprobs[parent_seq_id].copy()
output_logprobs[next_token_id] = logprob[i, next_token_id].item() output_logprobs[next_token_id] = logprob[j,
next_token_id].item()
seq_outputs[seq_id] = SequenceOutputs( seq_outputs[seq_id] = SequenceOutputs(
seq_id, seq_id,
parent_seq_id, parent_seq_id,
......
...@@ -6,8 +6,9 @@ import torch.nn as nn ...@@ -6,8 +6,9 @@ import torch.nn as nn
from transformers import PretrainedConfig from transformers import PretrainedConfig
from vllm.config import ModelConfig from vllm.config import ModelConfig
from vllm.model_executor.models import (GPT2LMHeadModel, GPTBigCodeForCausalLM, GPTNeoXForCausalLM, from vllm.model_executor.models import (GPT2LMHeadModel, GPTBigCodeForCausalLM,
LlamaForCausalLM, OPTForCausalLM) GPTNeoXForCausalLM, LlamaForCausalLM,
OPTForCausalLM)
from vllm.model_executor.weight_utils import initialize_dummy_weights from vllm.model_executor.weight_utils import initialize_dummy_weights
# TODO(woosuk): Lazy-load the model classes. # TODO(woosuk): Lazy-load the model classes.
...@@ -28,8 +29,7 @@ def _get_model_architecture(config: PretrainedConfig) -> Type[nn.Module]: ...@@ -28,8 +29,7 @@ def _get_model_architecture(config: PretrainedConfig) -> Type[nn.Module]:
return _MODEL_REGISTRY[arch] return _MODEL_REGISTRY[arch]
raise ValueError( raise ValueError(
f"Model architectures {architectures} are not supported for now. " f"Model architectures {architectures} are not supported for now. "
f"Supported architectures: {list(_MODEL_REGISTRY.keys())}" f"Supported architectures: {list(_MODEL_REGISTRY.keys())}")
)
def get_model(model_config: ModelConfig) -> nn.Module: def get_model(model_config: ModelConfig) -> nn.Module:
...@@ -46,8 +46,7 @@ def get_model(model_config: ModelConfig) -> nn.Module: ...@@ -46,8 +46,7 @@ def get_model(model_config: ModelConfig) -> nn.Module:
initialize_dummy_weights(model) initialize_dummy_weights(model)
else: else:
# Load the weights from the cached or downloaded files. # Load the weights from the cached or downloaded files.
model.load_weights( model.load_weights(model_config.model, model_config.download_dir,
model_config.model, model_config.download_dir, model_config.use_np_weights)
model_config.use_np_weights)
model = model.cuda() model = model.cuda()
return model.eval() return model.eval()
...@@ -4,8 +4,6 @@ from vllm.model_executor.models.gpt_neox import GPTNeoXForCausalLM ...@@ -4,8 +4,6 @@ from vllm.model_executor.models.gpt_neox import GPTNeoXForCausalLM
from vllm.model_executor.models.llama import LlamaForCausalLM from vllm.model_executor.models.llama import LlamaForCausalLM
from vllm.model_executor.models.opt import OPTForCausalLM from vllm.model_executor.models.opt import OPTForCausalLM
__all__ = [ __all__ = [
"GPT2LMHeadModel", "GPT2LMHeadModel",
"GPTBigCodeForCausalLM", "GPTBigCodeForCausalLM",
......
# coding=utf-8 # coding=utf-8
# Adapted from https://github.com/huggingface/transformers/blob/v4.28.0/src/transformers/models/gpt2/modeling_gpt2.py # Adapted from
# https://github.com/huggingface/transformers/blob/v4.28.0/src/transformers/models/gpt2/modeling_gpt2.py
# Copyright 2023 The vLLM team. # Copyright 2023 The vLLM team.
# Copyright 2018 The OpenAI Team Authors and HuggingFace Inc. team. # Copyright 2018 The OpenAI Team Authors and HuggingFace Inc. team.
# Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved. # Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved.
...@@ -47,19 +48,25 @@ class GPT2Attention(nn.Module): ...@@ -47,19 +48,25 @@ class GPT2Attention(nn.Module):
super().__init__() super().__init__()
self.hidden_size = config.hidden_size self.hidden_size = config.hidden_size
total_num_heads = config.num_attention_heads total_num_heads = config.num_attention_heads
tensor_model_parallel_world_size = get_tensor_model_parallel_world_size() tensor_model_parallel_world_size = (
get_tensor_model_parallel_world_size())
assert total_num_heads % tensor_model_parallel_world_size == 0 assert total_num_heads % tensor_model_parallel_world_size == 0
self.num_heads = total_num_heads // tensor_model_parallel_world_size self.num_heads = total_num_heads // tensor_model_parallel_world_size
self.head_dim = self.hidden_size // total_num_heads self.head_dim = self.hidden_size // total_num_heads
self.scale = self.head_dim ** -0.5 self.scale = self.head_dim**-0.5
self.c_attn = ColumnParallelLinear(self.hidden_size, 3 * self.hidden_size, self.c_attn = ColumnParallelLinear(self.hidden_size,
bias=True, gather_output=False, 3 * self.hidden_size,
bias=True,
gather_output=False,
perform_initialization=False) perform_initialization=False)
self.c_proj = RowParallelLinear(self.hidden_size, self.hidden_size, self.c_proj = RowParallelLinear(self.hidden_size,
bias=True, input_is_parallel=True, self.hidden_size,
bias=True,
input_is_parallel=True,
perform_initialization=False) perform_initialization=False)
self.attn = PagedAttention(self.num_heads, self.head_dim, self.attn = PagedAttention(self.num_heads,
self.head_dim,
scale=self.scale) scale=self.scale)
def forward( def forward(
...@@ -72,8 +79,8 @@ class GPT2Attention(nn.Module): ...@@ -72,8 +79,8 @@ class GPT2Attention(nn.Module):
qkv, _ = self.c_attn(hidden_states) qkv, _ = self.c_attn(hidden_states)
q, k, v = qkv.chunk(chunks=3, dim=-1) q, k, v = qkv.chunk(chunks=3, dim=-1)
key_cache, value_cache = kv_cache key_cache, value_cache = kv_cache
attn_output = self.attn( attn_output = self.attn(q, k, v, key_cache, value_cache,
q, k, v, key_cache, value_cache, input_metadata, cache_event) input_metadata, cache_event)
attn_output, _ = self.c_proj(attn_output) attn_output, _ = self.c_proj(attn_output)
return attn_output return attn_output
...@@ -87,11 +94,15 @@ class GPT2MLP(nn.Module): ...@@ -87,11 +94,15 @@ class GPT2MLP(nn.Module):
): ):
super().__init__() super().__init__()
hidden_size = config.hidden_size hidden_size = config.hidden_size
self.c_fc = ColumnParallelLinear(hidden_size, intermediate_size, self.c_fc = ColumnParallelLinear(hidden_size,
bias=True, gather_output=False, intermediate_size,
bias=True,
gather_output=False,
perform_initialization=False) perform_initialization=False)
self.c_proj = RowParallelLinear(intermediate_size, hidden_size, self.c_proj = RowParallelLinear(intermediate_size,
bias=True, input_is_parallel=True, hidden_size,
bias=True,
input_is_parallel=True,
perform_initialization=False) perform_initialization=False)
self.act = get_act_fn(config.activation_function) self.act = get_act_fn(config.activation_function)
...@@ -107,7 +118,8 @@ class GPT2Block(nn.Module): ...@@ -107,7 +118,8 @@ class GPT2Block(nn.Module):
def __init__(self, config: GPT2Config): def __init__(self, config: GPT2Config):
super().__init__() super().__init__()
hidden_size = config.hidden_size hidden_size = config.hidden_size
inner_dim = config.n_inner if config.n_inner is not None else 4 * hidden_size inner_dim = (config.n_inner if config.n_inner is not None else 4 *
hidden_size)
self.ln_1 = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon) self.ln_1 = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon)
self.attn = GPT2Attention(config) self.attn = GPT2Attention(config)
...@@ -145,9 +157,9 @@ class GPT2Model(nn.Module): ...@@ -145,9 +157,9 @@ class GPT2Model(nn.Module):
def __init__(self, config: GPT2Config): def __init__(self, config: GPT2Config):
super().__init__() super().__init__()
self.config = config self.config = config
assert config.add_cross_attention == False assert not config.add_cross_attention
assert config.scale_attn_by_inverse_layer_idx == False assert not config.scale_attn_by_inverse_layer_idx
assert config.reorder_and_upcast_attn == False assert not config.reorder_and_upcast_attn
self.embed_dim = config.hidden_size self.embed_dim = config.hidden_size
# Optimization: While the vocab size of GPT-2 is 50257, we extend it # Optimization: While the vocab size of GPT-2 is 50257, we extend it
...@@ -180,8 +192,8 @@ class GPT2Model(nn.Module): ...@@ -180,8 +192,8 @@ class GPT2Model(nn.Module):
else: else:
cache_event = cache_events[i] cache_event = cache_events[i]
layer = self.h[i] layer = self.h[i]
hidden_states = layer( hidden_states = layer(hidden_states, kv_caches[i], input_metadata,
hidden_states, kv_caches[i], input_metadata, cache_event) cache_event)
hidden_states = self.ln_f(hidden_states) hidden_states = self.ln_f(hidden_states)
return hidden_states return hidden_states
...@@ -206,24 +218,26 @@ class GPT2LMHeadModel(nn.Module): ...@@ -206,24 +218,26 @@ class GPT2LMHeadModel(nn.Module):
input_metadata: InputMetadata, input_metadata: InputMetadata,
cache_events: Optional[List[torch.cuda.Event]], cache_events: Optional[List[torch.cuda.Event]],
) -> Dict[int, SequenceOutputs]: ) -> Dict[int, SequenceOutputs]:
hidden_states = self.transformer( hidden_states = self.transformer(input_ids, positions, kv_caches,
input_ids, positions, kv_caches, input_metadata, cache_events) input_metadata, cache_events)
next_tokens = self.sampler( next_tokens = self.sampler(self.lm_head_weight, hidden_states,
self.lm_head_weight, hidden_states, input_metadata) input_metadata)
return next_tokens return next_tokens
_column_parallel_weights = ["wte.weight", "c_fc.weight", "c_fc.bias"] _column_parallel_weights = ["wte.weight", "c_fc.weight", "c_fc.bias"]
_row_parallel_weights = ["c_proj.weight"] _row_parallel_weights = ["c_proj.weight"]
def load_weights(self, model_name_or_path: str, def load_weights(self,
model_name_or_path: str,
cache_dir: Optional[str] = None, cache_dir: Optional[str] = None,
use_np_cache: bool = False): use_np_cache: bool = False):
tensor_model_parallel_world_size = get_tensor_model_parallel_world_size() tensor_model_parallel_world_size = (
get_tensor_model_parallel_world_size())
tensor_model_parallel_rank = get_tensor_model_parallel_rank() tensor_model_parallel_rank = get_tensor_model_parallel_rank()
state_dict = self.state_dict() state_dict = self.state_dict()
for name, loaded_weight in hf_model_weights_iterator( for name, loaded_weight in hf_model_weights_iterator(
model_name_or_path, cache_dir, use_np_cache): model_name_or_path, cache_dir, use_np_cache):
if "lm_head.weight" in name: if "lm_head.weight" in name:
# GPT-2 ties the weights of the embedding layer and the final # GPT-2 ties the weights of the embedding layer and the final
# linear layer. # linear layer.
...@@ -248,16 +262,20 @@ class GPT2LMHeadModel(nn.Module): ...@@ -248,16 +262,20 @@ class GPT2LMHeadModel(nn.Module):
if name == "transformer.wte.weight": if name == "transformer.wte.weight":
# Consider padding in the vocab size. # Consider padding in the vocab size.
padded_vocab_size = param.shape[0] * tensor_model_parallel_world_size padded_vocab_size = (param.shape[0] *
tensor_model_parallel_world_size)
num_extra_rows = padded_vocab_size - self.config.vocab_size num_extra_rows = padded_vocab_size - self.config.vocab_size
extra_rows = torch.empty(num_extra_rows, loaded_weight.shape[1]) extra_rows = torch.empty(num_extra_rows,
loaded_weight.shape[1])
extra_rows = extra_rows.to(loaded_weight) extra_rows = extra_rows.to(loaded_weight)
loaded_weight = torch.cat([loaded_weight, extra_rows], dim=0) loaded_weight = torch.cat([loaded_weight, extra_rows], dim=0)
# For the fused QKV linear layer, manually shard the weights. # For the fused QKV linear layer, manually shard the weights.
if "c_attn" in name: if "c_attn" in name:
# GPT-2's fused QKV has the shape of [3 * num_heads * head_size, hidden_size]. # GPT-2's fused QKV has the shape of
# When tensor parallelism is used, we shard the weights along the head dimension. # [3 * num_heads * head_size, hidden_size].
# When tensor parallelism is used, we shard the weights along
# the head dimension.
total_num_heads = self.config.num_attention_heads total_num_heads = self.config.num_attention_heads
hidden_size = self.config.hidden_size hidden_size = self.config.hidden_size
head_size = hidden_size // total_num_heads head_size = hidden_size // total_num_heads
...@@ -266,11 +284,13 @@ class GPT2LMHeadModel(nn.Module): ...@@ -266,11 +284,13 @@ class GPT2LMHeadModel(nn.Module):
head_end = (tensor_model_parallel_rank + 1) * num_heads head_end = (tensor_model_parallel_rank + 1) * num_heads
if name.endswith(".weight"): if name.endswith(".weight"):
loaded_weight = loaded_weight.view(3, total_num_heads, head_size, hidden_size) loaded_weight = loaded_weight.view(3, total_num_heads,
head_size, hidden_size)
loaded_weight = loaded_weight[:, head_start:head_end, :, :] loaded_weight = loaded_weight[:, head_start:head_end, :, :]
loaded_weight = loaded_weight.reshape(-1, hidden_size) loaded_weight = loaded_weight.reshape(-1, hidden_size)
elif name.endswith(".bias"): elif name.endswith(".bias"):
loaded_weight = loaded_weight.view(3, total_num_heads, head_size) loaded_weight = loaded_weight.view(3, total_num_heads,
head_size)
loaded_weight = loaded_weight[:, head_start:head_end, :] loaded_weight = loaded_weight[:, head_start:head_end, :]
loaded_weight = loaded_weight.reshape(-1) loaded_weight = loaded_weight.reshape(-1)
else: else:
......
# coding=utf-8 # coding=utf-8
# Adapted from https://github.com/huggingface/transformers/blob/v4.28.0/src/transformers/models/gpt2/modeling_gpt2.py # Adapted from
# https://github.com/huggingface/transformers/blob/v4.28.0/src/transformers/models/gpt2/modeling_gpt2.py
# Copyright 2023 The vLLM team. # Copyright 2023 The vLLM team.
# Copyright 2023 CTranslate2, and Michael Feil # Copyright 2023 CTranslate2, and Michael Feil
# Copyright 2018 The OpenAI Team Authors and HuggingFace Inc. team. # Copyright 2018 The OpenAI Team Authors and HuggingFace Inc. team.
...@@ -49,19 +50,25 @@ class GPTBigCodeAttention(nn.Module): ...@@ -49,19 +50,25 @@ class GPTBigCodeAttention(nn.Module):
super().__init__() super().__init__()
self.hidden_size = config.hidden_size self.hidden_size = config.hidden_size
total_num_heads = config.num_attention_heads total_num_heads = config.num_attention_heads
tensor_model_parallel_world_size = get_tensor_model_parallel_world_size() tensor_model_parallel_world_size = (
get_tensor_model_parallel_world_size())
assert total_num_heads % tensor_model_parallel_world_size == 0 assert total_num_heads % tensor_model_parallel_world_size == 0
self.num_heads = total_num_heads // tensor_model_parallel_world_size self.num_heads = total_num_heads // tensor_model_parallel_world_size
self.head_dim = self.hidden_size // total_num_heads self.head_dim = self.hidden_size // total_num_heads
self.scale = self.head_dim ** -0.5 self.scale = self.head_dim**-0.5
self.c_attn = ColumnParallelLinear(self.hidden_size, 3 * self.hidden_size, self.c_attn = ColumnParallelLinear(self.hidden_size,
bias=True, gather_output=False, 3 * self.hidden_size,
bias=True,
gather_output=False,
perform_initialization=False) perform_initialization=False)
self.c_proj = RowParallelLinear(self.hidden_size, self.hidden_size, self.c_proj = RowParallelLinear(self.hidden_size,
bias=True, input_is_parallel=True, self.hidden_size,
bias=True,
input_is_parallel=True,
perform_initialization=False) perform_initialization=False)
self.attn = PagedAttention(self.num_heads, self.head_dim, self.attn = PagedAttention(self.num_heads,
self.head_dim,
scale=self.scale) scale=self.scale)
def forward( def forward(
...@@ -74,8 +81,8 @@ class GPTBigCodeAttention(nn.Module): ...@@ -74,8 +81,8 @@ class GPTBigCodeAttention(nn.Module):
qkv, _ = self.c_attn(hidden_states) qkv, _ = self.c_attn(hidden_states)
q, k, v = qkv.chunk(chunks=3, dim=-1) q, k, v = qkv.chunk(chunks=3, dim=-1)
key_cache, value_cache = kv_cache key_cache, value_cache = kv_cache
attn_output = self.attn( attn_output = self.attn(q, k, v, key_cache, value_cache,
q, k, v, key_cache, value_cache, input_metadata, cache_event) input_metadata, cache_event)
attn_output, _ = self.c_proj(attn_output) attn_output, _ = self.c_proj(attn_output)
return attn_output return attn_output
...@@ -89,11 +96,15 @@ class GPTBigMLP(nn.Module): ...@@ -89,11 +96,15 @@ class GPTBigMLP(nn.Module):
): ):
super().__init__() super().__init__()
hidden_size = config.hidden_size hidden_size = config.hidden_size
self.c_fc = ColumnParallelLinear(hidden_size, intermediate_size, self.c_fc = ColumnParallelLinear(hidden_size,
bias=True, gather_output=False, intermediate_size,
bias=True,
gather_output=False,
perform_initialization=False) perform_initialization=False)
self.c_proj = RowParallelLinear(intermediate_size, hidden_size, self.c_proj = RowParallelLinear(intermediate_size,
bias=True, input_is_parallel=True, hidden_size,
bias=True,
input_is_parallel=True,
perform_initialization=False) perform_initialization=False)
self.act = get_act_fn(config.activation_function) self.act = get_act_fn(config.activation_function)
...@@ -109,7 +120,8 @@ class GPTBigCodeBlock(nn.Module): ...@@ -109,7 +120,8 @@ class GPTBigCodeBlock(nn.Module):
def __init__(self, config: GPTBigCodeConfig): def __init__(self, config: GPTBigCodeConfig):
super().__init__() super().__init__()
hidden_size = config.hidden_size hidden_size = config.hidden_size
inner_dim = config.n_inner if config.n_inner is not None else 4 * hidden_size inner_dim = (config.n_inner if config.n_inner is not None else 4 *
hidden_size)
self.ln_1 = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon) self.ln_1 = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon)
self.attn = GPTBigCodeAttention(config) self.attn = GPTBigCodeAttention(config)
...@@ -147,7 +159,7 @@ class GPTBigCodeModel(nn.Module): ...@@ -147,7 +159,7 @@ class GPTBigCodeModel(nn.Module):
def __init__(self, config: GPTBigCodeConfig): def __init__(self, config: GPTBigCodeConfig):
super().__init__() super().__init__()
self.config = config self.config = config
assert config.add_cross_attention == False assert not config.add_cross_attention
self.embed_dim = config.hidden_size self.embed_dim = config.hidden_size
...@@ -181,8 +193,8 @@ class GPTBigCodeModel(nn.Module): ...@@ -181,8 +193,8 @@ class GPTBigCodeModel(nn.Module):
else: else:
cache_event = cache_events[i] cache_event = cache_events[i]
layer = self.h[i] layer = self.h[i]
hidden_states = layer( hidden_states = layer(hidden_states, kv_caches[i], input_metadata,
hidden_states, kv_caches[i], input_metadata, cache_event) cache_event)
hidden_states = self.ln_f(hidden_states) hidden_states = self.ln_f(hidden_states)
return hidden_states return hidden_states
...@@ -207,24 +219,26 @@ class GPTBigCodeForCausalLM(nn.Module): ...@@ -207,24 +219,26 @@ class GPTBigCodeForCausalLM(nn.Module):
input_metadata: InputMetadata, input_metadata: InputMetadata,
cache_events: Optional[List[torch.cuda.Event]], cache_events: Optional[List[torch.cuda.Event]],
) -> Dict[int, SequenceOutputs]: ) -> Dict[int, SequenceOutputs]:
hidden_states = self.transformer( hidden_states = self.transformer(input_ids, positions, kv_caches,
input_ids, positions, kv_caches, input_metadata, cache_events) input_metadata, cache_events)
next_tokens = self.sampler( next_tokens = self.sampler(self.lm_head_weight, hidden_states,
self.lm_head_weight, hidden_states, input_metadata) input_metadata)
return next_tokens return next_tokens
_column_parallel_weights = ["wte.weight", "c_fc.weight", "c_fc.bias"] _column_parallel_weights = ["wte.weight", "c_fc.weight", "c_fc.bias"]
_row_parallel_weights = ["c_proj.weight"] _row_parallel_weights = ["c_proj.weight"]
def load_weights(self, model_name_or_path: str, def load_weights(self,
model_name_or_path: str,
cache_dir: Optional[str] = None, cache_dir: Optional[str] = None,
use_np_cache: bool = False): use_np_cache: bool = False):
tensor_model_parallel_world_size = get_tensor_model_parallel_world_size() tensor_model_parallel_world_size = (
get_tensor_model_parallel_world_size())
tensor_model_parallel_rank = get_tensor_model_parallel_rank() tensor_model_parallel_rank = get_tensor_model_parallel_rank()
state_dict = self.state_dict() state_dict = self.state_dict()
for name, loaded_weight in hf_model_weights_iterator( for name, loaded_weight in hf_model_weights_iterator(
model_name_or_path, cache_dir, use_np_cache): model_name_or_path, cache_dir, use_np_cache):
if "lm_head.weight" in name: if "lm_head.weight" in name:
# GPT-2 ties the weights of the embedding layer and the final # GPT-2 ties the weights of the embedding layer and the final
# linear layer. # linear layer.
...@@ -241,9 +255,11 @@ class GPTBigCodeForCausalLM(nn.Module): ...@@ -241,9 +255,11 @@ class GPTBigCodeForCausalLM(nn.Module):
if name == "transformer.wte.weight": if name == "transformer.wte.weight":
# Consider padding in the vocab size. # Consider padding in the vocab size.
padded_vocab_size = param.shape[0] * tensor_model_parallel_world_size padded_vocab_size = param.shape[
0] * tensor_model_parallel_world_size
num_extra_rows = padded_vocab_size - self.config.vocab_size num_extra_rows = padded_vocab_size - self.config.vocab_size
extra_rows = torch.empty(num_extra_rows, loaded_weight.shape[1]) extra_rows = torch.empty(num_extra_rows,
loaded_weight.shape[1])
extra_rows = extra_rows.to(loaded_weight) extra_rows = extra_rows.to(loaded_weight)
loaded_weight = torch.cat([loaded_weight, extra_rows], dim=0) loaded_weight = torch.cat([loaded_weight, extra_rows], dim=0)
...@@ -258,25 +274,31 @@ class GPTBigCodeForCausalLM(nn.Module): ...@@ -258,25 +274,31 @@ class GPTBigCodeForCausalLM(nn.Module):
qkv_array = qkv_array.numpy() qkv_array = qkv_array.numpy()
dims_q = n_head * head_dim dims_q = n_head * head_dim
q, k, v = np.split(qkv_array, (dims_q, dims_q + head_dim), axis=0) # pylint: disable=unbalanced-tuple-unpacking
# q is fine, but k & v have not replicated shape along the first axis q, k, v = np.split(qkv_array, (dims_q, dims_q + head_dim),
# as long as MQA is not nativly supported, increase memory and replicated axis=0)
# (head_dim, hidden_dim) to (n_heads * head_dim, hidden_dim) # q is fine, but k & v have not replicated shape along the first
# axis as long as MQA is not nativly supported, increase memory
# and replicated (head_dim, hidden_dim) to
# (n_heads * head_dim, hidden_dim)
if k.ndim == 2 and v.ndim == 2: if k.ndim == 2 and v.ndim == 2:
replication = (n_head, 1) # weights replication = (n_head, 1) # weights
else: else:
replication = n_head # biases replication = n_head # biases
# replicate n_head times for q, v # replicate n_head times for q, v
k, v = np.tile(k, replication), np.tile(v, replication) k, v = np.tile(k, replication), np.tile(v, replication)
# concat q, k, v along the first axis (n_heads * head_dim, hidden_dim) # concat q, k, v along the first axis
# (n_heads * head_dim, hidden_dim)
# to (3 * n_heads * head_dim, hidden_dim) # to (3 * n_heads * head_dim, hidden_dim)
qkv_array = np.concatenate((q, k, v), axis=0) qkv_array = np.concatenate((q, k, v), axis=0)
return torch.from_numpy(qkv_array) return torch.from_numpy(qkv_array)
# For the fused QKV linear layer, manually shard the weights. # For the fused QKV linear layer, manually shard the weights.
if "c_attn" in name: if "c_attn" in name:
# GPT-2's fused QKV has the shape of [3 * num_heads * head_size, hidden_size]. # GPT-2's fused QKV has the shape of
# When tensor parallelism is used, we shard the weights along the head dimension. # [3 * num_heads * head_size, hidden_size].
# When tensor parallelism is used, we shard the weights along
# the head dimension.
total_num_heads = self.config.num_attention_heads total_num_heads = self.config.num_attention_heads
hidden_size = self.config.hidden_size hidden_size = self.config.hidden_size
head_size = hidden_size // total_num_heads head_size = hidden_size // total_num_heads
...@@ -285,13 +307,19 @@ class GPTBigCodeForCausalLM(nn.Module): ...@@ -285,13 +307,19 @@ class GPTBigCodeForCausalLM(nn.Module):
head_end = (tensor_model_parallel_rank + 1) * num_heads head_end = (tensor_model_parallel_rank + 1) * num_heads
if name.endswith(".weight"): if name.endswith(".weight"):
loaded_weight = _expand_mqa_mha(loaded_weight, n_head=total_num_heads, head_dim=head_size) loaded_weight = _expand_mqa_mha(loaded_weight,
loaded_weight = loaded_weight.view(3, total_num_heads, head_size, hidden_size) n_head=total_num_heads,
head_dim=head_size)
loaded_weight = loaded_weight.view(3, total_num_heads,
head_size, hidden_size)
loaded_weight = loaded_weight[:, head_start:head_end, :, :] loaded_weight = loaded_weight[:, head_start:head_end, :, :]
loaded_weight = loaded_weight.reshape(-1, hidden_size) loaded_weight = loaded_weight.reshape(-1, hidden_size)
elif name.endswith(".bias"): elif name.endswith(".bias"):
loaded_weight = _expand_mqa_mha(loaded_weight, n_head=total_num_heads, head_dim=head_size) loaded_weight = _expand_mqa_mha(loaded_weight,
loaded_weight = loaded_weight.view(3, total_num_heads, head_size) n_head=total_num_heads,
head_dim=head_size)
loaded_weight = loaded_weight.view(3, total_num_heads,
head_size)
loaded_weight = loaded_weight[:, head_start:head_end, :] loaded_weight = loaded_weight[:, head_start:head_end, :]
loaded_weight = loaded_weight.reshape(-1) loaded_weight = loaded_weight.reshape(-1)
else: else:
......
# coding=utf-8 # coding=utf-8
# Adapted from https://github.com/huggingface/transformers/blob/v4.28.0/src/transformers/models/gpt_neox/modeling_gpt_neox.py # Adapted from
# https://github.com/huggingface/transformers/blob/v4.28.0/src/transformers/models/gpt_neox/modeling_gpt_neox.py
# Copyright 2023 The vLLM team. # Copyright 2023 The vLLM team.
# Copyright 2022 EleutherAI The HuggingFace Inc. team. All rights reserved. # Copyright 2022 EleutherAI The HuggingFace Inc. team. All rights reserved.
# #
...@@ -48,19 +49,23 @@ class GPTNeoXAttention(nn.Module): ...@@ -48,19 +49,23 @@ class GPTNeoXAttention(nn.Module):
self.hidden_size = config.hidden_size self.hidden_size = config.hidden_size
self.head_size = self.hidden_size // self.total_num_heads self.head_size = self.hidden_size // self.total_num_heads
tensor_model_parallel_world_size = get_tensor_model_parallel_world_size() tensor_model_parallel_world_size = (
get_tensor_model_parallel_world_size())
assert self.total_num_heads % tensor_model_parallel_world_size == 0 assert self.total_num_heads % tensor_model_parallel_world_size == 0
self.num_heads = self.total_num_heads // tensor_model_parallel_world_size self.num_heads = (self.total_num_heads //
tensor_model_parallel_world_size)
self.query_key_value = ColumnParallelLinear(config.hidden_size,
3 * config.hidden_size, self.query_key_value = ColumnParallelLinear(
gather_output=False, config.hidden_size,
perform_initialization=False) 3 * config.hidden_size,
self.dense = RowParallelLinear(config.hidden_size, config.hidden_size, gather_output=False,
perform_initialization=False)
self.dense = RowParallelLinear(config.hidden_size,
config.hidden_size,
input_is_parallel=True, input_is_parallel=True,
perform_initialization=False) perform_initialization=False)
scaling = self.head_size ** -0.5 scaling = self.head_size**-0.5
rotary_dim = int(self.head_size * config.rotary_pct) rotary_dim = int(self.head_size * config.rotary_pct)
assert rotary_dim % 2 == 0 assert rotary_dim % 2 == 0
self.attn = PagedAttentionWithRoPE(self.num_heads, self.head_size, self.attn = PagedAttentionWithRoPE(self.num_heads, self.head_size,
...@@ -78,8 +83,8 @@ class GPTNeoXAttention(nn.Module): ...@@ -78,8 +83,8 @@ class GPTNeoXAttention(nn.Module):
q, k, v = qkv.chunk(chunks=3, dim=-1) q, k, v = qkv.chunk(chunks=3, dim=-1)
k_cache, v_cache = kv_cache k_cache, v_cache = kv_cache
attn_output = self.attn( attn_output = self.attn(position_ids, q, k, v, k_cache, v_cache,
position_ids, q, k, v, k_cache, v_cache, input_metadata, cache_event) input_metadata, cache_event)
output, _ = self.dense(attn_output) output, _ = self.dense(attn_output)
return output return output
...@@ -92,7 +97,8 @@ class GPTNeoXMLP(nn.Module): ...@@ -92,7 +97,8 @@ class GPTNeoXMLP(nn.Module):
config.intermediate_size, config.intermediate_size,
gather_output=False, gather_output=False,
perform_initialization=False) perform_initialization=False)
self.dense_4h_to_h = RowParallelLinear(config.intermediate_size, config.hidden_size, self.dense_4h_to_h = RowParallelLinear(config.intermediate_size,
config.hidden_size,
input_is_parallel=True, input_is_parallel=True,
perform_initialization=False) perform_initialization=False)
self.act = get_act_fn(config.hidden_act) self.act = get_act_fn(config.hidden_act)
...@@ -109,8 +115,10 @@ class GPTNeoXLayer(nn.Module): ...@@ -109,8 +115,10 @@ class GPTNeoXLayer(nn.Module):
def __init__(self, config: GPTNeoXConfig): def __init__(self, config: GPTNeoXConfig):
super().__init__() super().__init__()
self.use_parallel_residual = config.use_parallel_residual self.use_parallel_residual = config.use_parallel_residual
self.input_layernorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) self.input_layernorm = nn.LayerNorm(config.hidden_size,
self.post_attention_layernorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) eps=config.layer_norm_eps)
self.post_attention_layernorm = nn.LayerNorm(config.hidden_size,
eps=config.layer_norm_eps)
self.attention = GPTNeoXAttention(config) self.attention = GPTNeoXAttention(config)
self.mlp = GPTNeoXMLP(config) self.mlp = GPTNeoXMLP(config)
...@@ -154,10 +162,13 @@ class GPTNeoXModel(nn.Module): ...@@ -154,10 +162,13 @@ class GPTNeoXModel(nn.Module):
super().__init__() super().__init__()
self.config = config self.config = config
self.embed_in = VocabParallelEmbedding(config.vocab_size, config.hidden_size, self.embed_in = VocabParallelEmbedding(config.vocab_size,
config.hidden_size,
perform_initialization=False) perform_initialization=False)
self.layers = nn.ModuleList([GPTNeoXLayer(config) for _ in range(config.num_hidden_layers)]) self.layers = nn.ModuleList(
self.final_layer_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) [GPTNeoXLayer(config) for _ in range(config.num_hidden_layers)])
self.final_layer_norm = nn.LayerNorm(config.hidden_size,
eps=config.layer_norm_eps)
def forward( def forward(
self, self,
...@@ -191,8 +202,10 @@ class GPTNeoXForCausalLM(nn.Module): ...@@ -191,8 +202,10 @@ class GPTNeoXForCausalLM(nn.Module):
super().__init__() super().__init__()
self.config = config self.config = config
self.gpt_neox = GPTNeoXModel(config) self.gpt_neox = GPTNeoXModel(config)
self.embed_out = ColumnParallelLinear(config.hidden_size, config.vocab_size, self.embed_out = ColumnParallelLinear(config.hidden_size,
bias=False, gather_output=False, config.vocab_size,
bias=False,
gather_output=False,
perform_initialization=False) perform_initialization=False)
self.sampler = Sampler(config.vocab_size) self.sampler = Sampler(config.vocab_size)
...@@ -204,24 +217,28 @@ class GPTNeoXForCausalLM(nn.Module): ...@@ -204,24 +217,28 @@ class GPTNeoXForCausalLM(nn.Module):
input_metadata: InputMetadata, input_metadata: InputMetadata,
cache_events: Optional[List[torch.cuda.Event]], cache_events: Optional[List[torch.cuda.Event]],
) -> Dict[int, SequenceOutputs]: ) -> Dict[int, SequenceOutputs]:
hidden_states = self.gpt_neox( hidden_states = self.gpt_neox(input_ids, positions, kv_caches,
input_ids, positions, kv_caches, input_metadata, cache_events) input_metadata, cache_events)
next_tokens = self.sampler( next_tokens = self.sampler(self.embed_out.weight, hidden_states,
self.embed_out.weight, hidden_states, input_metadata) input_metadata)
return next_tokens return next_tokens
_column_parallel_weights = ["embed_in.weight", "embed_out.weight", "dense_h_to_4h.weight", "dense_h_to_4h.bias"] _column_parallel_weights = [
"embed_in.weight", "embed_out.weight", "dense_h_to_4h.weight",
"dense_h_to_4h.bias"
]
_row_parallel_weights = ["dense.weight", "dense_4h_to_h.weight"] _row_parallel_weights = ["dense.weight", "dense_4h_to_h.weight"]
def load_weights(self, model_name_or_path: str, def load_weights(self,
model_name_or_path: str,
cache_dir: Optional[str] = None, cache_dir: Optional[str] = None,
use_np_cache: bool = False): use_np_cache: bool = False):
tensor_model_parallel_rank = get_tensor_model_parallel_rank() tensor_model_parallel_rank = get_tensor_model_parallel_rank()
state_dict = self.state_dict() state_dict = self.state_dict()
for name, loaded_weight in hf_model_weights_iterator( for name, loaded_weight in hf_model_weights_iterator(
model_name_or_path, cache_dir, use_np_cache): model_name_or_path, cache_dir, use_np_cache):
if ("attention.bias" in name or "attention.masked_bias" in name if ("attention.bias" in name or "attention.masked_bias" in name
or "rotary_emb.inv_freq" in name): or "rotary_emb.inv_freq" in name):
continue continue
param = state_dict[name] param = state_dict[name]
if "query_key_value" in name: if "query_key_value" in name:
...@@ -230,17 +247,19 @@ class GPTNeoXForCausalLM(nn.Module): ...@@ -230,17 +247,19 @@ class GPTNeoXForCausalLM(nn.Module):
# required shape is [3 * num_heads * head_size, hidden_size]. # required shape is [3 * num_heads * head_size, hidden_size].
# Thus, we need weight conversion. # Thus, we need weight conversion.
shard_size = param.shape[0] shard_size = param.shape[0]
loaded_weight = loaded_weight[shard_size * tensor_model_parallel_rank loaded_weight = loaded_weight[
:shard_size * (tensor_model_parallel_rank + 1)] shard_size * tensor_model_parallel_rank:shard_size *
(tensor_model_parallel_rank + 1)]
num_heads = self.config.num_attention_heads num_heads = self.config.num_attention_heads
hidden_size = self.config.hidden_size hidden_size = self.config.hidden_size
head_size = hidden_size // num_heads head_size = hidden_size // num_heads
if 'query_key_value.weight' in name: if "query_key_value.weight" in name:
loaded_weight = loaded_weight.view(-1, 3, head_size, hidden_size) loaded_weight = loaded_weight.view(-1, 3, head_size,
hidden_size)
loaded_weight = loaded_weight.transpose(0, 1) loaded_weight = loaded_weight.transpose(0, 1)
loaded_weight = loaded_weight.reshape(-1, hidden_size) loaded_weight = loaded_weight.reshape(-1, hidden_size)
elif 'query_key_value.bias' in name: elif "query_key_value.bias" in name:
loaded_weight = loaded_weight.view(-1, 3, head_size) loaded_weight = loaded_weight.view(-1, 3, head_size)
loaded_weight = loaded_weight.transpose(0, 1) loaded_weight = loaded_weight.transpose(0, 1)
loaded_weight = loaded_weight.reshape(-1) loaded_weight = loaded_weight.reshape(-1)
......
# coding=utf-8 # coding=utf-8
# Adapted from https://github.com/huggingface/transformers/blob/v4.28.0/src/transformers/models/llama/modeling_llama.py # Adapted from
# https://github.com/huggingface/transformers/blob/v4.28.0/src/transformers/models/llama/modeling_llama.py
# Copyright 2023 The vLLM team. # Copyright 2023 The vLLM team.
# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved. # Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved.
# #
...@@ -30,7 +31,6 @@ import torch ...@@ -30,7 +31,6 @@ import torch
from torch import nn from torch import nn
from transformers import LlamaConfig from transformers import LlamaConfig
from vllm.sequence import SequenceOutputs
from vllm.model_executor.input_metadata import InputMetadata from vllm.model_executor.input_metadata import InputMetadata
from vllm.model_executor.layers.activation import SiluAndMul from vllm.model_executor.layers.activation import SiluAndMul
from vllm.model_executor.layers.layernorm import RMSNorm from vllm.model_executor.layers.layernorm import RMSNorm
...@@ -56,15 +56,19 @@ class LlamaMLP(nn.Module): ...@@ -56,15 +56,19 @@ class LlamaMLP(nn.Module):
hidden_act: str, hidden_act: str,
): ):
super().__init__() super().__init__()
self.gate_up_proj = ColumnParallelLinear(hidden_size, 2 * intermediate_size, self.gate_up_proj = ColumnParallelLinear(hidden_size,
bias=False, gather_output=False, 2 * intermediate_size,
bias=False,
gather_output=False,
perform_initialization=False) perform_initialization=False)
self.down_proj = RowParallelLinear(intermediate_size, hidden_size, self.down_proj = RowParallelLinear(intermediate_size,
bias=False, input_is_parallel=True, hidden_size,
bias=False,
input_is_parallel=True,
perform_initialization=False) perform_initialization=False)
if hidden_act != 'silu': if hidden_act != "silu":
raise ValueError(f'Unsupported activation: {hidden_act}. ' raise ValueError(f"Unsupported activation: {hidden_act}. "
'Only silu is supported for now.') "Only silu is supported for now.")
self.act_fn = SiluAndMul() self.act_fn = SiluAndMul()
def forward(self, x): def forward(self, x):
...@@ -83,12 +87,14 @@ class LlamaAttention(nn.Module): ...@@ -83,12 +87,14 @@ class LlamaAttention(nn.Module):
): ):
super().__init__() super().__init__()
self.hidden_size = hidden_size self.hidden_size = hidden_size
tensor_model_parallel_world_size = get_tensor_model_parallel_world_size() tensor_model_parallel_world_size = (
get_tensor_model_parallel_world_size())
self.total_num_heads = num_heads self.total_num_heads = num_heads
assert self.total_num_heads % tensor_model_parallel_world_size == 0 assert self.total_num_heads % tensor_model_parallel_world_size == 0
self.num_heads = self.total_num_heads // tensor_model_parallel_world_size self.num_heads = (self.total_num_heads //
tensor_model_parallel_world_size)
self.head_dim = hidden_size // self.total_num_heads self.head_dim = hidden_size // self.total_num_heads
self.scaling = self.head_dim ** -0.5 self.scaling = self.head_dim**-0.5
self.qkv_proj = ColumnParallelLinear( self.qkv_proj = ColumnParallelLinear(
hidden_size, hidden_size,
...@@ -104,8 +110,10 @@ class LlamaAttention(nn.Module): ...@@ -104,8 +110,10 @@ class LlamaAttention(nn.Module):
input_is_parallel=True, input_is_parallel=True,
perform_initialization=False, perform_initialization=False,
) )
self.attn = PagedAttentionWithRoPE(self.num_heads, self.head_dim, self.attn = PagedAttentionWithRoPE(self.num_heads,
self.scaling, rotary_dim=self.head_dim) self.head_dim,
self.scaling,
rotary_dim=self.head_dim)
def forward( def forward(
self, self,
...@@ -118,8 +126,8 @@ class LlamaAttention(nn.Module): ...@@ -118,8 +126,8 @@ class LlamaAttention(nn.Module):
qkv, _ = self.qkv_proj(hidden_states) qkv, _ = self.qkv_proj(hidden_states)
q, k, v = qkv.chunk(chunks=3, dim=-1) q, k, v = qkv.chunk(chunks=3, dim=-1)
k_cache, v_cache = kv_cache k_cache, v_cache = kv_cache
attn_output = self.attn( attn_output = self.attn(positions, q, k, v, k_cache, v_cache,
positions, q, k, v, k_cache, v_cache, input_metadata, cache_event) input_metadata, cache_event)
output, _ = self.o_proj(attn_output) output, _ = self.o_proj(attn_output)
return output return output
...@@ -138,8 +146,10 @@ class LlamaDecoderLayer(nn.Module): ...@@ -138,8 +146,10 @@ class LlamaDecoderLayer(nn.Module):
intermediate_size=config.intermediate_size, intermediate_size=config.intermediate_size,
hidden_act=config.hidden_act, hidden_act=config.hidden_act,
) )
self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) self.input_layernorm = RMSNorm(config.hidden_size,
self.post_attention_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) eps=config.rms_norm_eps)
self.post_attention_layernorm = RMSNorm(config.hidden_size,
eps=config.rms_norm_eps)
def forward( def forward(
self, self,
...@@ -177,9 +187,13 @@ class LlamaModel(nn.Module): ...@@ -177,9 +187,13 @@ class LlamaModel(nn.Module):
self.padding_idx = config.pad_token_id self.padding_idx = config.pad_token_id
self.vocab_size = config.vocab_size self.vocab_size = config.vocab_size
self.embed_tokens = VocabParallelEmbedding(config.vocab_size, config.hidden_size, self.embed_tokens = VocabParallelEmbedding(
perform_initialization=False) config.vocab_size,
self.layers = nn.ModuleList([LlamaDecoderLayer(config) for _ in range(config.num_hidden_layers)]) config.hidden_size,
perform_initialization=False)
self.layers = nn.ModuleList([
LlamaDecoderLayer(config) for _ in range(config.num_hidden_layers)
])
self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
def forward( def forward(
...@@ -209,6 +223,7 @@ class LlamaModel(nn.Module): ...@@ -209,6 +223,7 @@ class LlamaModel(nn.Module):
class LlamaForCausalLM(nn.Module): class LlamaForCausalLM(nn.Module):
def __init__(self, config): def __init__(self, config):
super().__init__() super().__init__()
self.config = config self.config = config
...@@ -228,39 +243,42 @@ class LlamaForCausalLM(nn.Module): ...@@ -228,39 +243,42 @@ class LlamaForCausalLM(nn.Module):
input_metadata: InputMetadata, input_metadata: InputMetadata,
cache_events: Optional[List[torch.cuda.Event]], cache_events: Optional[List[torch.cuda.Event]],
) -> Dict[int, SequenceOutputs]: ) -> Dict[int, SequenceOutputs]:
hidden_states = self.model( hidden_states = self.model(input_ids, positions, kv_caches,
input_ids, positions, kv_caches, input_metadata, cache_events) input_metadata, cache_events)
next_tokens = self.sampler( next_tokens = self.sampler(self.lm_head.weight, hidden_states,
self.lm_head.weight, hidden_states, input_metadata) input_metadata)
return next_tokens return next_tokens
_column_parallel_weights = ["embed_tokens.weight", "lm_head.weight", _column_parallel_weights = [
"qkv_proj.weight", "gate_proj.weight", "embed_tokens.weight", "lm_head.weight", "qkv_proj.weight",
"up_proj.weight"] "gate_proj.weight", "up_proj.weight"
]
_row_parallel_weights = ["o_proj.weight", "down_proj.weight"] _row_parallel_weights = ["o_proj.weight", "down_proj.weight"]
def load_weights(self, model_name_or_path: str, def load_weights(self,
model_name_or_path: str,
cache_dir: Optional[str] = None, cache_dir: Optional[str] = None,
use_np_cache: bool = False): use_np_cache: bool = False):
tensor_model_parallel_rank = get_tensor_model_parallel_rank() tensor_model_parallel_rank = get_tensor_model_parallel_rank()
state_dict = self.state_dict() state_dict = self.state_dict()
for name, loaded_weight in hf_model_weights_iterator( for name, loaded_weight in hf_model_weights_iterator(
model_name_or_path, cache_dir, use_np_cache): model_name_or_path, cache_dir, use_np_cache):
if "rotary_emb.inv_freq" in name: if "rotary_emb.inv_freq" in name:
continue continue
is_attention_weight = False is_attention_weight = False
for stride_id, att_weight_name in enumerate(["q_proj", "k_proj", "v_proj"]): for stride_id, att_weight_name in enumerate(
["q_proj", "k_proj", "v_proj"]):
if att_weight_name not in name: if att_weight_name not in name:
continue continue
param = state_dict[name.replace(att_weight_name, "qkv_proj")] param = state_dict[name.replace(att_weight_name, "qkv_proj")]
shard_size = param.shape[0] // 3 shard_size = param.shape[0] // 3
loaded_weight = loaded_weight[ loaded_weight = loaded_weight[
shard_size * tensor_model_parallel_rank shard_size * tensor_model_parallel_rank:shard_size *
:shard_size * (tensor_model_parallel_rank + 1)] (tensor_model_parallel_rank + 1)]
param_slice = param.data[shard_size * stride_id param_slice = param.data[shard_size * stride_id:shard_size *
:shard_size * (stride_id + 1)] (stride_id + 1)]
assert param_slice.shape == loaded_weight.shape assert param_slice.shape == loaded_weight.shape
param_slice.copy_(loaded_weight) param_slice.copy_(loaded_weight)
is_attention_weight = True is_attention_weight = True
...@@ -275,10 +293,10 @@ class LlamaForCausalLM(nn.Module): ...@@ -275,10 +293,10 @@ class LlamaForCausalLM(nn.Module):
param = state_dict[name.replace(weight_name, "gate_up_proj")] param = state_dict[name.replace(weight_name, "gate_up_proj")]
shard_size = param.shape[0] // 2 shard_size = param.shape[0] // 2
loaded_weight = loaded_weight[ loaded_weight = loaded_weight[
shard_size * tensor_model_parallel_rank shard_size * tensor_model_parallel_rank:shard_size *
:shard_size * (tensor_model_parallel_rank + 1)] (tensor_model_parallel_rank + 1)]
param_slice = param.data[shard_size * stride_id param_slice = param.data[shard_size * stride_id:shard_size *
:shard_size * (stride_id + 1)] (stride_id + 1)]
assert param_slice.shape == loaded_weight.shape assert param_slice.shape == loaded_weight.shape
param_slice.copy_(loaded_weight) param_slice.copy_(loaded_weight)
is_gate_up_weight = True is_gate_up_weight = True
......
# coding=utf-8 # coding=utf-8
# Adapted from https://github.com/huggingface/transformers/blob/v4.28.0/src/transformers/models/opt/modeling_opt.py # Adapted from
# https://github.com/huggingface/transformers/blob/v4.28.0/src/transformers/models/opt/modeling_opt.py
# Copyright 2023 The vLLM team. # Copyright 2023 The vLLM team.
# Copyright 2022 The Fairseq Authors and The HuggingFace Inc. team. All rights reserved. # Copyright 2022 The Fairseq Authors and The HuggingFace Inc. team. All rights
# reserved.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
...@@ -43,8 +45,9 @@ KVCache = Tuple[torch.Tensor, torch.Tensor] ...@@ -43,8 +45,9 @@ KVCache = Tuple[torch.Tensor, torch.Tensor]
class OPTLearnedPositionalEmbedding(nn.Embedding): class OPTLearnedPositionalEmbedding(nn.Embedding):
def __init__(self, num_embeddings: int, embedding_dim: int): def __init__(self, num_embeddings: int, embedding_dim: int):
# OPT is set up so that if padding_idx is specified then offset the embedding ids by 2 # OPT is set up so that if padding_idx is specified then offset the
# and adjust num_embeddings appropriately. Other models don't have this hack # embedding ids by 2 and adjust num_embeddings appropriately. Other
# models don't have this hack
self.offset = 2 self.offset = 2
super().__init__(num_embeddings + self.offset, embedding_dim) super().__init__(num_embeddings + self.offset, embedding_dim)
...@@ -62,20 +65,26 @@ class OPTAttention(nn.Module): ...@@ -62,20 +65,26 @@ class OPTAttention(nn.Module):
) -> None: ) -> None:
super().__init__() super().__init__()
self.embed_dim = embed_dim self.embed_dim = embed_dim
tensor_model_parallel_world_size = get_tensor_model_parallel_world_size() tensor_model_parallel_world_size = (
get_tensor_model_parallel_world_size())
total_num_heads = num_heads total_num_heads = num_heads
assert num_heads % tensor_model_parallel_world_size == 0 assert num_heads % tensor_model_parallel_world_size == 0
self.num_heads = total_num_heads // tensor_model_parallel_world_size self.num_heads = total_num_heads // tensor_model_parallel_world_size
self.head_dim = embed_dim // total_num_heads self.head_dim = embed_dim // total_num_heads
self.scaling = self.head_dim ** -0.5 self.scaling = self.head_dim**-0.5
self.qkv_proj = ColumnParallelLinear(embed_dim, 3 * embed_dim, bias=bias, self.qkv_proj = ColumnParallelLinear(embed_dim,
3 * embed_dim,
bias=bias,
gather_output=False, gather_output=False,
perform_initialization=False) perform_initialization=False)
self.out_proj = RowParallelLinear(embed_dim, embed_dim, bias=bias, self.out_proj = RowParallelLinear(embed_dim,
embed_dim,
bias=bias,
input_is_parallel=True, input_is_parallel=True,
perform_initialization=False) perform_initialization=False)
self.attn = PagedAttention(self.num_heads, self.head_dim, self.attn = PagedAttention(self.num_heads,
self.head_dim,
scale=self.scaling) scale=self.scaling)
def forward( def forward(
...@@ -88,8 +97,8 @@ class OPTAttention(nn.Module): ...@@ -88,8 +97,8 @@ class OPTAttention(nn.Module):
qkv, _ = self.qkv_proj(hidden_states) qkv, _ = self.qkv_proj(hidden_states)
q, k, v = qkv.chunk(chunks=3, dim=-1) q, k, v = qkv.chunk(chunks=3, dim=-1)
key_cache, value_cache = kv_cache key_cache, value_cache = kv_cache
attn_output = self.attn( attn_output = self.attn(q, k, v, key_cache, value_cache,
q, k, v, key_cache, value_cache, input_metadata, cache_event) input_metadata, cache_event)
output, _ = self.out_proj(attn_output) output, _ = self.out_proj(attn_output)
return output return output
...@@ -109,17 +118,21 @@ class OPTDecoderLayer(nn.Module): ...@@ -109,17 +118,21 @@ class OPTDecoderLayer(nn.Module):
self.activation_fn = get_act_fn(config.activation_function) self.activation_fn = get_act_fn(config.activation_function)
self.self_attn_layer_norm = nn.LayerNorm( self.self_attn_layer_norm = nn.LayerNorm(
self.embed_dim, elementwise_affine=config.layer_norm_elementwise_affine) self.embed_dim,
self.fc1 = ColumnParallelLinear(self.embed_dim, config.ffn_dim, elementwise_affine=config.layer_norm_elementwise_affine)
self.fc1 = ColumnParallelLinear(self.embed_dim,
config.ffn_dim,
bias=config.enable_bias, bias=config.enable_bias,
gather_output=False, gather_output=False,
perform_initialization=False) perform_initialization=False)
self.fc2 = RowParallelLinear(config.ffn_dim, self.embed_dim, self.fc2 = RowParallelLinear(config.ffn_dim,
self.embed_dim,
bias=config.enable_bias, bias=config.enable_bias,
input_is_parallel=True, input_is_parallel=True,
perform_initialization=False) perform_initialization=False)
self.final_layer_norm = nn.LayerNorm( self.final_layer_norm = nn.LayerNorm(
self.embed_dim, elementwise_affine=config.layer_norm_elementwise_affine) self.embed_dim,
elementwise_affine=config.layer_norm_elementwise_affine)
def forward( def forward(
self, self,
...@@ -133,11 +146,10 @@ class OPTDecoderLayer(nn.Module): ...@@ -133,11 +146,10 @@ class OPTDecoderLayer(nn.Module):
# 125m, 1.7B, ..., 175B applies layer norm BEFORE attention # 125m, 1.7B, ..., 175B applies layer norm BEFORE attention
if self.do_layer_norm_before: if self.do_layer_norm_before:
hidden_states = self.self_attn_layer_norm(hidden_states) hidden_states = self.self_attn_layer_norm(hidden_states)
hidden_states = self.self_attn( hidden_states = self.self_attn(hidden_states=hidden_states,
hidden_states=hidden_states, kv_cache=kv_cache,
kv_cache=kv_cache, input_metadata=input_metadata,
input_metadata=input_metadata, cache_event=cache_event)
cache_event=cache_event)
hidden_states = residual + hidden_states hidden_states = residual + hidden_states
# 350m applies layer norm AFTER attention # 350m applies layer norm AFTER attention
if not self.do_layer_norm_before: if not self.do_layer_norm_before:
...@@ -167,35 +179,42 @@ class OPTDecoder(nn.Module): ...@@ -167,35 +179,42 @@ class OPTDecoder(nn.Module):
self.max_target_positions = config.max_position_embeddings self.max_target_positions = config.max_position_embeddings
self.vocab_size = config.vocab_size self.vocab_size = config.vocab_size
self.embed_tokens = VocabParallelEmbedding(config.vocab_size, self.embed_tokens = VocabParallelEmbedding(
config.word_embed_proj_dim, config.vocab_size,
perform_initialization=False) config.word_embed_proj_dim,
perform_initialization=False)
# Positional embeddings are replicated (not sharded). # Positional embeddings are replicated (not sharded).
self.embed_positions = OPTLearnedPositionalEmbedding( self.embed_positions = OPTLearnedPositionalEmbedding(
config.max_position_embeddings, config.hidden_size) config.max_position_embeddings, config.hidden_size)
# Project out & in will be replicated if they exist. # Project out & in will be replicated if they exist.
if config.word_embed_proj_dim != config.hidden_size: if config.word_embed_proj_dim != config.hidden_size:
self.project_out = nn.Linear(config.hidden_size, config.word_embed_proj_dim, bias=False) self.project_out = nn.Linear(config.hidden_size,
config.word_embed_proj_dim,
bias=False)
else: else:
self.project_out = None self.project_out = None
if config.word_embed_proj_dim != config.hidden_size: if config.word_embed_proj_dim != config.hidden_size:
self.project_in = nn.Linear(config.word_embed_proj_dim, config.hidden_size, bias=False) self.project_in = nn.Linear(config.word_embed_proj_dim,
config.hidden_size,
bias=False)
else: else:
self.project_in = None self.project_in = None
# Note that the only purpose of `config._remove_final_layer_norm` is to keep backward compatibility # Note that the only purpose of `config._remove_final_layer_norm` is to
# with checkpoints that have been fine-tuned before transformers v4.20.1 # keep backward compatibility with checkpoints that have been fine-tuned
# before transformers v4.20.1
# see https://github.com/facebookresearch/metaseq/pull/164 # see https://github.com/facebookresearch/metaseq/pull/164
if config.do_layer_norm_before and not config._remove_final_layer_norm: if config.do_layer_norm_before and not config._remove_final_layer_norm:
self.final_layer_norm = nn.LayerNorm( self.final_layer_norm = nn.LayerNorm(
config.hidden_size, elementwise_affine=config.layer_norm_elementwise_affine config.hidden_size,
) elementwise_affine=config.layer_norm_elementwise_affine)
else: else:
self.final_layer_norm = None self.final_layer_norm = None
self.layers = nn.ModuleList([OPTDecoderLayer(config) for _ in range(config.num_hidden_layers)]) self.layers = nn.ModuleList(
[OPTDecoderLayer(config) for _ in range(config.num_hidden_layers)])
def forward( def forward(
self, self,
...@@ -217,8 +236,8 @@ class OPTDecoder(nn.Module): ...@@ -217,8 +236,8 @@ class OPTDecoder(nn.Module):
else: else:
cache_event = cache_events[i] cache_event = cache_events[i]
layer = self.layers[i] layer = self.layers[i]
hidden_states = layer( hidden_states = layer(hidden_states, kv_caches[i], input_metadata,
hidden_states, kv_caches[i], input_metadata, cache_event) cache_event)
if self.final_layer_norm is not None: if self.final_layer_norm is not None:
hidden_states = self.final_layer_norm(hidden_states) hidden_states = self.final_layer_norm(hidden_states)
...@@ -241,8 +260,8 @@ class OPTModel(nn.Module): ...@@ -241,8 +260,8 @@ class OPTModel(nn.Module):
input_metadata: InputMetadata, input_metadata: InputMetadata,
cache_events: Optional[List[torch.cuda.Event]], cache_events: Optional[List[torch.cuda.Event]],
) -> torch.Tensor: ) -> torch.Tensor:
return self.decoder( return self.decoder(input_ids, positions, kv_caches, input_metadata,
input_ids, positions, kv_caches, input_metadata, cache_events) cache_events)
class OPTForCausalLM(nn.Module): class OPTForCausalLM(nn.Module):
...@@ -264,23 +283,26 @@ class OPTForCausalLM(nn.Module): ...@@ -264,23 +283,26 @@ class OPTForCausalLM(nn.Module):
input_metadata: InputMetadata, input_metadata: InputMetadata,
cache_events: Optional[List[torch.cuda.Event]], cache_events: Optional[List[torch.cuda.Event]],
) -> Dict[int, SequenceOutputs]: ) -> Dict[int, SequenceOutputs]:
hidden_states = self.model( hidden_states = self.model(input_ids, positions, kv_caches,
input_ids, positions, kv_caches, input_metadata, cache_events) input_metadata, cache_events)
next_tokens = self.sampler( next_tokens = self.sampler(self.lm_head_weight, hidden_states,
self.lm_head_weight, hidden_states, input_metadata) input_metadata)
return next_tokens return next_tokens
_column_parallel_weights = ["embed_tokens.weight", "fc1.weight", "fc1.bias"] _column_parallel_weights = [
"embed_tokens.weight", "fc1.weight", "fc1.bias"
]
_row_parallel_weights = ["out_proj.weight", "fc2.weight"] _row_parallel_weights = ["out_proj.weight", "fc2.weight"]
def load_weights(self, model_name_or_path: str, def load_weights(self,
model_name_or_path: str,
cache_dir: Optional[str] = None, cache_dir: Optional[str] = None,
use_np_cache: bool = False): use_np_cache: bool = False):
tensor_model_parallel_rank = get_tensor_model_parallel_rank() tensor_model_parallel_rank = get_tensor_model_parallel_rank()
state_dict = self.state_dict() state_dict = self.state_dict()
for name, loaded_weight in hf_model_weights_iterator( for name, loaded_weight in hf_model_weights_iterator(
model_name_or_path, cache_dir, use_np_cache): model_name_or_path, cache_dir, use_np_cache):
if "lm_head.weight" in name: if "lm_head.weight" in name:
continue continue
...@@ -288,16 +310,17 @@ class OPTForCausalLM(nn.Module): ...@@ -288,16 +310,17 @@ class OPTForCausalLM(nn.Module):
name = "model." + name name = "model." + name
is_attention_weight = False is_attention_weight = False
for stride_id, att_weight_name in enumerate(["q_proj", "k_proj", "v_proj"]): for stride_id, att_weight_name in enumerate(
["q_proj", "k_proj", "v_proj"]):
if att_weight_name not in name: if att_weight_name not in name:
continue continue
param = state_dict[name.replace(att_weight_name, "qkv_proj")] param = state_dict[name.replace(att_weight_name, "qkv_proj")]
shard_size = param.shape[0] // 3 shard_size = param.shape[0] // 3
loaded_weight = loaded_weight[ loaded_weight = loaded_weight[
shard_size * tensor_model_parallel_rank shard_size * tensor_model_parallel_rank:shard_size *
:shard_size * (tensor_model_parallel_rank + 1)] (tensor_model_parallel_rank + 1)]
param_slice = param.data[shard_size * stride_id param_slice = param.data[shard_size * stride_id:shard_size *
:shard_size * (stride_id + 1)] (stride_id + 1)]
assert param_slice.shape == loaded_weight.shape assert param_slice.shape == loaded_weight.shape
param_slice.copy_(loaded_weight) param_slice.copy_(loaded_weight)
is_attention_weight = True is_attention_weight = True
......
...@@ -44,9 +44,9 @@ def hf_model_weights_iterator( ...@@ -44,9 +44,9 @@ def hf_model_weights_iterator(
if use_np_cache: if use_np_cache:
# Convert the model weights from torch tensors to numpy arrays for # Convert the model weights from torch tensors to numpy arrays for
# faster loading. # faster loading.
np_folder = os.path.join(hf_folder, 'np') np_folder = os.path.join(hf_folder, "np")
os.makedirs(np_folder, exist_ok=True) os.makedirs(np_folder, exist_ok=True)
weight_names_file = os.path.join(np_folder, 'weight_names.json') weight_names_file = os.path.join(np_folder, "weight_names.json")
with lock: with lock:
if not os.path.exists(weight_names_file): if not os.path.exists(weight_names_file):
weight_names = [] weight_names = []
...@@ -57,10 +57,10 @@ def hf_model_weights_iterator( ...@@ -57,10 +57,10 @@ def hf_model_weights_iterator(
with open(param_path, "wb") as f: with open(param_path, "wb") as f:
np.save(f, param.cpu().detach().numpy()) np.save(f, param.cpu().detach().numpy())
weight_names.append(name) weight_names.append(name)
with open(weight_names_file, 'w') as f: with open(weight_names_file, "w") as f:
json.dump(weight_names, f) json.dump(weight_names, f)
with open(weight_names_file, 'r') as f: with open(weight_names_file, "r") as f:
weight_names = json.load(f) weight_names = json.load(f)
for name in weight_names: for name in weight_names:
...@@ -86,17 +86,16 @@ def load_tensor_parallel_weights( ...@@ -86,17 +86,16 @@ def load_tensor_parallel_weights(
for p in column_parallel_weight_names: for p in column_parallel_weight_names:
if p in param_name: if p in param_name:
shard_size = param.shape[0] shard_size = param.shape[0]
loaded_weight = loaded_weight[ start_idx = tensor_model_parallel_rank * shard_size
shard_size * tensor_model_parallel_rank end_idx = (tensor_model_parallel_rank + 1) * shard_size
:shard_size * (tensor_model_parallel_rank + 1)] loaded_weight = loaded_weight[start_idx:end_idx]
break break
for p in row_parallel_weight_names: for p in row_parallel_weight_names:
if p in param_name: if p in param_name:
shard_size = param.shape[1] shard_size = param.shape[1]
loaded_weight = loaded_weight[ start_idx = tensor_model_parallel_rank * shard_size
:, end_idx = (tensor_model_parallel_rank + 1) * shard_size
shard_size * tensor_model_parallel_rank loaded_weight = loaded_weight[:, start_idx:end_idx]
:shard_size * (tensor_model_parallel_rank + 1)]
break break
assert param.shape == loaded_weight.shape, ( assert param.shape == loaded_weight.shape, (
f"{param_name} shape mismatch between model and checkpoint: " f"{param_name} shape mismatch between model and checkpoint: "
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment