Unverified Commit d6fa1be3 authored by Zhuohan Li's avatar Zhuohan Li Committed by GitHub
Browse files

[Quality] Add code formatter and linter (#326)

parent 0ffded81
......@@ -67,8 +67,7 @@ class LLMEngine:
f"download_dir={model_config.download_dir!r}, "
f"use_np_weights={model_config.use_np_weights}, "
f"tensor_parallel_size={parallel_config.tensor_parallel_size}, "
f"seed={model_config.seed})"
)
f"seed={model_config.seed})")
# TODO(woosuk): Print more configs in debug mode.
self.model_config = model_config
......@@ -78,8 +77,8 @@ class LLMEngine:
self.log_stats = log_stats
self._verify_args()
self.tokenizer = get_tokenizer(model_config.tokenizer,
model_config.tokenizer_mode)
self.tokenizer = get_tokenizer(
model_config.tokenizer, tokenizer_mode=model_config.tokenizer_mode)
self.seq_counter = Counter()
# Create the parallel GPU workers.
......@@ -129,8 +128,8 @@ class LLMEngine:
num_gpu_blocks = min(b[0] for b in num_blocks)
num_cpu_blocks = min(b[1] for b in num_blocks)
# FIXME(woosuk): Change to debug log.
logger.info(f'# GPU blocks: {num_gpu_blocks}, '
f'# CPU blocks: {num_cpu_blocks}')
logger.info(f"# GPU blocks: {num_gpu_blocks}, "
f"# CPU blocks: {num_cpu_blocks}")
if num_gpu_blocks <= 0:
raise ValueError("No available memory for the cache blocks. "
......@@ -152,7 +151,9 @@ class LLMEngine:
# Initialize the cluster.
distributed_init_method, devices = initialize_cluster(parallel_config)
# Create the LLM engine.
engine = cls(*engine_configs, distributed_init_method, devices,
engine = cls(*engine_configs,
distributed_init_method,
devices,
log_stats=not engine_args.disable_log_stats)
return engine
......@@ -226,8 +227,10 @@ class LLMEngine:
and updates the scheduler with the model outputs. Finally, it decodes
the sequences and returns the newly generated results.
"""
seq_group_metadata_list, scheduler_outputs, ignored_seq_groups = self.scheduler.schedule()
if (not seq_group_metadata_list) and scheduler_outputs.is_empty() and (not ignored_seq_groups):
(seq_group_metadata_list, scheduler_outputs,
ignored_seq_groups) = self.scheduler.schedule()
if ((not seq_group_metadata_list) and scheduler_outputs.is_empty()
and (not ignored_seq_groups)):
# Nothing to do.
return []
......@@ -281,8 +284,8 @@ class LLMEngine:
# Truncate the output text so that the stop string is
# not included in the output.
seq.output_text = seq.output_text[:-len(stop_str)]
self.scheduler.free_seq(seq,
SequenceStatus.FINISHED_STOPPED)
self.scheduler.free_seq(
seq, SequenceStatus.FINISHED_STOPPED)
stopped = True
break
if stopped:
......@@ -290,7 +293,7 @@ class LLMEngine:
# Check if the sequence has reached max_seq_len.
if (seq.get_len() >=
self.scheduler.scheduler_config.max_seq_len):
self.scheduler.scheduler_config.max_seq_len):
self.scheduler.free_seq(
seq, SequenceStatus.FINISHED_LENGTH_CAPPED)
continue
......@@ -302,15 +305,15 @@ class LLMEngine:
# Check if the sequence has generated the EOS token.
if not sampling_params.ignore_eos:
if seq.get_last_token_id() == self.tokenizer.eos_token_id:
self.scheduler.free_seq(seq,
SequenceStatus.FINISHED_STOPPED)
self.scheduler.free_seq(
seq, SequenceStatus.FINISHED_STOPPED)
continue
def _run_workers(
self,
method: str,
get_all_outputs: bool = False,
*args,
get_all_outputs: bool = False,
**kwargs,
) -> Any:
"""Runs the given method on all workers."""
......
......@@ -8,7 +8,8 @@ except ImportError:
from vllm.config import ParallelConfig
DeviceID = Tuple[int, Optional[str], int] # rank, node resource (node IP), device id
# rank, node resource (node IP), device id
DeviceID = Tuple[int, Optional[str], int]
def initialize_cluster(
......@@ -53,15 +54,15 @@ def initialize_cluster(
valid_node_resources = []
num_devices_per_node = None
for node in ray.nodes():
if (not node['Alive']) or node['Resources']['GPU'] <= 0:
if (not node["Alive"]) or node["Resources"]["GPU"] <= 0:
continue
if num_devices_per_node is None:
num_devices_per_node = node['Resources']['GPU']
num_devices_per_node = node["Resources"]["GPU"]
else:
assert num_devices_per_node == node['Resources']['GPU'], (
assert num_devices_per_node == node["Resources"]["GPU"], (
"The number of GPUs per node is not uniform.")
for key in node['Resources']:
if key.startswith('node:'):
for key in node["Resources"]:
if key.startswith("node:"):
valid_node_resources.append(key)
# Verify the parallel config.
......
......@@ -11,8 +11,8 @@ from vllm.engine.async_llm_engine import AsyncLLMEngine
from vllm.sampling_params import SamplingParams
from vllm.utils import random_uuid
TIMEOUT_KEEP_ALIVE = 5 # seconds.
TIMEOUT_TO_PREVENT_DEADLOCK = 1 # seconds
TIMEOUT_KEEP_ALIVE = 5 # seconds.
TIMEOUT_TO_PREVENT_DEADLOCK = 1 # seconds.
app = FastAPI()
......@@ -37,8 +37,7 @@ async def generate(request: Request) -> Response:
async for request_output in results_generator:
prompt = request_output.prompt
text_outputs = [
prompt + output.text
for output in request_output.outputs
prompt + output.text for output in request_output.outputs
]
ret = {"text": text_outputs}
yield (json.dumps(ret) + "\0").encode("utf-8")
......@@ -63,10 +62,7 @@ async def generate(request: Request) -> Response:
assert final_output is not None
prompt = final_output.prompt
text_outputs = [
prompt + output.text
for output in final_output.outputs
]
text_outputs = [prompt + output.text for output in final_output.outputs]
ret = {"text": text_outputs}
return Response(content=json.dumps(ret))
......@@ -81,5 +77,8 @@ if __name__ == "__main__":
engine_args = AsyncEngineArgs.from_cli_args(args)
engine = AsyncLLMEngine.from_engine_args(engine_args)
uvicorn.run(app, host=args.host, port=args.port, log_level="debug",
uvicorn.run(app,
host=args.host,
port=args.port,
log_level="debug",
timeout_keep_alive=TIMEOUT_KEEP_ALIVE)
......@@ -63,8 +63,7 @@ class LLM:
self.request_counter = Counter()
def get_tokenizer(
self,
) -> Union[PreTrainedTokenizer, PreTrainedTokenizerFast]:
self) -> Union[PreTrainedTokenizer, PreTrainedTokenizerFast]:
return self.llm_engine.tokenizer
def set_tokenizer(
......
# Adapted from https://github.com/lm-sys/FastChat/blob/168ccc29d3f7edc50823016105c024fe2282732a/fastchat/serve/openai_api_server.py
# Adapted from
# https://github.com/lm-sys/FastChat/blob/168ccc29d3f7edc50823016105c024fe2282732a/fastchat/serve/openai_api_server.py
import argparse
from http import HTTPStatus
......@@ -29,7 +30,7 @@ from vllm.sampling_params import SamplingParams
from vllm.transformers_utils.tokenizer import get_tokenizer
from vllm.utils import random_uuid
TIMEOUT_KEEP_ALIVE = 5 # seconds
TIMEOUT_KEEP_ALIVE = 5 # seconds
logger = init_logger(__name__)
served_model = None
......@@ -38,14 +39,13 @@ app = fastapi.FastAPI()
def create_error_response(status_code: HTTPStatus,
message: str) -> JSONResponse:
return JSONResponse(
ErrorResponse(message=message, type="invalid_request_error").dict(),
status_code=status_code.value
)
return JSONResponse(ErrorResponse(message=message,
type="invalid_request_error").dict(),
status_code=status_code.value)
@app.exception_handler(RequestValidationError)
async def validation_exception_handler(request, exc):
async def validation_exception_handler(request, exc): # pylint: disable=unused-argument
return create_error_response(HTTPStatus.BAD_REQUEST, str(exc))
......@@ -126,8 +126,11 @@ async def check_length(request, prompt, engine):
@app.get("/v1/models")
async def show_available_models():
"""Show available models. Right now we only have one model."""
model_cards = [ModelCard(id=served_model, root=served_model,
permission=[ModelPermission()])]
model_cards = [
ModelCard(id=served_model,
root=served_model,
permission=[ModelPermission()])
]
return ModelList(data=model_cards)
......@@ -144,12 +147,14 @@ def create_logprobs(token_ids: List[int],
if len(logprobs.text_offset) == 0:
logprobs.text_offset.append(initial_text_offset)
else:
logprobs.text_offset.append(logprobs.text_offset[-1] + last_token_len)
logprobs.text_offset.append(logprobs.text_offset[-1] +
last_token_len)
last_token_len = len(token)
logprobs.top_logprobs.append(
{tokenizer.convert_ids_to_tokens(i): p
for i, p in id_logprob.items()})
logprobs.top_logprobs.append({
tokenizer.convert_ids_to_tokens(i): p
for i, p in id_logprob.items()
})
return logprobs
......@@ -348,7 +353,7 @@ async def create_completion(raw_request: Request):
if request.suffix is not None:
# The language models we currently support do not support suffix.
return create_error_response(HTTPStatus.BAD_REQUEST,
"suffix is not currently supported")
"suffix is not currently supported")
if request.logit_bias is not None:
# TODO: support logit_bias in vLLM engine.
......@@ -387,22 +392,23 @@ async def create_completion(raw_request: Request):
except ValueError as e:
return create_error_response(HTTPStatus.BAD_REQUEST, str(e))
result_generator = engine.generate(prompt, sampling_params,
request_id)
result_generator = engine.generate(prompt, sampling_params, request_id)
# Similar to the OpenAI API, when n != best_of, we do not stream the
# results. In addition, we do not stream the results when use beam search.
stream = (request.stream and
(request.best_of is None or request.n == request.best_of) and
not request.use_beam_search)
stream = (request.stream
and (request.best_of is None or request.n == request.best_of)
and not request.use_beam_search)
async def abort_request() -> None:
await engine.abort(request_id)
def create_stream_response_json(index: int,
text: str,
logprobs: Optional[LogProbs] = None,
finish_reason: Optional[str] = None) -> str:
def create_stream_response_json(
index: int,
text: str,
logprobs: Optional[LogProbs] = None,
finish_reason: Optional[str] = None,
) -> str:
choice_data = CompletionResponseStreamChoice(
index=index,
text=text,
......@@ -443,7 +449,8 @@ async def create_completion(raw_request: Request):
)
yield f"data: {response_json}\n\n"
if output.finish_reason is not None:
logprobs = LogProbs() if request.logprobs is not None else None
logprobs = (LogProbs()
if request.logprobs is not None else None)
response_json = create_stream_response_json(
index=i,
text="",
......@@ -487,8 +494,8 @@ async def create_completion(raw_request: Request):
choices.append(choice_data)
num_prompt_tokens = len(final_res.prompt_token_ids)
num_generated_tokens = sum(len(output.token_ids)
for output in final_res.outputs)
num_generated_tokens = sum(
len(output.token_ids) for output in final_res.outputs)
usage = UsageInfo(
prompt_tokens=num_prompt_tokens,
completion_tokens=num_generated_tokens,
......@@ -506,9 +513,11 @@ async def create_completion(raw_request: Request):
# When user requests streaming but we don't stream, we still need to
# return a streaming response with a single event.
response_json = response.json(ensure_ascii=False)
async def fake_stream_generator() -> AsyncGenerator[str, None]:
yield f"data: {response_json}\n\n"
yield "data: [DONE]\n\n"
return StreamingResponse(fake_stream_generator(),
media_type="text/event-stream")
......@@ -517,26 +526,34 @@ async def create_completion(raw_request: Request):
if __name__ == "__main__":
parser = argparse.ArgumentParser(
description="vLLM OpenAI-Compatible RESTful API server."
)
parser.add_argument("--host", type=str, default="localhost", help="host name")
description="vLLM OpenAI-Compatible RESTful API server.")
parser.add_argument("--host",
type=str,
default="localhost",
help="host name")
parser.add_argument("--port", type=int, default=8000, help="port number")
parser.add_argument("--allow-credentials",
action="store_true",
help="allow credentials")
parser.add_argument("--allowed-origins",
type=json.loads,
default=["*"],
help="allowed origins")
parser.add_argument("--allowed-methods",
type=json.loads,
default=["*"],
help="allowed methods")
parser.add_argument("--allowed-headers",
type=json.loads,
default=["*"],
help="allowed headers")
parser.add_argument(
"--allow-credentials", action="store_true", help="allow credentials"
)
parser.add_argument(
"--allowed-origins", type=json.loads, default=["*"], help="allowed origins"
)
parser.add_argument(
"--allowed-methods", type=json.loads, default=["*"], help="allowed methods"
)
parser.add_argument(
"--allowed-headers", type=json.loads, default=["*"], help="allowed headers"
)
parser.add_argument("--served-model-name", type=str, default=None,
help="The model name used in the API. If not specified, "
"the model name will be the same as the "
"huggingface name.")
"--served-model-name",
type=str,
default=None,
help="The model name used in the API. If not specified, "
"the model name will be the same as the "
"huggingface name.")
parser = AsyncEngineArgs.add_cli_args(parser)
args = parser.parse_args()
......@@ -556,7 +573,11 @@ if __name__ == "__main__":
engine = AsyncLLMEngine.from_engine_args(engine_args)
# A separate tokenizer to map token IDs to strings.
tokenizer = get_tokenizer(engine_args.tokenizer, engine_args.tokenizer_mode)
tokenizer = get_tokenizer(engine_args.tokenizer,
tokenizer_mode=engine_args.tokenizer_mode)
uvicorn.run(app, host=args.host, port=args.port, log_level="info",
uvicorn.run(app,
host=args.host,
port=args.port,
log_level="info",
timeout_keep_alive=TIMEOUT_KEEP_ALIVE)
# Adapted from https://github.com/lm-sys/FastChat/blob/168ccc29d3f7edc50823016105c024fe2282732a/fastchat/protocol/openai_api_protocol.py
# Adapted from
# https://github.com/lm-sys/FastChat/blob/168ccc29d3f7edc50823016105c024fe2282732a/fastchat/protocol/openai_api_protocol.py
import time
from typing import Dict, List, Literal, Optional, Union
......@@ -98,7 +99,8 @@ class LogProbs(BaseModel):
text_offset: List[int] = Field(default_factory=list)
token_logprobs: List[Optional[float]] = Field(default_factory=list)
tokens: List[str] = Field(default_factory=list)
top_logprobs: List[Optional[Dict[str, float]]] = Field(default_factory=list)
top_logprobs: List[Optional[Dict[str,
float]]] = Field(default_factory=list)
class CompletionResponseChoice(BaseModel):
......
# Adapted from https://github.com/skypilot-org/skypilot/blob/86dc0f6283a335e4aa37b3c10716f90999f48ab6/sky/sky_logging.py
# Adapted from
# https://github.com/skypilot-org/skypilot/blob/86dc0f6283a335e4aa37b3c10716f90999f48ab6/sky/sky_logging.py
"""Logging configuration for vLLM."""
import logging
import sys
_FORMAT = "%(levelname)s %(asctime)s %(filename)s:%(lineno)d] %(message)s"
_DATE_FORMAT = "%m-%d %H:%M:%S"
......
......@@ -2,7 +2,6 @@ from vllm.model_executor.input_metadata import InputMetadata
from vllm.model_executor.model_loader import get_model
from vllm.model_executor.utils import set_random_seed
__all__ = [
"InputMetadata",
"get_model",
......
......@@ -8,11 +8,22 @@ from vllm.sequence import SequenceData
class InputMetadata:
"""Metadata for input sequences. Used for PagedAttention.
Args:
seq_groups: List of (seq_ids, sampling_params).
seq_data: Seq_id -> SequenceData.
prompt_lens: Lengths of prompts.
slot_mapping: The address to write the new KV to of each token.
context_lens: the length of attention context for each generation token.
max_context_len: The maximum context length.
block_tables: The block tables. (Seq id -> list of physical block)
"""
def __init__(
self,
seq_groups: List[Tuple[List[int], SamplingParams]], # List of (seq_ids, sampling_params).
seq_data: Dict[int, SequenceData], # Seq_id -> SequenceData.
seq_groups: List[Tuple[List[int], SamplingParams]],
seq_data: Dict[int, SequenceData],
prompt_lens: List[int],
slot_mapping: torch.Tensor,
context_lens: torch.Tensor,
......
......@@ -6,9 +6,10 @@ from vllm import activation_ops
_ACTIVATION_REGISTRY = {
"gelu": nn.GELU(),
"gelu_new": nn.GELU(approximate="tanh"), # NOTE: This may introduce small rounding errors.
"gelu_fast": nn.GELU(approximate="tanh"), # NOTE: This may introduce small rounding errors.
"gelu_pytorch_tanh": nn.GELU(approximate="tanh"), # NOTE: This may introduce small rounding errors.
# NOTE: The following GELU functions may introduce small rounding errors.
"gelu_new": nn.GELU(approximate="tanh"),
"gelu_fast": nn.GELU(approximate="tanh"),
"gelu_pytorch_tanh": nn.GELU(approximate="tanh"),
"relu": nn.ReLU(),
}
......@@ -25,15 +26,13 @@ class SiluAndMul(nn.Module):
"""An activation function for SwiGLU.
The function computes x -> silu(x[:d]) * x[d:] where d = x.shape[1] // 2.
"""
def __init__(self):
super().__init__()
Shapes:
x: (num_tokens, 2 * d)
return: (num_tokens, d)
"""
def forward(
self,
x: torch.Tensor, # (num_tokens, 2 * d)
) -> torch.Tensor: # (num_tokens, d)
def forward(self, x: torch.Tensor) -> torch.Tensor:
num_tokens = x.shape[0]
d = x.shape[1] // 2
out = torch.empty(num_tokens, d, dtype=x.dtype, device=x.device)
......
......@@ -14,6 +14,7 @@ _SUPPORTED_HEAD_SIZES = [64, 80, 96, 128]
class PagedAttention(nn.Module):
# pylint: disable=line-too-long
"""GPT-style multi-head PagedAttention.
This class takes flattened 1D query, key, and value tensors as input. The
......@@ -54,12 +55,20 @@ class PagedAttention(nn.Module):
def multi_query_kv_attention(
self,
output: torch.Tensor, # [num_prompt_tokens, num_heads, head_size]
query: torch.Tensor, # [num_prompt_tokens, num_heads, head_size]
key: torch.Tensor, # [num_prompt_tokens, num_heads, head_size]
value: torch.Tensor, # [num_prompt_tokens, num_heads, head_size]
output: torch.Tensor,
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
attn_bias: xops.AttentionBias,
) -> torch.Tensor:
"""Normal attention for the prompt tokens.
Args:
output: shape = [num_prompt_tokens, num_heads, head_size]
query: shape = [num_prompt_tokens, num_heads, head_size]
key: shape = [num_prompt_tokens, num_heads, head_size]
value: shape = [num_prompt_tokens, num_heads, head_size]
"""
# TODO(woosuk): The unsqueeze op may incur some CPU overhead. Optimize.
out = xops.memory_efficient_attention_forward(
query.unsqueeze(0),
......@@ -76,12 +85,22 @@ class PagedAttention(nn.Module):
def single_query_cached_kv_attention(
self,
output: torch.Tensor, # [num_generation_tokens, num_heads, head_size]
query: torch.Tensor, # [num_generation_tokens, num_heads, head_size]
key_cache: torch.Tensor, # [num_blocks, num_heads, head_size/x, block_size, x]
value_cache: torch.Tensor, # [num_blocks, num_heads, head_size, block_size]
output: torch.Tensor,
query: torch.Tensor,
key_cache: torch.Tensor,
value_cache: torch.Tensor,
input_metadata: InputMetadata,
) -> None:
"""PagedAttention for the generation tokens.
Args:
output: shape = [num_generation_tokens, num_heads, head_size]
query: shape = [num_generation_tokens, num_heads, head_size]
key_cache: shape = [num_blocks, num_heads, head_size/x,
block_size, x]
value_cache: shape = [num_blocks, num_heads, head_size, block_size]
input_metadata: metadata for paged attention.
"""
block_size = value_cache.shape[3]
attention_ops.single_query_cached_kv_attention(
output,
......@@ -97,16 +116,32 @@ class PagedAttention(nn.Module):
def forward(
self,
query: torch.Tensor, # [num_tokens, num_heads * head_size]
key: torch.Tensor, # [num_tokens, num_heads * head_size]
value: torch.Tensor, # [num_tokens, num_heads * head_size]
key_cache: Optional[torch.Tensor], # [num_blocks, num_heads, head_size/x, block_size, x]
value_cache: Optional[torch.Tensor], # [num_blocks, num_heads, head_size, block_size]
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
key_cache: Optional[torch.Tensor],
value_cache: Optional[torch.Tensor],
input_metadata: InputMetadata,
cache_event: Optional[torch.cuda.Event],
) -> torch.Tensor: # [num_tokens, num_heads * head_size]
# NOTE: The query, key, and value tensors must be sliced from a qkv
# tensor of shape [num_tokens, 3 * num_heads * head_size].
) -> torch.Tensor:
"""PagedAttention forward pass.
NOTE: The query, key, and value tensors must be sliced from a qkv
tensor of shape [num_tokens, 3 * num_heads * head_size].
Args:
query: shape = [num_tokens, num_heads * head_size]
key: shape = [num_tokens, num_heads * head_size]
value: shape = [num_tokens, num_heads * head_size]
key_cache: shape = [num_blocks, num_heads, head_size/x,
block_size, x]
value_cache: shape = [num_blocks, num_heads, head_size, block_size]
input_metadata: metadata for paged attention.
cache_event: event to wait for the cache operations to finish.
Returns:
shape = [num_tokens, num_heads * head_size]
"""
# Reshape the query, key, and value tensors.
query = query.view(-1, self.num_heads, self.head_size)
......@@ -136,7 +171,7 @@ class PagedAttention(nn.Module):
# and value vectors will not be cached.
num_valid_tokens = input_metadata.num_valid_tokens
if (num_valid_tokens > 0 and key_cache is not None
and value_cache is not None):
and value_cache is not None):
# The stride is 3 because the key and value are sliced from qkv.
cache_ops.reshape_and_cache(
key[:num_valid_tokens],
......@@ -149,15 +184,12 @@ class PagedAttention(nn.Module):
if input_metadata.num_generation_tokens > 0:
assert key_cache is not None and value_cache is not None, (
"key_cache and value_cache must be provided when "
"generating tokens."
)
"generating tokens.")
# Compute the attention op for generation tokens.
self.single_query_cached_kv_attention(
output[num_prompt_tokens:num_valid_tokens],
query[num_prompt_tokens:num_valid_tokens],
key_cache,
value_cache,
input_metadata)
query[num_prompt_tokens:num_valid_tokens], key_cache,
value_cache, input_metadata)
# Reshape the output tensor.
# NOTE(woosuk): The output tensor may include paddings.
......@@ -179,9 +211,9 @@ class PagedAttentionWithRoPE(PagedAttention):
super().__init__(num_heads, head_size, scale)
# Create the cos and sin cache.
inv_freq = 1.0 / (base ** (torch.arange(0, rotary_dim, 2) / rotary_dim))
inv_freq = 1.0 / (base**(torch.arange(0, rotary_dim, 2) / rotary_dim))
t = torch.arange(max_position).float()
freqs = torch.einsum('i,j -> ij', t, inv_freq.float())
freqs = torch.einsum("i,j -> ij", t, inv_freq.float())
cos = freqs.cos()
sin = freqs.sin()
cache = torch.cat((cos, sin), dim=-1)
......@@ -195,15 +227,32 @@ class PagedAttentionWithRoPE(PagedAttention):
def forward(
self,
positions: torch.Tensor, # [num_tokens]
query: torch.Tensor, # [num_tokens, num_heads * head_size]
key: torch.Tensor, # [num_tokens, num_heads * head_size]
value: torch.Tensor, # [num_tokens, num_heads * head_size]
key_cache: torch.Tensor, # [num_blocks, num_heads, head_size/x, block_size, x]
value_cache: torch.Tensor, # [num_blocks, num_heads, head_size, block_size]
positions: torch.Tensor,
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
key_cache: torch.Tensor,
value_cache: torch.Tensor,
input_metadata: InputMetadata,
cache_event: Optional[torch.cuda.Event],
) -> torch.Tensor: # [num_tokens, num_heads * head_size]
) -> torch.Tensor:
""" PagedAttention forward pass with rotary embedding.
Args:
positions: shape = [num_tokens]
query: shape = [num_tokens, num_heads * head_size]
key: shape = [num_tokens, num_heads * head_size]
value: shape = [num_tokens, num_heads * head_size]
key_cache: shape = [num_blocks, num_heads, head_size/x,
block_size, x]
value_cache: shape = [num_blocks, num_heads, head_size, block_size]
input_metadata: metadata for paged attention.
cache_event: event to wait for the cache operations to finish.
Returns:
shape = [num_tokens, num_heads * head_size]
"""
# Apply rotary embedding to the query and key before passing them
# to the attention op.
pos_encoding_ops.rotary_embedding_neox(
......
......@@ -13,6 +13,7 @@ from vllm.sequence import SequenceOutputs
_SAMPLING_EPS = 1e-5
class Sampler(nn.Module):
"""Samples the next tokens from the model's outputs.
......@@ -50,19 +51,20 @@ class Sampler(nn.Module):
# Apply presence and frequency penalties.
output_tokens = _get_output_tokens(input_metadata)
assert len(output_tokens) == logits.shape[0]
presence_penalties, frequency_penalties = _get_penalties(input_metadata)
presence_penalties, frequency_penalties = _get_penalties(
input_metadata)
assert len(presence_penalties) == logits.shape[0]
assert len(frequency_penalties) == logits.shape[0]
logits = _apply_penalties(
logits, output_tokens, presence_penalties, frequency_penalties,
self.vocab_size)
logits = _apply_penalties(logits, output_tokens, presence_penalties,
frequency_penalties, self.vocab_size)
# Apply temperature scaling.
temperatures = _get_temperatures(input_metadata)
assert len(temperatures) == logits.shape[0]
if any(t != 1.0 for t in temperatures):
t = torch.tensor(
temperatures, dtype=logits.dtype, device=logits.device)
t = torch.tensor(temperatures,
dtype=logits.dtype,
device=logits.device)
# Use in-place division to avoid creating a new tensor.
logits.div_(t.unsqueeze(dim=1))
......@@ -75,7 +77,9 @@ class Sampler(nn.Module):
# Apply top-p and top-k truncation.
top_ps, top_ks = _get_top_p_top_k(input_metadata, self.vocab_size)
assert len(top_ps) == len(top_ks) == probs.shape[0]
if any(p < 1.0 - _SAMPLING_EPS for p in top_ps) or any(k != self.vocab_size for k in top_ks):
do_top_p = any(p < 1.0 - _SAMPLING_EPS for p in top_ps)
do_top_k = any(k != self.vocab_size for k in top_ks)
if do_top_p or do_top_k:
probs = _apply_top_p_top_k(probs, top_ps, top_ks)
# Sample the next tokens.
......@@ -97,8 +101,7 @@ def _prune_hidden_states(
def _get_penalties(
input_metadata: InputMetadata,
) -> Tuple[List[float], List[float]]:
input_metadata: InputMetadata) -> Tuple[List[float], List[float]]:
# Collect the presence and frequency penalties.
presence_penalties: List[float] = []
frequency_penalties: List[float] = []
......@@ -117,9 +120,7 @@ def _get_penalties(
return presence_penalties, frequency_penalties
def _get_output_tokens(
input_metadata: InputMetadata,
) -> List[List[int]]:
def _get_output_tokens(input_metadata: InputMetadata) -> List[List[int]]:
output_tokens: List[List[int]] = []
for i, seq_group in enumerate(input_metadata.seq_groups):
seq_ids, _ = seq_group
......@@ -169,11 +170,13 @@ def _apply_penalties(
device=logits.device)
frequency_penalties = [frequency_penalties[i] for i in indices]
frequency_penalties = torch.tensor(
frequency_penalties, dtype=logits.dtype, device=logits.device)
frequency_penalties = torch.tensor(frequency_penalties,
dtype=logits.dtype,
device=logits.device)
presence_penalties = [presence_penalties[i] for i in indices]
presence_penalties = torch.tensor(
presence_penalties, dtype=logits.dtype, device=logits.device)
presence_penalties = torch.tensor(presence_penalties,
dtype=logits.dtype,
device=logits.device)
# We follow the definition in OpenAI API.
# Refer to https://platform.openai.com/docs/api-reference/parameter-details
......@@ -183,9 +186,7 @@ def _apply_penalties(
return logits
def _get_temperatures(
input_metadata: InputMetadata,
) -> List[float]:
def _get_temperatures(input_metadata: InputMetadata) -> List[float]:
# Collect the temperatures for the logits.
temperatures: List[float] = []
for i, seq_group in enumerate(input_metadata.seq_groups):
......@@ -252,8 +253,9 @@ def _apply_top_p_top_k(
probs_sort[top_k_mask] = 0.0
# Re-sort the probabilities.
probs = torch.gather(
probs_sort, dim=-1, index=torch.argsort(probs_idx, dim=-1))
probs = torch.gather(probs_sort,
dim=-1,
index=torch.argsort(probs_idx, dim=-1))
return probs
......@@ -296,8 +298,9 @@ def _sample_from_prompt(
# Random sampling.
# Sample `best_of` tokens for the prompt.
num_seqs = sampling_params.best_of
next_token_ids = torch.multinomial(
prob, num_samples=num_seqs, replacement=True)
next_token_ids = torch.multinomial(prob,
num_samples=num_seqs,
replacement=True)
next_token_ids = next_token_ids.tolist()
return next_token_ids
......@@ -315,8 +318,9 @@ def _sample_from_generation_tokens(
if sampling_params.use_beam_search:
# Beam search.
# Add cumulative logprobs for the sequences in the group.
seq_logprobs = torch.tensor(
seq_logprobs, dtype=torch.float, device=logprobs.device)
seq_logprobs = torch.tensor(seq_logprobs,
dtype=torch.float,
device=logprobs.device)
logprobs = logprobs + seq_logprobs.unsqueeze(dim=1)
vocab_size = logprobs.size(-1)
......@@ -353,8 +357,9 @@ def _sample_from_generation_tokens(
else:
# Random sampling.
# Sample 1 token for each sequence in the group.
next_token_ids = torch.multinomial(
probs, num_samples=1, replacement=True)
next_token_ids = torch.multinomial(probs,
num_samples=1,
replacement=True)
next_token_ids = next_token_ids.squeeze(dim=-1).tolist()
parent_seq_ids = seq_ids
return parent_seq_ids, next_token_ids
......@@ -381,15 +386,16 @@ def _sample(
# Sample the next tokens.
next_token_ids = _sample_from_prompt(prob, sampling_params)
# Get top-k log probabilities for the next tokens.
next_logprobs = _get_topk_logprobs(
logprob, sampling_params.logprobs)
next_logprobs = _get_topk_logprobs(logprob,
sampling_params.logprobs)
# Build the output.
for seq_id, next_token_id in zip(seq_ids, next_token_ids):
output_logprobs = next_logprobs.copy()
output_logprobs[next_token_id] = logprob[next_token_id].item()
seq_outputs[seq_id] = SequenceOutputs(
seq_id, seq_id, next_token_id, output_logprobs)
seq_outputs[seq_id] = SequenceOutputs(seq_id, seq_id,
next_token_id,
output_logprobs)
else:
# Generate the next tokens for generation tokens.
prob = probs[idx:idx + len(seq_ids)]
......@@ -399,22 +405,24 @@ def _sample(
# Sample the next tokens.
seq_logprobs = [
input_metadata.seq_data[seq_id].cumulative_logprob
for seq_id in seq_ids]
for seq_id in seq_ids
]
parent_seq_ids, next_token_ids = _sample_from_generation_tokens(
seq_ids, prob, logprob, seq_logprobs, sampling_params)
# Get top-k log probabilities for the next tokens.
next_logprobs: Dict[int, Dict[int, float]] = {}
for i, seq_id in enumerate(seq_ids):
for j, seq_id in enumerate(seq_ids):
next_logprobs[seq_id] = _get_topk_logprobs(
logprob[i], sampling_params.logprobs)
logprob[j], sampling_params.logprobs)
# Build the output.
for seq_id, parent_seq_id, next_token_id in zip(
seq_ids, parent_seq_ids, next_token_ids):
i = seq_ids.index(parent_seq_id)
seq_ids, parent_seq_ids, next_token_ids):
j = seq_ids.index(parent_seq_id)
output_logprobs = next_logprobs[parent_seq_id].copy()
output_logprobs[next_token_id] = logprob[i, next_token_id].item()
output_logprobs[next_token_id] = logprob[j,
next_token_id].item()
seq_outputs[seq_id] = SequenceOutputs(
seq_id,
parent_seq_id,
......
......@@ -6,8 +6,9 @@ import torch.nn as nn
from transformers import PretrainedConfig
from vllm.config import ModelConfig
from vllm.model_executor.models import (GPT2LMHeadModel, GPTBigCodeForCausalLM, GPTNeoXForCausalLM,
LlamaForCausalLM, OPTForCausalLM)
from vllm.model_executor.models import (GPT2LMHeadModel, GPTBigCodeForCausalLM,
GPTNeoXForCausalLM, LlamaForCausalLM,
OPTForCausalLM)
from vllm.model_executor.weight_utils import initialize_dummy_weights
# TODO(woosuk): Lazy-load the model classes.
......@@ -28,8 +29,7 @@ def _get_model_architecture(config: PretrainedConfig) -> Type[nn.Module]:
return _MODEL_REGISTRY[arch]
raise ValueError(
f"Model architectures {architectures} are not supported for now. "
f"Supported architectures: {list(_MODEL_REGISTRY.keys())}"
)
f"Supported architectures: {list(_MODEL_REGISTRY.keys())}")
def get_model(model_config: ModelConfig) -> nn.Module:
......@@ -46,8 +46,7 @@ def get_model(model_config: ModelConfig) -> nn.Module:
initialize_dummy_weights(model)
else:
# Load the weights from the cached or downloaded files.
model.load_weights(
model_config.model, model_config.download_dir,
model_config.use_np_weights)
model.load_weights(model_config.model, model_config.download_dir,
model_config.use_np_weights)
model = model.cuda()
return model.eval()
......@@ -4,8 +4,6 @@ from vllm.model_executor.models.gpt_neox import GPTNeoXForCausalLM
from vllm.model_executor.models.llama import LlamaForCausalLM
from vllm.model_executor.models.opt import OPTForCausalLM
__all__ = [
"GPT2LMHeadModel",
"GPTBigCodeForCausalLM",
......
# coding=utf-8
# Adapted from https://github.com/huggingface/transformers/blob/v4.28.0/src/transformers/models/gpt2/modeling_gpt2.py
# Adapted from
# https://github.com/huggingface/transformers/blob/v4.28.0/src/transformers/models/gpt2/modeling_gpt2.py
# Copyright 2023 The vLLM team.
# Copyright 2018 The OpenAI Team Authors and HuggingFace Inc. team.
# Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved.
......@@ -47,19 +48,25 @@ class GPT2Attention(nn.Module):
super().__init__()
self.hidden_size = config.hidden_size
total_num_heads = config.num_attention_heads
tensor_model_parallel_world_size = get_tensor_model_parallel_world_size()
tensor_model_parallel_world_size = (
get_tensor_model_parallel_world_size())
assert total_num_heads % tensor_model_parallel_world_size == 0
self.num_heads = total_num_heads // tensor_model_parallel_world_size
self.head_dim = self.hidden_size // total_num_heads
self.scale = self.head_dim ** -0.5
self.scale = self.head_dim**-0.5
self.c_attn = ColumnParallelLinear(self.hidden_size, 3 * self.hidden_size,
bias=True, gather_output=False,
self.c_attn = ColumnParallelLinear(self.hidden_size,
3 * self.hidden_size,
bias=True,
gather_output=False,
perform_initialization=False)
self.c_proj = RowParallelLinear(self.hidden_size, self.hidden_size,
bias=True, input_is_parallel=True,
self.c_proj = RowParallelLinear(self.hidden_size,
self.hidden_size,
bias=True,
input_is_parallel=True,
perform_initialization=False)
self.attn = PagedAttention(self.num_heads, self.head_dim,
self.attn = PagedAttention(self.num_heads,
self.head_dim,
scale=self.scale)
def forward(
......@@ -72,8 +79,8 @@ class GPT2Attention(nn.Module):
qkv, _ = self.c_attn(hidden_states)
q, k, v = qkv.chunk(chunks=3, dim=-1)
key_cache, value_cache = kv_cache
attn_output = self.attn(
q, k, v, key_cache, value_cache, input_metadata, cache_event)
attn_output = self.attn(q, k, v, key_cache, value_cache,
input_metadata, cache_event)
attn_output, _ = self.c_proj(attn_output)
return attn_output
......@@ -87,11 +94,15 @@ class GPT2MLP(nn.Module):
):
super().__init__()
hidden_size = config.hidden_size
self.c_fc = ColumnParallelLinear(hidden_size, intermediate_size,
bias=True, gather_output=False,
self.c_fc = ColumnParallelLinear(hidden_size,
intermediate_size,
bias=True,
gather_output=False,
perform_initialization=False)
self.c_proj = RowParallelLinear(intermediate_size, hidden_size,
bias=True, input_is_parallel=True,
self.c_proj = RowParallelLinear(intermediate_size,
hidden_size,
bias=True,
input_is_parallel=True,
perform_initialization=False)
self.act = get_act_fn(config.activation_function)
......@@ -107,7 +118,8 @@ class GPT2Block(nn.Module):
def __init__(self, config: GPT2Config):
super().__init__()
hidden_size = config.hidden_size
inner_dim = config.n_inner if config.n_inner is not None else 4 * hidden_size
inner_dim = (config.n_inner if config.n_inner is not None else 4 *
hidden_size)
self.ln_1 = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon)
self.attn = GPT2Attention(config)
......@@ -145,9 +157,9 @@ class GPT2Model(nn.Module):
def __init__(self, config: GPT2Config):
super().__init__()
self.config = config
assert config.add_cross_attention == False
assert config.scale_attn_by_inverse_layer_idx == False
assert config.reorder_and_upcast_attn == False
assert not config.add_cross_attention
assert not config.scale_attn_by_inverse_layer_idx
assert not config.reorder_and_upcast_attn
self.embed_dim = config.hidden_size
# Optimization: While the vocab size of GPT-2 is 50257, we extend it
......@@ -180,8 +192,8 @@ class GPT2Model(nn.Module):
else:
cache_event = cache_events[i]
layer = self.h[i]
hidden_states = layer(
hidden_states, kv_caches[i], input_metadata, cache_event)
hidden_states = layer(hidden_states, kv_caches[i], input_metadata,
cache_event)
hidden_states = self.ln_f(hidden_states)
return hidden_states
......@@ -206,24 +218,26 @@ class GPT2LMHeadModel(nn.Module):
input_metadata: InputMetadata,
cache_events: Optional[List[torch.cuda.Event]],
) -> Dict[int, SequenceOutputs]:
hidden_states = self.transformer(
input_ids, positions, kv_caches, input_metadata, cache_events)
next_tokens = self.sampler(
self.lm_head_weight, hidden_states, input_metadata)
hidden_states = self.transformer(input_ids, positions, kv_caches,
input_metadata, cache_events)
next_tokens = self.sampler(self.lm_head_weight, hidden_states,
input_metadata)
return next_tokens
_column_parallel_weights = ["wte.weight", "c_fc.weight", "c_fc.bias"]
_row_parallel_weights = ["c_proj.weight"]
def load_weights(self, model_name_or_path: str,
def load_weights(self,
model_name_or_path: str,
cache_dir: Optional[str] = None,
use_np_cache: bool = False):
tensor_model_parallel_world_size = get_tensor_model_parallel_world_size()
tensor_model_parallel_world_size = (
get_tensor_model_parallel_world_size())
tensor_model_parallel_rank = get_tensor_model_parallel_rank()
state_dict = self.state_dict()
for name, loaded_weight in hf_model_weights_iterator(
model_name_or_path, cache_dir, use_np_cache):
model_name_or_path, cache_dir, use_np_cache):
if "lm_head.weight" in name:
# GPT-2 ties the weights of the embedding layer and the final
# linear layer.
......@@ -248,16 +262,20 @@ class GPT2LMHeadModel(nn.Module):
if name == "transformer.wte.weight":
# Consider padding in the vocab size.
padded_vocab_size = param.shape[0] * tensor_model_parallel_world_size
padded_vocab_size = (param.shape[0] *
tensor_model_parallel_world_size)
num_extra_rows = padded_vocab_size - self.config.vocab_size
extra_rows = torch.empty(num_extra_rows, loaded_weight.shape[1])
extra_rows = torch.empty(num_extra_rows,
loaded_weight.shape[1])
extra_rows = extra_rows.to(loaded_weight)
loaded_weight = torch.cat([loaded_weight, extra_rows], dim=0)
# For the fused QKV linear layer, manually shard the weights.
if "c_attn" in name:
# GPT-2's fused QKV has the shape of [3 * num_heads * head_size, hidden_size].
# When tensor parallelism is used, we shard the weights along the head dimension.
# GPT-2's fused QKV has the shape of
# [3 * num_heads * head_size, hidden_size].
# When tensor parallelism is used, we shard the weights along
# the head dimension.
total_num_heads = self.config.num_attention_heads
hidden_size = self.config.hidden_size
head_size = hidden_size // total_num_heads
......@@ -266,11 +284,13 @@ class GPT2LMHeadModel(nn.Module):
head_end = (tensor_model_parallel_rank + 1) * num_heads
if name.endswith(".weight"):
loaded_weight = loaded_weight.view(3, total_num_heads, head_size, hidden_size)
loaded_weight = loaded_weight.view(3, total_num_heads,
head_size, hidden_size)
loaded_weight = loaded_weight[:, head_start:head_end, :, :]
loaded_weight = loaded_weight.reshape(-1, hidden_size)
elif name.endswith(".bias"):
loaded_weight = loaded_weight.view(3, total_num_heads, head_size)
loaded_weight = loaded_weight.view(3, total_num_heads,
head_size)
loaded_weight = loaded_weight[:, head_start:head_end, :]
loaded_weight = loaded_weight.reshape(-1)
else:
......
# coding=utf-8
# Adapted from https://github.com/huggingface/transformers/blob/v4.28.0/src/transformers/models/gpt2/modeling_gpt2.py
# Adapted from
# https://github.com/huggingface/transformers/blob/v4.28.0/src/transformers/models/gpt2/modeling_gpt2.py
# Copyright 2023 The vLLM team.
# Copyright 2023 CTranslate2, and Michael Feil
# Copyright 2018 The OpenAI Team Authors and HuggingFace Inc. team.
......@@ -49,19 +50,25 @@ class GPTBigCodeAttention(nn.Module):
super().__init__()
self.hidden_size = config.hidden_size
total_num_heads = config.num_attention_heads
tensor_model_parallel_world_size = get_tensor_model_parallel_world_size()
tensor_model_parallel_world_size = (
get_tensor_model_parallel_world_size())
assert total_num_heads % tensor_model_parallel_world_size == 0
self.num_heads = total_num_heads // tensor_model_parallel_world_size
self.head_dim = self.hidden_size // total_num_heads
self.scale = self.head_dim ** -0.5
self.scale = self.head_dim**-0.5
self.c_attn = ColumnParallelLinear(self.hidden_size, 3 * self.hidden_size,
bias=True, gather_output=False,
self.c_attn = ColumnParallelLinear(self.hidden_size,
3 * self.hidden_size,
bias=True,
gather_output=False,
perform_initialization=False)
self.c_proj = RowParallelLinear(self.hidden_size, self.hidden_size,
bias=True, input_is_parallel=True,
self.c_proj = RowParallelLinear(self.hidden_size,
self.hidden_size,
bias=True,
input_is_parallel=True,
perform_initialization=False)
self.attn = PagedAttention(self.num_heads, self.head_dim,
self.attn = PagedAttention(self.num_heads,
self.head_dim,
scale=self.scale)
def forward(
......@@ -74,8 +81,8 @@ class GPTBigCodeAttention(nn.Module):
qkv, _ = self.c_attn(hidden_states)
q, k, v = qkv.chunk(chunks=3, dim=-1)
key_cache, value_cache = kv_cache
attn_output = self.attn(
q, k, v, key_cache, value_cache, input_metadata, cache_event)
attn_output = self.attn(q, k, v, key_cache, value_cache,
input_metadata, cache_event)
attn_output, _ = self.c_proj(attn_output)
return attn_output
......@@ -89,11 +96,15 @@ class GPTBigMLP(nn.Module):
):
super().__init__()
hidden_size = config.hidden_size
self.c_fc = ColumnParallelLinear(hidden_size, intermediate_size,
bias=True, gather_output=False,
self.c_fc = ColumnParallelLinear(hidden_size,
intermediate_size,
bias=True,
gather_output=False,
perform_initialization=False)
self.c_proj = RowParallelLinear(intermediate_size, hidden_size,
bias=True, input_is_parallel=True,
self.c_proj = RowParallelLinear(intermediate_size,
hidden_size,
bias=True,
input_is_parallel=True,
perform_initialization=False)
self.act = get_act_fn(config.activation_function)
......@@ -109,7 +120,8 @@ class GPTBigCodeBlock(nn.Module):
def __init__(self, config: GPTBigCodeConfig):
super().__init__()
hidden_size = config.hidden_size
inner_dim = config.n_inner if config.n_inner is not None else 4 * hidden_size
inner_dim = (config.n_inner if config.n_inner is not None else 4 *
hidden_size)
self.ln_1 = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon)
self.attn = GPTBigCodeAttention(config)
......@@ -147,7 +159,7 @@ class GPTBigCodeModel(nn.Module):
def __init__(self, config: GPTBigCodeConfig):
super().__init__()
self.config = config
assert config.add_cross_attention == False
assert not config.add_cross_attention
self.embed_dim = config.hidden_size
......@@ -181,8 +193,8 @@ class GPTBigCodeModel(nn.Module):
else:
cache_event = cache_events[i]
layer = self.h[i]
hidden_states = layer(
hidden_states, kv_caches[i], input_metadata, cache_event)
hidden_states = layer(hidden_states, kv_caches[i], input_metadata,
cache_event)
hidden_states = self.ln_f(hidden_states)
return hidden_states
......@@ -207,24 +219,26 @@ class GPTBigCodeForCausalLM(nn.Module):
input_metadata: InputMetadata,
cache_events: Optional[List[torch.cuda.Event]],
) -> Dict[int, SequenceOutputs]:
hidden_states = self.transformer(
input_ids, positions, kv_caches, input_metadata, cache_events)
next_tokens = self.sampler(
self.lm_head_weight, hidden_states, input_metadata)
hidden_states = self.transformer(input_ids, positions, kv_caches,
input_metadata, cache_events)
next_tokens = self.sampler(self.lm_head_weight, hidden_states,
input_metadata)
return next_tokens
_column_parallel_weights = ["wte.weight", "c_fc.weight", "c_fc.bias"]
_row_parallel_weights = ["c_proj.weight"]
def load_weights(self, model_name_or_path: str,
def load_weights(self,
model_name_or_path: str,
cache_dir: Optional[str] = None,
use_np_cache: bool = False):
tensor_model_parallel_world_size = get_tensor_model_parallel_world_size()
tensor_model_parallel_world_size = (
get_tensor_model_parallel_world_size())
tensor_model_parallel_rank = get_tensor_model_parallel_rank()
state_dict = self.state_dict()
for name, loaded_weight in hf_model_weights_iterator(
model_name_or_path, cache_dir, use_np_cache):
model_name_or_path, cache_dir, use_np_cache):
if "lm_head.weight" in name:
# GPT-2 ties the weights of the embedding layer and the final
# linear layer.
......@@ -241,9 +255,11 @@ class GPTBigCodeForCausalLM(nn.Module):
if name == "transformer.wte.weight":
# Consider padding in the vocab size.
padded_vocab_size = param.shape[0] * tensor_model_parallel_world_size
padded_vocab_size = param.shape[
0] * tensor_model_parallel_world_size
num_extra_rows = padded_vocab_size - self.config.vocab_size
extra_rows = torch.empty(num_extra_rows, loaded_weight.shape[1])
extra_rows = torch.empty(num_extra_rows,
loaded_weight.shape[1])
extra_rows = extra_rows.to(loaded_weight)
loaded_weight = torch.cat([loaded_weight, extra_rows], dim=0)
......@@ -258,25 +274,31 @@ class GPTBigCodeForCausalLM(nn.Module):
qkv_array = qkv_array.numpy()
dims_q = n_head * head_dim
q, k, v = np.split(qkv_array, (dims_q, dims_q + head_dim), axis=0)
# q is fine, but k & v have not replicated shape along the first axis
# as long as MQA is not nativly supported, increase memory and replicated
# (head_dim, hidden_dim) to (n_heads * head_dim, hidden_dim)
# pylint: disable=unbalanced-tuple-unpacking
q, k, v = np.split(qkv_array, (dims_q, dims_q + head_dim),
axis=0)
# q is fine, but k & v have not replicated shape along the first
# axis as long as MQA is not nativly supported, increase memory
# and replicated (head_dim, hidden_dim) to
# (n_heads * head_dim, hidden_dim)
if k.ndim == 2 and v.ndim == 2:
replication = (n_head, 1) # weights
else:
replication = n_head # biases
# replicate n_head times for q, v
k, v = np.tile(k, replication), np.tile(v, replication)
# concat q, k, v along the first axis (n_heads * head_dim, hidden_dim)
# concat q, k, v along the first axis
# (n_heads * head_dim, hidden_dim)
# to (3 * n_heads * head_dim, hidden_dim)
qkv_array = np.concatenate((q, k, v), axis=0)
return torch.from_numpy(qkv_array)
# For the fused QKV linear layer, manually shard the weights.
if "c_attn" in name:
# GPT-2's fused QKV has the shape of [3 * num_heads * head_size, hidden_size].
# When tensor parallelism is used, we shard the weights along the head dimension.
# GPT-2's fused QKV has the shape of
# [3 * num_heads * head_size, hidden_size].
# When tensor parallelism is used, we shard the weights along
# the head dimension.
total_num_heads = self.config.num_attention_heads
hidden_size = self.config.hidden_size
head_size = hidden_size // total_num_heads
......@@ -285,13 +307,19 @@ class GPTBigCodeForCausalLM(nn.Module):
head_end = (tensor_model_parallel_rank + 1) * num_heads
if name.endswith(".weight"):
loaded_weight = _expand_mqa_mha(loaded_weight, n_head=total_num_heads, head_dim=head_size)
loaded_weight = loaded_weight.view(3, total_num_heads, head_size, hidden_size)
loaded_weight = _expand_mqa_mha(loaded_weight,
n_head=total_num_heads,
head_dim=head_size)
loaded_weight = loaded_weight.view(3, total_num_heads,
head_size, hidden_size)
loaded_weight = loaded_weight[:, head_start:head_end, :, :]
loaded_weight = loaded_weight.reshape(-1, hidden_size)
elif name.endswith(".bias"):
loaded_weight = _expand_mqa_mha(loaded_weight, n_head=total_num_heads, head_dim=head_size)
loaded_weight = loaded_weight.view(3, total_num_heads, head_size)
loaded_weight = _expand_mqa_mha(loaded_weight,
n_head=total_num_heads,
head_dim=head_size)
loaded_weight = loaded_weight.view(3, total_num_heads,
head_size)
loaded_weight = loaded_weight[:, head_start:head_end, :]
loaded_weight = loaded_weight.reshape(-1)
else:
......
# coding=utf-8
# Adapted from https://github.com/huggingface/transformers/blob/v4.28.0/src/transformers/models/gpt_neox/modeling_gpt_neox.py
# Adapted from
# https://github.com/huggingface/transformers/blob/v4.28.0/src/transformers/models/gpt_neox/modeling_gpt_neox.py
# Copyright 2023 The vLLM team.
# Copyright 2022 EleutherAI The HuggingFace Inc. team. All rights reserved.
#
......@@ -48,19 +49,23 @@ class GPTNeoXAttention(nn.Module):
self.hidden_size = config.hidden_size
self.head_size = self.hidden_size // self.total_num_heads
tensor_model_parallel_world_size = get_tensor_model_parallel_world_size()
tensor_model_parallel_world_size = (
get_tensor_model_parallel_world_size())
assert self.total_num_heads % tensor_model_parallel_world_size == 0
self.num_heads = self.total_num_heads // tensor_model_parallel_world_size
self.query_key_value = ColumnParallelLinear(config.hidden_size,
3 * config.hidden_size,
gather_output=False,
perform_initialization=False)
self.dense = RowParallelLinear(config.hidden_size, config.hidden_size,
self.num_heads = (self.total_num_heads //
tensor_model_parallel_world_size)
self.query_key_value = ColumnParallelLinear(
config.hidden_size,
3 * config.hidden_size,
gather_output=False,
perform_initialization=False)
self.dense = RowParallelLinear(config.hidden_size,
config.hidden_size,
input_is_parallel=True,
perform_initialization=False)
scaling = self.head_size ** -0.5
scaling = self.head_size**-0.5
rotary_dim = int(self.head_size * config.rotary_pct)
assert rotary_dim % 2 == 0
self.attn = PagedAttentionWithRoPE(self.num_heads, self.head_size,
......@@ -78,8 +83,8 @@ class GPTNeoXAttention(nn.Module):
q, k, v = qkv.chunk(chunks=3, dim=-1)
k_cache, v_cache = kv_cache
attn_output = self.attn(
position_ids, q, k, v, k_cache, v_cache, input_metadata, cache_event)
attn_output = self.attn(position_ids, q, k, v, k_cache, v_cache,
input_metadata, cache_event)
output, _ = self.dense(attn_output)
return output
......@@ -92,7 +97,8 @@ class GPTNeoXMLP(nn.Module):
config.intermediate_size,
gather_output=False,
perform_initialization=False)
self.dense_4h_to_h = RowParallelLinear(config.intermediate_size, config.hidden_size,
self.dense_4h_to_h = RowParallelLinear(config.intermediate_size,
config.hidden_size,
input_is_parallel=True,
perform_initialization=False)
self.act = get_act_fn(config.hidden_act)
......@@ -109,8 +115,10 @@ class GPTNeoXLayer(nn.Module):
def __init__(self, config: GPTNeoXConfig):
super().__init__()
self.use_parallel_residual = config.use_parallel_residual
self.input_layernorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
self.post_attention_layernorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
self.input_layernorm = nn.LayerNorm(config.hidden_size,
eps=config.layer_norm_eps)
self.post_attention_layernorm = nn.LayerNorm(config.hidden_size,
eps=config.layer_norm_eps)
self.attention = GPTNeoXAttention(config)
self.mlp = GPTNeoXMLP(config)
......@@ -154,10 +162,13 @@ class GPTNeoXModel(nn.Module):
super().__init__()
self.config = config
self.embed_in = VocabParallelEmbedding(config.vocab_size, config.hidden_size,
self.embed_in = VocabParallelEmbedding(config.vocab_size,
config.hidden_size,
perform_initialization=False)
self.layers = nn.ModuleList([GPTNeoXLayer(config) for _ in range(config.num_hidden_layers)])
self.final_layer_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
self.layers = nn.ModuleList(
[GPTNeoXLayer(config) for _ in range(config.num_hidden_layers)])
self.final_layer_norm = nn.LayerNorm(config.hidden_size,
eps=config.layer_norm_eps)
def forward(
self,
......@@ -191,8 +202,10 @@ class GPTNeoXForCausalLM(nn.Module):
super().__init__()
self.config = config
self.gpt_neox = GPTNeoXModel(config)
self.embed_out = ColumnParallelLinear(config.hidden_size, config.vocab_size,
bias=False, gather_output=False,
self.embed_out = ColumnParallelLinear(config.hidden_size,
config.vocab_size,
bias=False,
gather_output=False,
perform_initialization=False)
self.sampler = Sampler(config.vocab_size)
......@@ -204,24 +217,28 @@ class GPTNeoXForCausalLM(nn.Module):
input_metadata: InputMetadata,
cache_events: Optional[List[torch.cuda.Event]],
) -> Dict[int, SequenceOutputs]:
hidden_states = self.gpt_neox(
input_ids, positions, kv_caches, input_metadata, cache_events)
next_tokens = self.sampler(
self.embed_out.weight, hidden_states, input_metadata)
hidden_states = self.gpt_neox(input_ids, positions, kv_caches,
input_metadata, cache_events)
next_tokens = self.sampler(self.embed_out.weight, hidden_states,
input_metadata)
return next_tokens
_column_parallel_weights = ["embed_in.weight", "embed_out.weight", "dense_h_to_4h.weight", "dense_h_to_4h.bias"]
_column_parallel_weights = [
"embed_in.weight", "embed_out.weight", "dense_h_to_4h.weight",
"dense_h_to_4h.bias"
]
_row_parallel_weights = ["dense.weight", "dense_4h_to_h.weight"]
def load_weights(self, model_name_or_path: str,
def load_weights(self,
model_name_or_path: str,
cache_dir: Optional[str] = None,
use_np_cache: bool = False):
tensor_model_parallel_rank = get_tensor_model_parallel_rank()
state_dict = self.state_dict()
for name, loaded_weight in hf_model_weights_iterator(
model_name_or_path, cache_dir, use_np_cache):
model_name_or_path, cache_dir, use_np_cache):
if ("attention.bias" in name or "attention.masked_bias" in name
or "rotary_emb.inv_freq" in name):
or "rotary_emb.inv_freq" in name):
continue
param = state_dict[name]
if "query_key_value" in name:
......@@ -230,17 +247,19 @@ class GPTNeoXForCausalLM(nn.Module):
# required shape is [3 * num_heads * head_size, hidden_size].
# Thus, we need weight conversion.
shard_size = param.shape[0]
loaded_weight = loaded_weight[shard_size * tensor_model_parallel_rank
:shard_size * (tensor_model_parallel_rank + 1)]
loaded_weight = loaded_weight[
shard_size * tensor_model_parallel_rank:shard_size *
(tensor_model_parallel_rank + 1)]
num_heads = self.config.num_attention_heads
hidden_size = self.config.hidden_size
head_size = hidden_size // num_heads
if 'query_key_value.weight' in name:
loaded_weight = loaded_weight.view(-1, 3, head_size, hidden_size)
if "query_key_value.weight" in name:
loaded_weight = loaded_weight.view(-1, 3, head_size,
hidden_size)
loaded_weight = loaded_weight.transpose(0, 1)
loaded_weight = loaded_weight.reshape(-1, hidden_size)
elif 'query_key_value.bias' in name:
elif "query_key_value.bias" in name:
loaded_weight = loaded_weight.view(-1, 3, head_size)
loaded_weight = loaded_weight.transpose(0, 1)
loaded_weight = loaded_weight.reshape(-1)
......
# coding=utf-8
# Adapted from https://github.com/huggingface/transformers/blob/v4.28.0/src/transformers/models/llama/modeling_llama.py
# Adapted from
# https://github.com/huggingface/transformers/blob/v4.28.0/src/transformers/models/llama/modeling_llama.py
# Copyright 2023 The vLLM team.
# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved.
#
......@@ -30,7 +31,6 @@ import torch
from torch import nn
from transformers import LlamaConfig
from vllm.sequence import SequenceOutputs
from vllm.model_executor.input_metadata import InputMetadata
from vllm.model_executor.layers.activation import SiluAndMul
from vllm.model_executor.layers.layernorm import RMSNorm
......@@ -56,15 +56,19 @@ class LlamaMLP(nn.Module):
hidden_act: str,
):
super().__init__()
self.gate_up_proj = ColumnParallelLinear(hidden_size, 2 * intermediate_size,
bias=False, gather_output=False,
self.gate_up_proj = ColumnParallelLinear(hidden_size,
2 * intermediate_size,
bias=False,
gather_output=False,
perform_initialization=False)
self.down_proj = RowParallelLinear(intermediate_size, hidden_size,
bias=False, input_is_parallel=True,
self.down_proj = RowParallelLinear(intermediate_size,
hidden_size,
bias=False,
input_is_parallel=True,
perform_initialization=False)
if hidden_act != 'silu':
raise ValueError(f'Unsupported activation: {hidden_act}. '
'Only silu is supported for now.')
if hidden_act != "silu":
raise ValueError(f"Unsupported activation: {hidden_act}. "
"Only silu is supported for now.")
self.act_fn = SiluAndMul()
def forward(self, x):
......@@ -83,12 +87,14 @@ class LlamaAttention(nn.Module):
):
super().__init__()
self.hidden_size = hidden_size
tensor_model_parallel_world_size = get_tensor_model_parallel_world_size()
tensor_model_parallel_world_size = (
get_tensor_model_parallel_world_size())
self.total_num_heads = num_heads
assert self.total_num_heads % tensor_model_parallel_world_size == 0
self.num_heads = self.total_num_heads // tensor_model_parallel_world_size
self.num_heads = (self.total_num_heads //
tensor_model_parallel_world_size)
self.head_dim = hidden_size // self.total_num_heads
self.scaling = self.head_dim ** -0.5
self.scaling = self.head_dim**-0.5
self.qkv_proj = ColumnParallelLinear(
hidden_size,
......@@ -104,8 +110,10 @@ class LlamaAttention(nn.Module):
input_is_parallel=True,
perform_initialization=False,
)
self.attn = PagedAttentionWithRoPE(self.num_heads, self.head_dim,
self.scaling, rotary_dim=self.head_dim)
self.attn = PagedAttentionWithRoPE(self.num_heads,
self.head_dim,
self.scaling,
rotary_dim=self.head_dim)
def forward(
self,
......@@ -118,8 +126,8 @@ class LlamaAttention(nn.Module):
qkv, _ = self.qkv_proj(hidden_states)
q, k, v = qkv.chunk(chunks=3, dim=-1)
k_cache, v_cache = kv_cache
attn_output = self.attn(
positions, q, k, v, k_cache, v_cache, input_metadata, cache_event)
attn_output = self.attn(positions, q, k, v, k_cache, v_cache,
input_metadata, cache_event)
output, _ = self.o_proj(attn_output)
return output
......@@ -138,8 +146,10 @@ class LlamaDecoderLayer(nn.Module):
intermediate_size=config.intermediate_size,
hidden_act=config.hidden_act,
)
self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
self.post_attention_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
self.input_layernorm = RMSNorm(config.hidden_size,
eps=config.rms_norm_eps)
self.post_attention_layernorm = RMSNorm(config.hidden_size,
eps=config.rms_norm_eps)
def forward(
self,
......@@ -177,9 +187,13 @@ class LlamaModel(nn.Module):
self.padding_idx = config.pad_token_id
self.vocab_size = config.vocab_size
self.embed_tokens = VocabParallelEmbedding(config.vocab_size, config.hidden_size,
perform_initialization=False)
self.layers = nn.ModuleList([LlamaDecoderLayer(config) for _ in range(config.num_hidden_layers)])
self.embed_tokens = VocabParallelEmbedding(
config.vocab_size,
config.hidden_size,
perform_initialization=False)
self.layers = nn.ModuleList([
LlamaDecoderLayer(config) for _ in range(config.num_hidden_layers)
])
self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
def forward(
......@@ -209,6 +223,7 @@ class LlamaModel(nn.Module):
class LlamaForCausalLM(nn.Module):
def __init__(self, config):
super().__init__()
self.config = config
......@@ -228,39 +243,42 @@ class LlamaForCausalLM(nn.Module):
input_metadata: InputMetadata,
cache_events: Optional[List[torch.cuda.Event]],
) -> Dict[int, SequenceOutputs]:
hidden_states = self.model(
input_ids, positions, kv_caches, input_metadata, cache_events)
next_tokens = self.sampler(
self.lm_head.weight, hidden_states, input_metadata)
hidden_states = self.model(input_ids, positions, kv_caches,
input_metadata, cache_events)
next_tokens = self.sampler(self.lm_head.weight, hidden_states,
input_metadata)
return next_tokens
_column_parallel_weights = ["embed_tokens.weight", "lm_head.weight",
"qkv_proj.weight", "gate_proj.weight",
"up_proj.weight"]
_column_parallel_weights = [
"embed_tokens.weight", "lm_head.weight", "qkv_proj.weight",
"gate_proj.weight", "up_proj.weight"
]
_row_parallel_weights = ["o_proj.weight", "down_proj.weight"]
def load_weights(self, model_name_or_path: str,
def load_weights(self,
model_name_or_path: str,
cache_dir: Optional[str] = None,
use_np_cache: bool = False):
tensor_model_parallel_rank = get_tensor_model_parallel_rank()
state_dict = self.state_dict()
for name, loaded_weight in hf_model_weights_iterator(
model_name_or_path, cache_dir, use_np_cache):
model_name_or_path, cache_dir, use_np_cache):
if "rotary_emb.inv_freq" in name:
continue
is_attention_weight = False
for stride_id, att_weight_name in enumerate(["q_proj", "k_proj", "v_proj"]):
for stride_id, att_weight_name in enumerate(
["q_proj", "k_proj", "v_proj"]):
if att_weight_name not in name:
continue
param = state_dict[name.replace(att_weight_name, "qkv_proj")]
shard_size = param.shape[0] // 3
loaded_weight = loaded_weight[
shard_size * tensor_model_parallel_rank
:shard_size * (tensor_model_parallel_rank + 1)]
param_slice = param.data[shard_size * stride_id
:shard_size * (stride_id + 1)]
shard_size * tensor_model_parallel_rank:shard_size *
(tensor_model_parallel_rank + 1)]
param_slice = param.data[shard_size * stride_id:shard_size *
(stride_id + 1)]
assert param_slice.shape == loaded_weight.shape
param_slice.copy_(loaded_weight)
is_attention_weight = True
......@@ -275,10 +293,10 @@ class LlamaForCausalLM(nn.Module):
param = state_dict[name.replace(weight_name, "gate_up_proj")]
shard_size = param.shape[0] // 2
loaded_weight = loaded_weight[
shard_size * tensor_model_parallel_rank
:shard_size * (tensor_model_parallel_rank + 1)]
param_slice = param.data[shard_size * stride_id
:shard_size * (stride_id + 1)]
shard_size * tensor_model_parallel_rank:shard_size *
(tensor_model_parallel_rank + 1)]
param_slice = param.data[shard_size * stride_id:shard_size *
(stride_id + 1)]
assert param_slice.shape == loaded_weight.shape
param_slice.copy_(loaded_weight)
is_gate_up_weight = True
......
# coding=utf-8
# Adapted from https://github.com/huggingface/transformers/blob/v4.28.0/src/transformers/models/opt/modeling_opt.py
# Adapted from
# https://github.com/huggingface/transformers/blob/v4.28.0/src/transformers/models/opt/modeling_opt.py
# Copyright 2023 The vLLM team.
# Copyright 2022 The Fairseq Authors and The HuggingFace Inc. team. All rights reserved.
# Copyright 2022 The Fairseq Authors and The HuggingFace Inc. team. All rights
# reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
......@@ -43,8 +45,9 @@ KVCache = Tuple[torch.Tensor, torch.Tensor]
class OPTLearnedPositionalEmbedding(nn.Embedding):
def __init__(self, num_embeddings: int, embedding_dim: int):
# OPT is set up so that if padding_idx is specified then offset the embedding ids by 2
# and adjust num_embeddings appropriately. Other models don't have this hack
# OPT is set up so that if padding_idx is specified then offset the
# embedding ids by 2 and adjust num_embeddings appropriately. Other
# models don't have this hack
self.offset = 2
super().__init__(num_embeddings + self.offset, embedding_dim)
......@@ -62,20 +65,26 @@ class OPTAttention(nn.Module):
) -> None:
super().__init__()
self.embed_dim = embed_dim
tensor_model_parallel_world_size = get_tensor_model_parallel_world_size()
tensor_model_parallel_world_size = (
get_tensor_model_parallel_world_size())
total_num_heads = num_heads
assert num_heads % tensor_model_parallel_world_size == 0
self.num_heads = total_num_heads // tensor_model_parallel_world_size
self.head_dim = embed_dim // total_num_heads
self.scaling = self.head_dim ** -0.5
self.scaling = self.head_dim**-0.5
self.qkv_proj = ColumnParallelLinear(embed_dim, 3 * embed_dim, bias=bias,
self.qkv_proj = ColumnParallelLinear(embed_dim,
3 * embed_dim,
bias=bias,
gather_output=False,
perform_initialization=False)
self.out_proj = RowParallelLinear(embed_dim, embed_dim, bias=bias,
self.out_proj = RowParallelLinear(embed_dim,
embed_dim,
bias=bias,
input_is_parallel=True,
perform_initialization=False)
self.attn = PagedAttention(self.num_heads, self.head_dim,
self.attn = PagedAttention(self.num_heads,
self.head_dim,
scale=self.scaling)
def forward(
......@@ -88,8 +97,8 @@ class OPTAttention(nn.Module):
qkv, _ = self.qkv_proj(hidden_states)
q, k, v = qkv.chunk(chunks=3, dim=-1)
key_cache, value_cache = kv_cache
attn_output = self.attn(
q, k, v, key_cache, value_cache, input_metadata, cache_event)
attn_output = self.attn(q, k, v, key_cache, value_cache,
input_metadata, cache_event)
output, _ = self.out_proj(attn_output)
return output
......@@ -109,17 +118,21 @@ class OPTDecoderLayer(nn.Module):
self.activation_fn = get_act_fn(config.activation_function)
self.self_attn_layer_norm = nn.LayerNorm(
self.embed_dim, elementwise_affine=config.layer_norm_elementwise_affine)
self.fc1 = ColumnParallelLinear(self.embed_dim, config.ffn_dim,
self.embed_dim,
elementwise_affine=config.layer_norm_elementwise_affine)
self.fc1 = ColumnParallelLinear(self.embed_dim,
config.ffn_dim,
bias=config.enable_bias,
gather_output=False,
perform_initialization=False)
self.fc2 = RowParallelLinear(config.ffn_dim, self.embed_dim,
self.fc2 = RowParallelLinear(config.ffn_dim,
self.embed_dim,
bias=config.enable_bias,
input_is_parallel=True,
perform_initialization=False)
self.final_layer_norm = nn.LayerNorm(
self.embed_dim, elementwise_affine=config.layer_norm_elementwise_affine)
self.embed_dim,
elementwise_affine=config.layer_norm_elementwise_affine)
def forward(
self,
......@@ -133,11 +146,10 @@ class OPTDecoderLayer(nn.Module):
# 125m, 1.7B, ..., 175B applies layer norm BEFORE attention
if self.do_layer_norm_before:
hidden_states = self.self_attn_layer_norm(hidden_states)
hidden_states = self.self_attn(
hidden_states=hidden_states,
kv_cache=kv_cache,
input_metadata=input_metadata,
cache_event=cache_event)
hidden_states = self.self_attn(hidden_states=hidden_states,
kv_cache=kv_cache,
input_metadata=input_metadata,
cache_event=cache_event)
hidden_states = residual + hidden_states
# 350m applies layer norm AFTER attention
if not self.do_layer_norm_before:
......@@ -167,35 +179,42 @@ class OPTDecoder(nn.Module):
self.max_target_positions = config.max_position_embeddings
self.vocab_size = config.vocab_size
self.embed_tokens = VocabParallelEmbedding(config.vocab_size,
config.word_embed_proj_dim,
perform_initialization=False)
self.embed_tokens = VocabParallelEmbedding(
config.vocab_size,
config.word_embed_proj_dim,
perform_initialization=False)
# Positional embeddings are replicated (not sharded).
self.embed_positions = OPTLearnedPositionalEmbedding(
config.max_position_embeddings, config.hidden_size)
# Project out & in will be replicated if they exist.
if config.word_embed_proj_dim != config.hidden_size:
self.project_out = nn.Linear(config.hidden_size, config.word_embed_proj_dim, bias=False)
self.project_out = nn.Linear(config.hidden_size,
config.word_embed_proj_dim,
bias=False)
else:
self.project_out = None
if config.word_embed_proj_dim != config.hidden_size:
self.project_in = nn.Linear(config.word_embed_proj_dim, config.hidden_size, bias=False)
self.project_in = nn.Linear(config.word_embed_proj_dim,
config.hidden_size,
bias=False)
else:
self.project_in = None
# Note that the only purpose of `config._remove_final_layer_norm` is to keep backward compatibility
# with checkpoints that have been fine-tuned before transformers v4.20.1
# Note that the only purpose of `config._remove_final_layer_norm` is to
# keep backward compatibility with checkpoints that have been fine-tuned
# before transformers v4.20.1
# see https://github.com/facebookresearch/metaseq/pull/164
if config.do_layer_norm_before and not config._remove_final_layer_norm:
self.final_layer_norm = nn.LayerNorm(
config.hidden_size, elementwise_affine=config.layer_norm_elementwise_affine
)
config.hidden_size,
elementwise_affine=config.layer_norm_elementwise_affine)
else:
self.final_layer_norm = None
self.layers = nn.ModuleList([OPTDecoderLayer(config) for _ in range(config.num_hidden_layers)])
self.layers = nn.ModuleList(
[OPTDecoderLayer(config) for _ in range(config.num_hidden_layers)])
def forward(
self,
......@@ -217,8 +236,8 @@ class OPTDecoder(nn.Module):
else:
cache_event = cache_events[i]
layer = self.layers[i]
hidden_states = layer(
hidden_states, kv_caches[i], input_metadata, cache_event)
hidden_states = layer(hidden_states, kv_caches[i], input_metadata,
cache_event)
if self.final_layer_norm is not None:
hidden_states = self.final_layer_norm(hidden_states)
......@@ -241,8 +260,8 @@ class OPTModel(nn.Module):
input_metadata: InputMetadata,
cache_events: Optional[List[torch.cuda.Event]],
) -> torch.Tensor:
return self.decoder(
input_ids, positions, kv_caches, input_metadata, cache_events)
return self.decoder(input_ids, positions, kv_caches, input_metadata,
cache_events)
class OPTForCausalLM(nn.Module):
......@@ -264,23 +283,26 @@ class OPTForCausalLM(nn.Module):
input_metadata: InputMetadata,
cache_events: Optional[List[torch.cuda.Event]],
) -> Dict[int, SequenceOutputs]:
hidden_states = self.model(
input_ids, positions, kv_caches, input_metadata, cache_events)
next_tokens = self.sampler(
self.lm_head_weight, hidden_states, input_metadata)
hidden_states = self.model(input_ids, positions, kv_caches,
input_metadata, cache_events)
next_tokens = self.sampler(self.lm_head_weight, hidden_states,
input_metadata)
return next_tokens
_column_parallel_weights = ["embed_tokens.weight", "fc1.weight", "fc1.bias"]
_column_parallel_weights = [
"embed_tokens.weight", "fc1.weight", "fc1.bias"
]
_row_parallel_weights = ["out_proj.weight", "fc2.weight"]
def load_weights(self, model_name_or_path: str,
def load_weights(self,
model_name_or_path: str,
cache_dir: Optional[str] = None,
use_np_cache: bool = False):
tensor_model_parallel_rank = get_tensor_model_parallel_rank()
state_dict = self.state_dict()
for name, loaded_weight in hf_model_weights_iterator(
model_name_or_path, cache_dir, use_np_cache):
model_name_or_path, cache_dir, use_np_cache):
if "lm_head.weight" in name:
continue
......@@ -288,16 +310,17 @@ class OPTForCausalLM(nn.Module):
name = "model." + name
is_attention_weight = False
for stride_id, att_weight_name in enumerate(["q_proj", "k_proj", "v_proj"]):
for stride_id, att_weight_name in enumerate(
["q_proj", "k_proj", "v_proj"]):
if att_weight_name not in name:
continue
param = state_dict[name.replace(att_weight_name, "qkv_proj")]
shard_size = param.shape[0] // 3
loaded_weight = loaded_weight[
shard_size * tensor_model_parallel_rank
:shard_size * (tensor_model_parallel_rank + 1)]
param_slice = param.data[shard_size * stride_id
:shard_size * (stride_id + 1)]
shard_size * tensor_model_parallel_rank:shard_size *
(tensor_model_parallel_rank + 1)]
param_slice = param.data[shard_size * stride_id:shard_size *
(stride_id + 1)]
assert param_slice.shape == loaded_weight.shape
param_slice.copy_(loaded_weight)
is_attention_weight = True
......
......@@ -44,9 +44,9 @@ def hf_model_weights_iterator(
if use_np_cache:
# Convert the model weights from torch tensors to numpy arrays for
# faster loading.
np_folder = os.path.join(hf_folder, 'np')
np_folder = os.path.join(hf_folder, "np")
os.makedirs(np_folder, exist_ok=True)
weight_names_file = os.path.join(np_folder, 'weight_names.json')
weight_names_file = os.path.join(np_folder, "weight_names.json")
with lock:
if not os.path.exists(weight_names_file):
weight_names = []
......@@ -57,10 +57,10 @@ def hf_model_weights_iterator(
with open(param_path, "wb") as f:
np.save(f, param.cpu().detach().numpy())
weight_names.append(name)
with open(weight_names_file, 'w') as f:
with open(weight_names_file, "w") as f:
json.dump(weight_names, f)
with open(weight_names_file, 'r') as f:
with open(weight_names_file, "r") as f:
weight_names = json.load(f)
for name in weight_names:
......@@ -86,17 +86,16 @@ def load_tensor_parallel_weights(
for p in column_parallel_weight_names:
if p in param_name:
shard_size = param.shape[0]
loaded_weight = loaded_weight[
shard_size * tensor_model_parallel_rank
:shard_size * (tensor_model_parallel_rank + 1)]
start_idx = tensor_model_parallel_rank * shard_size
end_idx = (tensor_model_parallel_rank + 1) * shard_size
loaded_weight = loaded_weight[start_idx:end_idx]
break
for p in row_parallel_weight_names:
if p in param_name:
shard_size = param.shape[1]
loaded_weight = loaded_weight[
:,
shard_size * tensor_model_parallel_rank
:shard_size * (tensor_model_parallel_rank + 1)]
start_idx = tensor_model_parallel_rank * shard_size
end_idx = (tensor_model_parallel_rank + 1) * shard_size
loaded_weight = loaded_weight[:, start_idx:end_idx]
break
assert param.shape == loaded_weight.shape, (
f"{param_name} shape mismatch between model and checkpoint: "
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment