Commit 1ce13335 authored by Woosuk Kwon's avatar Woosuk Kwon
Browse files

Set default dtype to half

parent de0fabbc
from typing import Union
import torch
import torch.nn as nn import torch.nn as nn
from cacheflow.models.opt import OPTForCausalLM from cacheflow.models.opt import OPTForCausalLM
...@@ -6,9 +9,23 @@ MODEL_CLASSES = { ...@@ -6,9 +9,23 @@ MODEL_CLASSES = {
'opt': OPTForCausalLM, 'opt': OPTForCausalLM,
} }
STR_DTYPE_TO_TORCH_DTYPE = {
'half': torch.half,
'float': torch.float,
'float16': torch.float16,
'float32': torch.float32,
}
def get_model(model_name: str) -> nn.Module: def get_model(
model_name: str,
dtype: Union[torch.dtype, str],
) -> nn.Module:
if isinstance(dtype, str):
torch_dtype = STR_DTYPE_TO_TORCH_DTYPE[dtype.lower()]
else:
torch_dtype = dtype
for model_class, model in MODEL_CLASSES.items(): for model_class, model in MODEL_CLASSES.items():
if model_class in model_name: if model_class in model_name:
return model.from_pretrained(model_name) return model.from_pretrained(model_name, torch_dtype=torch_dtype)
raise ValueError(f'Invalid model name: {model_name}') raise ValueError(f'Invalid model name: {model_name}')
...@@ -14,6 +14,7 @@ class Controller: ...@@ -14,6 +14,7 @@ class Controller:
block_size: int, block_size: int,
num_gpu_blocks: int, num_gpu_blocks: int,
num_cpu_blocks: int, num_cpu_blocks: int,
dtype: str = 'half',
) -> None: ) -> None:
self.node_id = node_id self.node_id = node_id
self.num_workers = num_workers self.num_workers = num_workers
...@@ -35,6 +36,7 @@ class Controller: ...@@ -35,6 +36,7 @@ class Controller:
block_size=block_size, block_size=block_size,
num_gpu_blocks=num_gpu_blocks, num_gpu_blocks=num_gpu_blocks,
num_cpu_blocks=num_cpu_blocks, num_cpu_blocks=num_cpu_blocks,
dtype=dtype,
) )
self.workers.append(worker) self.workers.append(worker)
......
...@@ -17,6 +17,7 @@ class Worker: ...@@ -17,6 +17,7 @@ class Worker:
block_size: int, block_size: int,
num_gpu_blocks: int, num_gpu_blocks: int,
num_cpu_blocks: int, num_cpu_blocks: int,
dtype: str,
) -> None: ) -> None:
self.worker_id = worker_id self.worker_id = worker_id
self.gpu_id = gpu_id self.gpu_id = gpu_id
...@@ -26,7 +27,7 @@ class Worker: ...@@ -26,7 +27,7 @@ class Worker:
# Initialize the model. # Initialize the model.
# FIXME(woosuk): This is a hack. # FIXME(woosuk): This is a hack.
self.model = get_model(model_name).to(device=gpu_id) self.model = get_model(model_name, dtype=dtype).to(device=self.device)
self.num_layers = self.model.config.num_hidden_layers self.num_layers = self.model.config.num_hidden_layers
self.num_heads = self.model.config.num_attention_heads self.num_heads = self.model.config.num_attention_heads
self.head_size = self.model.config.hidden_size // self.num_heads self.head_size = self.model.config.hidden_size // self.num_heads
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment