"megatron/data/bert_dataset.py" did not exist on "f6a6811fdf4bed14569b4d9e664216a0acc9874c"
Unverified Commit 7c041ab5 authored by Woosuk Kwon's avatar Woosuk Kwon Committed by GitHub
Browse files

Refactor system architecture (#82)

parent 8917782a
import cacheflow.parallel_utils.parallel_state
import cacheflow.parallel_utils.tensor_parallel
import cacheflow.parallel_utils.utils
import cacheflow.model_executor.parallel_utils.parallel_state
import cacheflow.model_executor.parallel_utils.tensor_parallel
import cacheflow.model_executor.parallel_utils.utils
# Alias parallel_state as mpu, its legacy name
mpu = parallel_state
......
......@@ -9,7 +9,7 @@ import torch.nn.functional as F
import torch.nn.init as init
from torch.nn.parameter import Parameter
from cacheflow.parallel_utils.parallel_state import (
from cacheflow.model_executor.parallel_utils.parallel_state import (
get_tensor_model_parallel_rank,
get_tensor_model_parallel_world_size,
get_all_reduce_launcher,
......
......@@ -2,7 +2,7 @@
import torch
from cacheflow.parallel_utils.parallel_state import (
from cacheflow.model_executor.parallel_utils.parallel_state import (
get_tensor_model_parallel_rank,
get_tensor_model_parallel_world_size,
get_tensor_model_parallel_group,
......
......@@ -10,7 +10,7 @@ from torch import _C
from torch.cuda import _lazy_call, device as device_ctx_manager
from torch.utils.checkpoint import detach_variable
from cacheflow.parallel_utils.parallel_state import (
from cacheflow.model_executor.parallel_utils.parallel_state import (
get_data_parallel_rank,
get_tensor_model_parallel_group,
get_tensor_model_parallel_rank,
......@@ -22,7 +22,7 @@ from .utils import (
gather_split_1d_tensor,
)
from cacheflow.parallel_utils.utils import safely_set_viewless_tensor_data
from cacheflow.model_executor.parallel_utils.utils import safely_set_viewless_tensor_data
# Default name for the model parallel rng tracker.
_MODEL_PARALLEL_RNG_TRACKER_NAME = 'model-parallel-rng'
......
......@@ -3,8 +3,8 @@
import torch
from typing import List, Sequence
from cacheflow.parallel_utils.utils import divide
from cacheflow.parallel_utils import parallel_state
from cacheflow.model_executor.parallel_utils.utils import divide
from cacheflow.model_executor.parallel_utils import parallel_state
def split_tensor_along_last_dim(
tensor: torch.Tensor,
......
......@@ -4,7 +4,7 @@ import operator
import torch
from cacheflow.parallel_utils import parallel_state
from cacheflow.model_executor.parallel_utils import parallel_state
def ensure_divisibility(numerator, denominator):
......
import random
from typing import Union
import numpy as np
import torch
from cacheflow.model_executor.parallel_utils.parallel_state import model_parallel_is_initialized
from cacheflow.model_executor.parallel_utils.tensor_parallel import model_parallel_cuda_manual_seed
_STR_DTYPE_TO_TORCH_DTYPE = {
'half': torch.half,
'float': torch.float,
'float16': torch.float16,
'float32': torch.float32,
'bfloat16': torch.bfloat16,
}
def get_torch_dtype(dtype: Union[torch.dtype, str]) -> torch.dtype:
if isinstance(dtype, str):
torch_dtype = _STR_DTYPE_TO_TORCH_DTYPE[dtype.lower()]
else:
torch_dtype = dtype
return torch_dtype
def get_dtype_size(dtype: Union[torch.dtype, str]) -> int:
torch_dtype = get_torch_dtype(dtype)
return torch.tensor([], dtype=torch_dtype).element_size()
def set_random_seed(seed: int) -> None:
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
if torch.cuda.is_available():
torch.cuda.manual_seed_all(seed)
if model_parallel_is_initialized():
model_parallel_cuda_manual_seed(seed)
import os
import filelock
import glob
import json
import filelock
from typing import Union, Optional
import os
from typing import Iterator, List, Optional, Tuple
from huggingface_hub import snapshot_download
import numpy as np
import torch
from tqdm.auto import tqdm
from huggingface_hub import snapshot_download
from cacheflow.parallel_utils.parallel_state import (
get_tensor_model_parallel_rank)
_STR_DTYPE_TO_TORCH_DTYPE = {
'half': torch.half,
'float': torch.float,
'float16': torch.float16,
'float32': torch.float32,
'bfloat16': torch.bfloat16,
}
def get_torch_dtype(dtype: Union[torch.dtype, str]) -> torch.dtype:
if isinstance(dtype, str):
torch_dtype = _STR_DTYPE_TO_TORCH_DTYPE[dtype.lower()]
else:
torch_dtype = dtype
return torch_dtype
def get_dtype_size(dtype: Union[torch.dtype, str]) -> int:
torch_dtype = get_torch_dtype(dtype)
return torch.tensor([], dtype=torch_dtype).element_size()
class Disabledtqdm(tqdm):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs, disable=True)
def hf_model_weights_iterator(model_name_or_path: str,
cache_dir: Optional[str] = None,
use_np_cache: bool = False):
def hf_model_weights_iterator(
model_name_or_path: str,
cache_dir: Optional[str] = None,
use_np_cache: bool = False,
) -> Iterator[Tuple[str, torch.Tensor]]:
# Prepare file lock directory to prevent multiple processes from
# downloading the same model weights at the same time.
lock_dir = cache_dir if cache_dir is not None else "/tmp"
......@@ -95,10 +74,14 @@ def hf_model_weights_iterator(model_name_or_path: str,
yield name, param
def load_tensor_parallel_weights(param, loaded_weight, param_name,
column_parallel_weight_names,
row_parallel_weight_names):
tensor_model_parallel_rank = get_tensor_model_parallel_rank()
def load_tensor_parallel_weights(
param: torch.Tensor,
loaded_weight: torch.Tensor,
param_name: str,
column_parallel_weight_names: List[str],
row_parallel_weight_names: List[str],
tensor_model_parallel_rank: int,
) -> None:
for p in column_parallel_weight_names:
if p in param_name:
shard_size = param.shape[0]
......@@ -116,3 +99,12 @@ def load_tensor_parallel_weights(param, loaded_weight, param_name,
break
assert param.shape == loaded_weight.shape
param.data.copy_(loaded_weight)
def initialize_dummy_weights(
model: torch.nn.Module,
low: float = -1e-3,
high: float = 1e-3,
) -> None:
for param in model.state_dict().values():
param.data.uniform_(low, high)
from cacheflow.models.input_metadata import InputMetadata
from cacheflow.models.model_utils import get_memory_analyzer
from cacheflow.models.model_utils import get_model
__all__ = [
'InputMetadata',
'get_memory_analyzer',
'get_model',
]
import enum
import random
import psutil
import numpy as np
import psutil
import torch
from cacheflow.parallel_utils.parallel_state import model_parallel_is_initialized
from cacheflow.parallel_utils.tensor_parallel import model_parallel_cuda_manual_seed
class Device(enum.Enum):
GPU = enum.auto()
......@@ -28,17 +23,6 @@ class Counter:
self.counter = 0
def set_random_seed(seed: int):
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
if torch.cuda.is_available():
torch.cuda.manual_seed_all(seed)
if model_parallel_is_initialized():
model_parallel_cuda_manual_seed(seed)
def get_gpu_memory(gpu: int = 0) -> int:
return torch.cuda.get_device_properties(gpu).total_memory
......
......@@ -5,7 +5,7 @@ try:
except ImportError:
ray = None
from cacheflow.master.scheduler import Scheduler
from cacheflow.core.scheduler import Scheduler
from cacheflow.sequence import SequenceGroupInputs
from cacheflow.worker.worker import Worker
......
......@@ -2,18 +2,15 @@ from typing import Dict, List, Tuple, Optional
import torch
from cacheflow.models import get_model
from cacheflow.models import InputMetadata
from cacheflow.model_executor import get_model, InputMetadata, set_random_seed
from cacheflow.model_executor.parallel_utils.parallel_state import (
initialize_model_parallel,
initialize_all_reduce_launcher,
get_tensor_model_parallel_world_size)
from cacheflow.sampling_params import SamplingParams
from cacheflow.sequence import SequenceGroupInputs
from cacheflow.sequence import SequenceOutputs
from cacheflow.worker.cache_engine import CacheEngine
from cacheflow.parallel_utils.parallel_state import (
initialize_model_parallel,
initialize_all_reduce_launcher,
get_tensor_model_parallel_world_size)
from cacheflow.utils import set_random_seed
class Worker:
......
import argparse
import os
import pickle
from typing import Any, Dict, List, Optional, Tuple
import matplotlib.pyplot as plt
import numpy as np
SYSTEMS = [
'orca-constant',
'orca-power2',
'orca-oracle',
'cacheflow',
]
SYSTEM_TO_LABEL = {
'orca-constant': 'Orca (Max)',
'orca-power2': 'Orca (Pow2)',
'orca-oracle': 'Orca (Oracle)',
'cacheflow': 'KVFlow',
}
SYSTEM_TO_COLOR = {
'orca-constant': 'red',
'orca-power2': 'orange',
'orca-oracle': 'green',
'cacheflow': 'blue',
}
SYSTEM_TO_MARKER = {
'orca-constant': 'x',
'orca-power2': '^',
'orca-oracle': 's',
'cacheflow': 'o',
}
def get_results(save_dir: str) -> List[Dict[str, Any]]:
with open(os.path.join(save_dir, 'sequences.pkl'), 'rb') as f:
results = pickle.load(f)
return results
def get_request_rate(save_dir: str) -> float:
"""Get request rate from save_dir name."""
# Directory name format:
# .../req-rate-{req_rate}/seed-{seed}/duration-{duration}
save_dir = os.path.abspath(save_dir)
dir_names = save_dir.split('/')
request_rate = None
for dir_name in dir_names:
if dir_name.startswith('req-rate-'):
if request_rate is not None:
raise ValueError(f'Found multiple request rates in {save_dir}')
request_rate = float(dir_name.split('-')[-1])
if request_rate is None:
raise ValueError(f'Cannot find request rate in {save_dir}')
return request_rate
def get_model(save_dir: str) -> Tuple[str, int]:
save_dir = os.path.abspath(save_dir)
dir_names = save_dir.split('/')
model = None
for dir_name in dir_names:
if '-tp' in dir_name:
if model is not None:
raise ValueError(f'Found multiple models in {save_dir}')
model = dir_name.split('-tp')[0]
tp = int(dir_name.split('-tp')[-1])
if model is None:
raise ValueError(f'Cannot find model in {save_dir}')
return model, tp
def get_system(save_dir: str) -> str:
save_dir = os.path.abspath(save_dir)
dir_names = save_dir.split('/')
for dir_name in dir_names:
if dir_name.startswith('orca-'):
return dir_name
if dir_name == 'cacheflow':
return dir_name
raise ValueError(f'Cannot find system in {save_dir}')
def get_sampling(save_dir: str) -> str:
save_dir = os.path.abspath(save_dir)
dir_names = save_dir.split('/')
for dir_name in dir_names:
if dir_name.startswith('n'):
if dir_name.endswith('-beam'):
return dir_name
if dir_name[1:].isdigit():
return dir_name
raise ValueError(f'Cannot find sampling method in {save_dir}')
def plot_normalized_latency(
exp_dir: str,
duration: int,
seed: int,
warmup: int,
xlim: Optional[float],
ylim: Optional[float],
log_scale: bool,
format: str,
) -> None:
# Get leaf directories.
save_dirs = []
for root, dirs, files in os.walk(exp_dir):
if dirs:
continue
if 'sequences.pkl' not in files:
continue
if f'seed{seed}' not in root:
continue
if f'duration-{duration}' not in root:
continue
save_dirs.append(root)
# Plot normalized latency.
perf_per_system: Dict[str, Tuple[List[float], List[float]]] = {}
for save_dir in save_dirs:
per_seq_norm_latencies = []
results = get_results(save_dir)
for seq in results:
arrival_time = seq['arrival_time']
finish_time = seq['finish_time']
output_len = seq['output_len']
if arrival_time < warmup:
continue
latency = finish_time - arrival_time
norm_latency = latency / output_len
per_seq_norm_latencies.append(norm_latency)
request_rate = get_request_rate(save_dir)
normalized_latency = np.mean(per_seq_norm_latencies)
system_name = get_system(save_dir)
if system_name not in perf_per_system:
perf_per_system[system_name] = ([], [])
perf_per_system[system_name][0].append(request_rate)
perf_per_system[system_name][1].append(normalized_latency)
print('#seqs', len(per_seq_norm_latencies))
print(f'{save_dir}: {normalized_latency:.3f} s')
# Plot normalized latency.
plt.figure(figsize=(6, 4))
for system_name in reversed(SYSTEMS):
if system_name not in perf_per_system:
continue
# Sort by request rate.
request_rates, normalized_latencies = perf_per_system[system_name]
request_rates, normalized_latencies = zip(*sorted(zip(request_rates, normalized_latencies)))
label = SYSTEM_TO_LABEL[system_name]
color = SYSTEM_TO_COLOR[system_name]
marker = SYSTEM_TO_MARKER[system_name]
plt.plot(request_rates, normalized_latencies, label=label, color=color, marker=marker)
# plt.legend()
plt.xlabel('Request rate (req/s)', fontsize=12)
plt.ylabel('Normalized latency (s/token)', fontsize=12)
if log_scale:
plt.yscale('log')
if xlim is not None:
plt.xlim(left=0, right=xlim)
if ylim is not None:
if log_scale:
plt.ylim(top=ylim)
else:
plt.ylim(bottom=0, top=ylim)
handles, labels = plt.gca().get_legend_handles_labels()
handles = reversed(handles)
labels = reversed(labels)
plt.legend(
handles, labels,
ncol=4, fontsize=12, loc='upper center', bbox_to_anchor=(0.5, 1.15),
columnspacing=0.5, handletextpad=0.5, handlelength=1.5, frameon=False, borderpad=0)
# Save figure.
model, tp = get_model(exp_dir)
sampling = get_sampling(exp_dir)
figname = f'{model}-tp{tp}-{sampling}.{format}'
os.makedirs('./figures', exist_ok=True)
plt.savefig(os.path.join('figures', figname), bbox_inches='tight')
print(f'Saved figure to ./figures/{figname}')
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('exp_dir', type=str)
parser.add_argument('--duration', type=int, required=True)
parser.add_argument('--seed', type=int, default=0)
parser.add_argument('--warmup', type=int, default=60)
parser.add_argument('--xlim', type=float, required=False, default=None)
parser.add_argument('--ylim', type=float, required=False, default=None)
parser.add_argument('--log', action='store_true')
parser.add_argument('--format', choices=['png', 'pdf'], default='png')
args = parser.parse_args()
plot_normalized_latency(
args.exp_dir, args.duration, args.seed, args.warmup, args.xlim, args.ylim, args.log, args.format)
import os
import pickle
import matplotlib.pyplot as plt
STAT_NAMES = [
'input_lens',
'num_running',
'num_waiting',
'num_preemption',
'gpu_cache_usage',
'cpu_cache_usage',
'num_swapped',
'swap_in_lens',
'swap_out_lens',
]
def plot_stats(output_dir: str):
# Get stats.
with open(os.path.join(output_dir, 'stats.pkl'), 'rb') as f:
stats = pickle.load(f)
timestamps = stats['timestamps']
# Draw one figure for each stat.
num_stats = len(STAT_NAMES)
COLORS = ['b', 'g', 'r', 'c', 'm', 'y', 'k', 'orange', 'purple', 'pink', 'brown', 'gray']
fig, axs = plt.subplots(num_stats, 1, figsize=(10, 2 * num_stats))
for i, stat in enumerate(STAT_NAMES):
data = stats[stat]
if stat in ['gpu_cache_usage', 'cpu_cache_usage']:
data = [x * 100 for x in data]
stat = stat + ' (%)'
axs[i].plot(timestamps, data, color=COLORS[i % len(COLORS)])
axs[i].set_ylabel(stat.replace('_', ' '), fontdict={'fontsize': 12})
axs[i].set_ylim(bottom=0)
plt.xlabel('Time (s)')
plt.tight_layout()
fig_path = os.path.join(output_dir, 'stats.png')
plt.savefig(fig_path)
print(f'Saved stats to {fig_path}')
if __name__ == '__main__':
import argparse
parser = argparse.ArgumentParser()
parser.add_argument('output_dir', type=str, help='Output directory.')
args = parser.parse_args()
plot_stats(args.output_dir)
import argparse
from cacheflow.master.server import (
from cacheflow.core.server import (
add_server_arguments, process_server_arguments,
init_local_server_and_frontend_with_arguments)
from cacheflow.sampling_params import SamplingParams
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment