Unverified Commit 7c041ab5 authored by Woosuk Kwon's avatar Woosuk Kwon Committed by GitHub
Browse files

Refactor system architecture (#82)

parent 8917782a
...@@ -4,8 +4,8 @@ import pickle ...@@ -4,8 +4,8 @@ import pickle
import time import time
from typing import Any, Dict, List, Optional, Tuple from typing import Any, Dict, List, Optional, Tuple
from cacheflow.master.block_manager import BlockSpaceManager from cacheflow.core.block_manager import BlockSpaceManager
from cacheflow.master.policy import PolicyFactory from cacheflow.core.policy import PolicyFactory
from cacheflow.sampling_params import SamplingParams from cacheflow.sampling_params import SamplingParams
from cacheflow.sequence import Sequence from cacheflow.sequence import Sequence
from cacheflow.sequence import SequenceGroup from cacheflow.sequence import SequenceGroup
......
...@@ -8,20 +8,21 @@ try: ...@@ -8,20 +8,21 @@ try:
except ImportError: except ImportError:
ray = None ray = None
from cacheflow.core.scheduler import Scheduler
from cacheflow.frontend.simple_frontend import SimpleFrontend
from cacheflow.logger import init_logger from cacheflow.logger import init_logger
from cacheflow.master.scheduler import Scheduler from cacheflow.model_executor import get_memory_analyzer
from cacheflow.master.simple_frontend import SimpleFrontend
from cacheflow.models import get_memory_analyzer
from cacheflow.worker.controller import Controller, DeviceID
from cacheflow.sequence import SequenceGroup from cacheflow.sequence import SequenceGroup
from cacheflow.sampling_params import SamplingParams from cacheflow.sampling_params import SamplingParams
from cacheflow.utils import get_gpu_memory, get_cpu_memory from cacheflow.utils import get_gpu_memory, get_cpu_memory
from cacheflow.worker.controller import Controller, DeviceID
logger = init_logger(__name__) logger = init_logger(__name__)
class Server: class Server:
def __init__( def __init__(
self, self,
model: str, model: str,
......
import argparse import argparse
import asyncio import asyncio
import json
import time import time
from typing import List, Dict, Optional from typing import List, Dict, Optional
import json
import ray
from transformers import AutoTokenizer
from fastapi import FastAPI, Request from fastapi import FastAPI, Request
from fastapi.responses import StreamingResponse from fastapi.responses import StreamingResponse
import ray
from transformers import AutoTokenizer
import uvicorn import uvicorn
from cacheflow.core.server import (Server, add_server_arguments,
process_server_arguments,
initialize_cluster)
from cacheflow.sampling_params import SamplingParams from cacheflow.sampling_params import SamplingParams
from cacheflow.sequence import Sequence, SequenceGroup from cacheflow.sequence import Sequence, SequenceGroup
from cacheflow.master.server import (Server, add_server_arguments,
process_server_arguments,
initialize_cluster)
from cacheflow.worker.controller import DeviceID
from cacheflow.utils import Counter, get_gpu_memory, get_cpu_memory from cacheflow.utils import Counter, get_gpu_memory, get_cpu_memory
from cacheflow.worker.controller import DeviceID
TIMEOUT_TO_PREVENT_DEADLOCK = 1 # seconds TIMEOUT_TO_PREVENT_DEADLOCK = 1 # seconds
app = FastAPI() app = FastAPI()
......
from cacheflow.model_executor.input_metadata import InputMetadata
from cacheflow.model_executor.model_loader import get_model, get_memory_analyzer
from cacheflow.model_executor.utils import set_random_seed
__all__ = [
"InputMetadata",
"get_model",
"get_memory_analyzer",
"set_random_seed",
]
...@@ -7,7 +7,7 @@ from xformers import ops as xops ...@@ -7,7 +7,7 @@ from xformers import ops as xops
from cacheflow import attention_ops from cacheflow import attention_ops
from cacheflow import cache_ops from cacheflow import cache_ops
from cacheflow import pos_encoding_ops from cacheflow import pos_encoding_ops
from cacheflow.models import InputMetadata from cacheflow.model_executor.input_metadata import InputMetadata
class GPTCacheFlowAttention(nn.Module): class GPTCacheFlowAttention(nn.Module):
......
...@@ -3,10 +3,11 @@ from typing import Dict, List, Tuple ...@@ -3,10 +3,11 @@ from typing import Dict, List, Tuple
import torch import torch
import torch.nn as nn import torch.nn as nn
from cacheflow.models import InputMetadata from cacheflow.model_executor.input_metadata import InputMetadata
from cacheflow.model_executor.parallel_utils.tensor_parallel import (
gather_from_tensor_model_parallel_region)
from cacheflow.sampling_params import SamplingParams from cacheflow.sampling_params import SamplingParams
from cacheflow.sequence import SequenceOutputs from cacheflow.sequence import SequenceOutputs
from cacheflow.parallel_utils.tensor_parallel import gather_from_tensor_model_parallel_region
class Sampler(nn.Module): class Sampler(nn.Module):
...@@ -27,7 +28,7 @@ class Sampler(nn.Module): ...@@ -27,7 +28,7 @@ class Sampler(nn.Module):
# Get the logits for the next tokens. # Get the logits for the next tokens.
logits = torch.matmul(hidden_states, embedding.t()) logits = torch.matmul(hidden_states, embedding.t())
logits = gather_from_tensor_model_parallel_region(logits) logits = gather_from_tensor_model_parallel_region(logits)
# Remove paddings in vocab. # Remove paddings in vocab (if any).
logits = logits[:, :self.vocab_size] logits = logits[:, :self.vocab_size]
# Apply temperature scaling. # Apply temperature scaling.
......
...@@ -2,7 +2,7 @@ import torch ...@@ -2,7 +2,7 @@ import torch
from transformers import AutoConfig from transformers import AutoConfig
from cacheflow.logger import init_logger from cacheflow.logger import init_logger
from cacheflow.models.utils import get_dtype_size from cacheflow.model_executor.utils import get_dtype_size
logger = init_logger(__name__) logger = init_logger(__name__)
......
...@@ -5,16 +5,13 @@ import torch.nn as nn ...@@ -5,16 +5,13 @@ import torch.nn as nn
from transformers import AutoConfig from transformers import AutoConfig
from transformers import PretrainedConfig from transformers import PretrainedConfig
from cacheflow.models.memory_analyzer import CacheFlowMemoryAnalyzer from cacheflow.model_executor.memory_analyzer import (
from cacheflow.models.memory_analyzer import GPT2MemoryAnalyzer CacheFlowMemoryAnalyzer, GPT2MemoryAnalyzer, GPTNeoXMemoryAnalyzer,
from cacheflow.models.memory_analyzer import GPTNeoXMemoryAnalyzer LlamaMemoryAnalyzer, OPTMemoryAnalyzer)
from cacheflow.models.memory_analyzer import LlamaMemoryAnalyzer from cacheflow.model_executor.models import (
from cacheflow.models.memory_analyzer import OPTMemoryAnalyzer GPT2LMHeadModel, GPTNeoXForCausalLM, LlamaForCausalLM, OPTForCausalLM)
from cacheflow.models.gpt2 import GPT2LMHeadModel from cacheflow.model_executor.utils import get_torch_dtype
from cacheflow.models.gpt_neox import GPTNeoXForCausalLM from cacheflow.model_executor.weight_utils import initialize_dummy_weights
from cacheflow.models.llama import LlamaForCausalLM
from cacheflow.models.opt import OPTForCausalLM
from cacheflow.models.utils import get_torch_dtype
_MODELS = { _MODELS = {
...@@ -77,7 +74,7 @@ def get_model( ...@@ -77,7 +74,7 @@ def get_model(
model = model.cuda() model = model.cuda()
# NOTE(woosuk): For precise performance evaluation, we assign # NOTE(woosuk): For precise performance evaluation, we assign
# random values to the weights. # random values to the weights.
model.initialize_dummy_weights() initialize_dummy_weights(model)
else: else:
# Create a model instance. # Create a model instance.
model = model_class(config) model = model_class(config)
......
from cacheflow.model_executor.models.gpt_neox import GPTNeoXForCausalLM
from cacheflow.model_executor.models.gpt2 import GPT2LMHeadModel
from cacheflow.model_executor.models.llama import LlamaForCausalLM
from cacheflow.model_executor.models.opt import OPTForCausalLM
__all__ = [
"GPT2LMHeadModel",
"GPTNeoXForCausalLM",
"LlamaForCausalLM",
"OPTForCausalLM",
]
...@@ -5,16 +5,15 @@ import torch ...@@ -5,16 +5,15 @@ import torch
from torch import nn from torch import nn
from transformers import GPT2Config from transformers import GPT2Config
from cacheflow.models import InputMetadata from cacheflow.model_executor.input_metadata import InputMetadata
from cacheflow.models.attention import GPTCacheFlowAttention from cacheflow.model_executor.layers.attention import GPTCacheFlowAttention
from cacheflow.models.sample import Sampler from cacheflow.model_executor.layers.sampler import Sampler
from cacheflow.models.utils import (hf_model_weights_iterator, from cacheflow.model_executor.weight_utils import (hf_model_weights_iterator,
load_tensor_parallel_weights) load_tensor_parallel_weights)
from cacheflow.parallel_utils.parallel_state import ( from cacheflow.model_executor.parallel_utils.parallel_state import (
get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size) get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size)
from cacheflow.parallel_utils.tensor_parallel import (VocabParallelEmbedding, from cacheflow.model_executor.parallel_utils.tensor_parallel import (
ColumnParallelLinear, VocabParallelEmbedding, ColumnParallelLinear, RowParallelLinear)
RowParallelLinear)
from cacheflow.sequence import SequenceOutputs from cacheflow.sequence import SequenceOutputs
KVCache = Tuple[torch.Tensor, torch.Tensor] KVCache = Tuple[torch.Tensor, torch.Tensor]
...@@ -258,8 +257,5 @@ class GPT2LMHeadModel(nn.Module): ...@@ -258,8 +257,5 @@ class GPT2LMHeadModel(nn.Module):
raise ValueError(f"Unexpected parameter name {name}") raise ValueError(f"Unexpected parameter name {name}")
load_tensor_parallel_weights(param, loaded_weight, name, load_tensor_parallel_weights(param, loaded_weight, name,
self._column_parallel_weights, self._column_parallel_weights,
self._row_parallel_weights) self._row_parallel_weights,
tensor_model_parallel_rank)
def initialize_dummy_weights(self) -> None:
for param in self.state_dict().values():
param.data.uniform_(-1e-3, 1e-3)
...@@ -3,17 +3,17 @@ from typing import Dict, List, Optional, Tuple ...@@ -3,17 +3,17 @@ from typing import Dict, List, Optional, Tuple
import torch import torch
from torch import nn from torch import nn
from transformers import GPTNeoXConfig
from cacheflow.models import InputMetadata
from cacheflow.models.attention import GPTNeoXCacheFlowAttention from cacheflow.model_executor.input_metadata import InputMetadata
from cacheflow.models.sample import Sampler from cacheflow.model_executor.layers.attention import GPTNeoXCacheFlowAttention
from cacheflow.models.utils import (hf_model_weights_iterator, from cacheflow.model_executor.layers.sampler import Sampler
load_tensor_parallel_weights) from cacheflow.model_executor.weight_utils import (hf_model_weights_iterator,
from cacheflow.parallel_utils.parallel_state import ( load_tensor_parallel_weights)
from cacheflow.model_executor.parallel_utils.parallel_state import (
get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size) get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size)
from cacheflow.parallel_utils.tensor_parallel import (VocabParallelEmbedding, from cacheflow.model_executor.parallel_utils.tensor_parallel import (
ColumnParallelLinear, VocabParallelEmbedding, ColumnParallelLinear, RowParallelLinear)
RowParallelLinear)
from cacheflow.sequence import SequenceOutputs from cacheflow.sequence import SequenceOutputs
KVCache = Tuple[torch.Tensor, torch.Tensor] KVCache = Tuple[torch.Tensor, torch.Tensor]
...@@ -21,7 +21,7 @@ KVCache = Tuple[torch.Tensor, torch.Tensor] ...@@ -21,7 +21,7 @@ KVCache = Tuple[torch.Tensor, torch.Tensor]
class GPTNeoXAttention(nn.Module): class GPTNeoXAttention(nn.Module):
def __init__(self, config): def __init__(self, config: GPTNeoXConfig):
super().__init__() super().__init__()
self.total_num_heads = config.num_attention_heads self.total_num_heads = config.num_attention_heads
self.hidden_size = config.hidden_size self.hidden_size = config.hidden_size
...@@ -63,7 +63,7 @@ class GPTNeoXAttention(nn.Module): ...@@ -63,7 +63,7 @@ class GPTNeoXAttention(nn.Module):
class GPTNeoXMLP(nn.Module): class GPTNeoXMLP(nn.Module):
def __init__(self, config): def __init__(self, config: GPTNeoXConfig):
super().__init__() super().__init__()
self.dense_h_to_4h = ColumnParallelLinear(config.hidden_size, self.dense_h_to_4h = ColumnParallelLinear(config.hidden_size,
config.intermediate_size, config.intermediate_size,
...@@ -86,7 +86,7 @@ class GPTNeoXMLP(nn.Module): ...@@ -86,7 +86,7 @@ class GPTNeoXMLP(nn.Module):
class GPTNeoXLayer(nn.Module): class GPTNeoXLayer(nn.Module):
def __init__(self, config): def __init__(self, config: GPTNeoXConfig):
super().__init__() super().__init__()
self.use_parallel_residual = config.use_parallel_residual self.use_parallel_residual = config.use_parallel_residual
self.input_layernorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) self.input_layernorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
...@@ -129,7 +129,7 @@ class GPTNeoXLayer(nn.Module): ...@@ -129,7 +129,7 @@ class GPTNeoXLayer(nn.Module):
class GPTNeoXModel(nn.Module): class GPTNeoXModel(nn.Module):
def __init__(self, config): def __init__(self, config: GPTNeoXConfig):
super().__init__() super().__init__()
self.config = config self.config = config
...@@ -227,8 +227,5 @@ class GPTNeoXForCausalLM(nn.Module): ...@@ -227,8 +227,5 @@ class GPTNeoXForCausalLM(nn.Module):
raise ValueError(f"Unexpected weight name: {name}") raise ValueError(f"Unexpected weight name: {name}")
load_tensor_parallel_weights(param, loaded_weight, name, load_tensor_parallel_weights(param, loaded_weight, name,
self._column_parallel_weights, self._column_parallel_weights,
self._row_parallel_weights) self._row_parallel_weights,
tensor_model_parallel_rank)
def initialize_dummy_weights(self) -> None:
for param in self.state_dict().values():
param.data.uniform_(-1e-3, 1e-3)
...@@ -5,18 +5,18 @@ import torch ...@@ -5,18 +5,18 @@ import torch
from torch import nn from torch import nn
from transformers import LlamaConfig from transformers import LlamaConfig
from cacheflow.models import InputMetadata from cacheflow.sequence import SequenceOutputs
from cacheflow.models.activation import SiluAndMul from cacheflow.model_executor.input_metadata import InputMetadata
from cacheflow.models.attention import GPTNeoXCacheFlowAttention from cacheflow.model_executor.layers.activation import SiluAndMul
from cacheflow.models.layernorm import RMSNorm from cacheflow.model_executor.layers.layernorm import RMSNorm
from cacheflow.models.sample import Sampler from cacheflow.model_executor.layers.attention import GPTNeoXCacheFlowAttention
from cacheflow.models.utils import (hf_model_weights_iterator, from cacheflow.model_executor.layers.sampler import Sampler
load_tensor_parallel_weights) from cacheflow.model_executor.weight_utils import (hf_model_weights_iterator,
from cacheflow.parallel_utils.parallel_state import ( load_tensor_parallel_weights)
from cacheflow.model_executor.parallel_utils.parallel_state import (
get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size) get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size)
from cacheflow.parallel_utils.tensor_parallel import (VocabParallelEmbedding, from cacheflow.model_executor.parallel_utils.tensor_parallel import (
ColumnParallelLinear, VocabParallelEmbedding, ColumnParallelLinear, RowParallelLinear)
RowParallelLinear)
from cacheflow.sequence import SequenceOutputs from cacheflow.sequence import SequenceOutputs
KVCache = Tuple[torch.Tensor, torch.Tensor] KVCache = Tuple[torch.Tensor, torch.Tensor]
...@@ -263,8 +263,5 @@ class LlamaForCausalLM(nn.Module): ...@@ -263,8 +263,5 @@ class LlamaForCausalLM(nn.Module):
param = state_dict[name] param = state_dict[name]
load_tensor_parallel_weights(param, loaded_weight, name, load_tensor_parallel_weights(param, loaded_weight, name,
self._column_parallel_weights, self._column_parallel_weights,
self._row_parallel_weights) self._row_parallel_weights,
tensor_model_parallel_rank)
def initialize_dummy_weights(self) -> None:
for param in self.state_dict().values():
param.data.uniform_(-1e-3, 1e-3)
...@@ -5,16 +5,15 @@ import torch ...@@ -5,16 +5,15 @@ import torch
from torch import nn from torch import nn
from transformers import OPTConfig from transformers import OPTConfig
from cacheflow.models import InputMetadata from cacheflow.model_executor.input_metadata import InputMetadata
from cacheflow.models.attention import GPTCacheFlowAttention from cacheflow.model_executor.layers.attention import GPTCacheFlowAttention
from cacheflow.models.sample import Sampler from cacheflow.model_executor.layers.sampler import Sampler
from cacheflow.models.utils import (hf_model_weights_iterator, from cacheflow.model_executor.weight_utils import (hf_model_weights_iterator,
load_tensor_parallel_weights) load_tensor_parallel_weights)
from cacheflow.parallel_utils.parallel_state import ( from cacheflow.model_executor.parallel_utils.parallel_state import (
get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size) get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size)
from cacheflow.parallel_utils.tensor_parallel import (VocabParallelEmbedding, from cacheflow.model_executor.parallel_utils.tensor_parallel import (
ColumnParallelLinear, VocabParallelEmbedding, ColumnParallelLinear, RowParallelLinear)
RowParallelLinear)
from cacheflow.sequence import SequenceOutputs from cacheflow.sequence import SequenceOutputs
KVCache = Tuple[torch.Tensor, torch.Tensor] KVCache = Tuple[torch.Tensor, torch.Tensor]
...@@ -288,8 +287,5 @@ class OPTForCausalLM(nn.Module): ...@@ -288,8 +287,5 @@ class OPTForCausalLM(nn.Module):
param = state_dict[name] param = state_dict[name]
load_tensor_parallel_weights(param, loaded_weight, name, load_tensor_parallel_weights(param, loaded_weight, name,
self._column_parallel_weights, self._column_parallel_weights,
self._row_parallel_weights) self._row_parallel_weights,
tensor_model_parallel_rank)
def initialize_dummy_weights(self) -> None:
for param in self.state_dict().values():
param.data.uniform_(-1e-3, 1e-3)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment