Unverified Commit badf3fa0 authored by Lianmin Zheng's avatar Lianmin Zheng Committed by GitHub
Browse files

Expose dtype argument (#569)

parent 945aa9be
......@@ -6,7 +6,7 @@ import logging
import pkgutil
from dataclasses import dataclass
from functools import lru_cache
from typing import List, Optional, Type, Any
from typing import List, Optional, Type
import numpy as np
import torch
......@@ -119,7 +119,7 @@ class InputMetadata:
head_dim,
1,
pos_encoding_mode="NONE",
data_type="float16",
data_type=self.token_to_kv_pool.kv_data[0].dtype
)
def init_extend_args(self):
......@@ -287,10 +287,11 @@ class ModelRunner:
tokenizer=None,
tokenizer_mode=None,
trust_remote_code=self.server_args.trust_remote_code,
dtype=torch.float16,
dtype=self.server_args.dtype,
seed=42,
skip_tokenizer_init=True,
)
self.dtype = vllm_model_config.dtype
if self.model_config.model_overide_args is not None:
vllm_model_config.hf_config.update(self.model_config.model_overide_args)
......@@ -307,6 +308,7 @@ class ModelRunner:
logger.info(
f"[gpu_id={self.gpu_id}] Load weight end. "
f"type={type(self.model).__name__}, "
f"dtype={self.dtype}, "
f"avail mem={get_available_gpu_memory(self.gpu_id):.2f} GB"
)
......@@ -316,7 +318,7 @@ class ModelRunner:
)
head_dim = self.model_config.head_dim
head_num = self.model_config.get_num_kv_heads(self.tp_size)
cell_size = head_num * head_dim * self.model_config.num_hidden_layers * 2 * 2
cell_size = head_num * head_dim * self.model_config.num_hidden_layers * 2 * torch._utils._element_size(self.dtype)
rest_memory = available_gpu_memory - total_gpu_memory * (
1 - self.mem_fraction_static
)
......@@ -337,7 +339,7 @@ class ModelRunner:
)
self.token_to_kv_pool = TokenToKVPool(
self.max_total_num_tokens,
dtype=torch.float16,
dtype=self.dtype,
head_num=self.model_config.get_num_kv_heads(self.tp_size),
head_dim=self.model_config.head_dim,
layer_num=self.model_config.num_hidden_layers,
......
......@@ -120,7 +120,7 @@ class ModelTpServer:
f"[gpu_id={self.gpu_id}] "
f"max_total_num_tokens={self.max_total_num_tokens}, "
f"max_prefill_tokens={self.max_prefill_tokens}, "
f"context_len={self.model_config.context_len}, "
f"context_len={self.model_config.context_len}"
)
if self.tp_rank == 0:
logger.info(
......
......@@ -11,12 +11,13 @@ class ServerArgs:
# Model and tokenizer
model_path: str
tokenizer_path: Optional[str] = None
load_format: str = "auto"
tokenizer_mode: str = "auto"
chat_template: Optional[str] = None
load_format: str = "auto"
dtype: str = "auto"
trust_remote_code: bool = True
context_length: Optional[int] = None
quantization: Optional[str] = None
chat_template: Optional[str] = None
# Port
host: str = "127.0.0.1"
......@@ -107,6 +108,15 @@ class ServerArgs:
default=[],
help="The additional ports specified for the server.",
)
parser.add_argument(
"--tokenizer-mode",
type=str,
default=ServerArgs.tokenizer_mode,
choices=["auto", "slow"],
help="Tokenizer mode. 'auto' will use the fast "
"tokenizer if available, and 'slow' will "
"always use the slow tokenizer.",
)
parser.add_argument(
"--load-format",
type=str,
......@@ -124,20 +134,20 @@ class ServerArgs:
"which is mainly for profiling.",
)
parser.add_argument(
"--tokenizer-mode",
type=str,
default=ServerArgs.tokenizer_mode,
choices=["auto", "slow"],
help="Tokenizer mode. 'auto' will use the fast "
"tokenizer if available, and 'slow' will "
"always use the slow tokenizer.",
)
parser.add_argument(
"--chat-template",
"--dtype",
type=str,
default=ServerArgs.chat_template,
help="The buliltin chat template name or the path of the chat template file. This is only used for OpenAI-compatible API server",
)
default=ServerArgs.dtype,
choices=[
"auto", "half", "float16", "bfloat16", "float", "float32"
],
help='Data type for model weights and activations.\n\n'
'* "auto" will use FP16 precision for FP32 and FP16 models, and '
'BF16 precision for BF16 models.\n'
'* "half" for FP16. Recommended for AWQ quantization.\n'
'* "float16" is the same as "half".\n'
'* "bfloat16" for a balance between precision and range.\n'
'* "float" is shorthand for FP32 precision.\n'
'* "float32" for FP32 precision.')
parser.add_argument(
"--trust-remote-code",
action="store_true",
......@@ -155,6 +165,12 @@ class ServerArgs:
default=ServerArgs.quantization,
help="The quantization method.",
)
parser.add_argument(
"--chat-template",
type=str,
default=ServerArgs.chat_template,
help="The buliltin chat template name or the path of the chat template file. This is only used for OpenAI-compatible API server.",
)
parser.add_argument(
"--mem-fraction-static",
type=float,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment