Commit 1ce13335 authored by Woosuk Kwon's avatar Woosuk Kwon
Browse files

Set default dtype to half

parent de0fabbc
from typing import Union
import torch
import torch.nn as nn
from cacheflow.models.opt import OPTForCausalLM
......@@ -6,9 +9,23 @@ MODEL_CLASSES = {
'opt': OPTForCausalLM,
}
STR_DTYPE_TO_TORCH_DTYPE = {
'half': torch.half,
'float': torch.float,
'float16': torch.float16,
'float32': torch.float32,
}
def get_model(model_name: str) -> nn.Module:
def get_model(
model_name: str,
dtype: Union[torch.dtype, str],
) -> nn.Module:
if isinstance(dtype, str):
torch_dtype = STR_DTYPE_TO_TORCH_DTYPE[dtype.lower()]
else:
torch_dtype = dtype
for model_class, model in MODEL_CLASSES.items():
if model_class in model_name:
return model.from_pretrained(model_name)
return model.from_pretrained(model_name, torch_dtype=torch_dtype)
raise ValueError(f'Invalid model name: {model_name}')
......@@ -14,6 +14,7 @@ class Controller:
block_size: int,
num_gpu_blocks: int,
num_cpu_blocks: int,
dtype: str = 'half',
) -> None:
self.node_id = node_id
self.num_workers = num_workers
......@@ -35,6 +36,7 @@ class Controller:
block_size=block_size,
num_gpu_blocks=num_gpu_blocks,
num_cpu_blocks=num_cpu_blocks,
dtype=dtype,
)
self.workers.append(worker)
......
......@@ -17,6 +17,7 @@ class Worker:
block_size: int,
num_gpu_blocks: int,
num_cpu_blocks: int,
dtype: str,
) -> None:
self.worker_id = worker_id
self.gpu_id = gpu_id
......@@ -26,7 +27,7 @@ class Worker:
# Initialize the model.
# FIXME(woosuk): This is a hack.
self.model = get_model(model_name).to(device=gpu_id)
self.model = get_model(model_name, dtype=dtype).to(device=self.device)
self.num_layers = self.model.config.num_hidden_layers
self.num_heads = self.model.config.num_attention_heads
self.head_size = self.model.config.hidden_size // self.num_heads
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment