Unverified Commit 189ae231 authored by Woosuk Kwon's avatar Woosuk Kwon Committed by GitHub
Browse files

Use dtype from model config & Add Dolly V2 (#63)

parent e548c148
...@@ -214,7 +214,11 @@ def add_server_arguments(parser: argparse.ArgumentParser): ...@@ -214,7 +214,11 @@ def add_server_arguments(parser: argparse.ArgumentParser):
help='save a numpy copy of model weights for faster loading') help='save a numpy copy of model weights for faster loading')
parser.add_argument('--use-dummy-weights', action='store_true', help='use dummy values for model weights') parser.add_argument('--use-dummy-weights', action='store_true', help='use dummy values for model weights')
# NOTE(woosuk): FlashAttention does not support float32. # NOTE(woosuk): FlashAttention does not support float32.
parser.add_argument('--dtype', type=str, default='half', choices=['half', 'bfloat16'], help='data type') parser.add_argument('--dtype', type=str, default='default', choices=['default', 'half', 'bfloat16'],
help=('data type for model weights and activations. '
'The "default" option will use FP16 precision '
'for FP32 and FP16 models, and BF16 precision '
'for BF16 models.'))
# Parallel arguments # Parallel arguments
parser.add_argument('--use-ray', action='store_true', help='use Ray for distributed serving, will be automatically set when using more than 1 GPU') parser.add_argument('--use-ray', action='store_true', help='use Ray for distributed serving, will be automatically set when using more than 1 GPU')
parser.add_argument('--pipeline-parallel-size', '-pp', type=int, default=1, help='number of pipeline stages') parser.add_argument('--pipeline-parallel-size', '-pp', type=int, default=1, help='number of pipeline stages')
......
from typing import Union, Optional from typing import Optional
import torch import torch
import torch.nn as nn import torch.nn as nn
from transformers import AutoConfig from transformers import AutoConfig
from transformers import PretrainedConfig
from cacheflow.models.memory_analyzer import CacheFlowMemoryAnalyzer from cacheflow.models.memory_analyzer import CacheFlowMemoryAnalyzer
from cacheflow.models.memory_analyzer import GPT2MemoryAnalyzer from cacheflow.models.memory_analyzer import GPT2MemoryAnalyzer
...@@ -22,6 +23,7 @@ _MODELS = { ...@@ -22,6 +23,7 @@ _MODELS = {
'opt': OPTForCausalLM, 'opt': OPTForCausalLM,
'stablelm': GPTNeoXForCausalLM, 'stablelm': GPTNeoXForCausalLM,
'pythia': GPTNeoXForCausalLM, 'pythia': GPTNeoXForCausalLM,
'dolly-v2': GPTNeoXForCausalLM,
} }
_MEMORY_ANALYZERS = { _MEMORY_ANALYZERS = {
...@@ -30,19 +32,38 @@ _MEMORY_ANALYZERS = { ...@@ -30,19 +32,38 @@ _MEMORY_ANALYZERS = {
'opt': OPTMemoryAnalyzer, 'opt': OPTMemoryAnalyzer,
'stablelm': GPTNeoXMemoryAnalyzer, 'stablelm': GPTNeoXMemoryAnalyzer,
'pythia': GPTNeoXMemoryAnalyzer, 'pythia': GPTNeoXMemoryAnalyzer,
'dolly-v2': GPTNeoXMemoryAnalyzer,
} }
def _get_dtype(config: PretrainedConfig, dtype: str) -> torch.dtype:
config_dtype: torch.dtype = getattr(config, 'torch_dtype', torch.float32)
if dtype == 'default':
if config_dtype == torch.float32:
# Following the common practice, we use float16 for float32 models.
torch_dtype = torch.float16
else:
torch_dtype = config_dtype
else:
torch_dtype = get_torch_dtype(dtype)
if torch_dtype != config_dtype and config_dtype != torch.float32:
# TODO(woosuk): Allow using float16 for bfloat16 models and
# vice versa. Print a warning message and continue.
raise ValueError(
f'Cannot use {torch_dtype} for {config_dtype} model.')
return torch_dtype
def get_model( def get_model(
model_name: str, model_name: str,
dtype: Union[torch.dtype, str], dtype: str,
cache_dir: Optional[str], cache_dir: Optional[str],
use_dummy_weights: bool, use_dummy_weights: bool,
use_np_cache: bool, use_np_cache: bool,
) -> nn.Module: ) -> nn.Module:
torch_dtype = get_torch_dtype(dtype)
torch.set_default_dtype(torch_dtype)
config = AutoConfig.from_pretrained(model_name) config = AutoConfig.from_pretrained(model_name)
torch_dtype = _get_dtype(config, dtype)
torch.set_default_dtype(torch_dtype)
for model_class_name, model_class in _MODELS.items(): for model_class_name, model_class in _MODELS.items():
if model_class_name in model_name: if model_class_name in model_name:
if use_dummy_weights: if use_dummy_weights:
...@@ -66,12 +87,13 @@ def get_model( ...@@ -66,12 +87,13 @@ def get_model(
def get_memory_analyzer( def get_memory_analyzer(
model_name: str, model_name: str,
block_size: int, block_size: int,
dtype: Union[torch.dtype, str], dtype: str,
gpu_memory: int, gpu_memory: int,
cpu_memory: int, cpu_memory: int,
tensor_parallel_size: int = 1, tensor_parallel_size: int = 1,
) -> CacheFlowMemoryAnalyzer: ) -> CacheFlowMemoryAnalyzer:
torch_dtype = get_torch_dtype(dtype) config = AutoConfig.from_pretrained(model_name)
torch_dtype = _get_dtype(config, dtype)
for model_class, memory_analyzer in _MEMORY_ANALYZERS.items(): for model_class, memory_analyzer in _MEMORY_ANALYZERS.items():
if model_class in model_name: if model_class in model_name:
return memory_analyzer( return memory_analyzer(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment