"vscode:/vscode.git/clone" did not exist on "03f94a698578031278f9c3c588a1dcd82b80d7b8"
Commit 8290fce4 authored by Woosuk Kwon's avatar Woosuk Kwon
Browse files

Add Worker class

parent 7b6844e5
from typing import Dict, List, Tuple
import torch
from cacheflow.models import get_model
from cacheflow.models import InputMetadata
from cacheflow.worker.cache_engine import CacheEngine
class Worker:
def __init__(
self,
worker_id: int,
gpu_id: int,
model_name: str,
block_size: int,
num_gpu_blocks: int,
num_cpu_blocks: int,
) -> None:
self.worker_id = worker_id
self.gpu_id = gpu_id
self.block_size = block_size
self.device = torch.device('cuda', index=gpu_id)
# Initialize the model.
# FIXME(woosuk): This is a hack.
self.model = get_model(model_name).to(device=gpu_id)
self.num_layers = self.model.config.num_hidden_layers
self.num_heads = self.model.config.num_attention_heads
self.head_size = self.model.config.hidden_size // self.num_heads
self.dtype = self.model.dtype
self.cache_engine = CacheEngine(
worker_id=worker_id,
gpu_id=gpu_id,
num_layers=self.num_layers,
num_heads=self.num_heads,
head_size=self.head_size,
block_size=block_size,
num_gpu_blocks=num_gpu_blocks,
num_cpu_blocks=num_cpu_blocks,
dtype=self.dtype,
)
self.cache_events = self.cache_engine.events
self.gpu_cache = self.cache_engine.gpu_cache
def prepare_inputs(
self,
prompt_tokens: Dict[int, List[int]], # Seq id -> List of input token ids.
generation_tokens: Dict[int, int], # Seq id -> Input token id.
context_lens: Dict[int, int], # Seq id -> Number of tokens participating in attention.
block_tables: Dict[int, List[int]], # Seq id -> List of physical block numbers.
) -> Tuple[torch.LongTensor, torch.LongTensor, InputMetadata]:
# TODO(woosuk): Support interactive generation.
# Add the prompt tokens.
prompt_lens: List[int] = []
input_tokens: List[int] = []
input_positions: List[int] = []
slot_mapping: List[int] = []
prompt_seq_ids = sorted(prompt_tokens.keys())
for seq_id in prompt_seq_ids:
prompt_len = len(prompt_tokens[seq_id])
prompt_lens.append(prompt_len)
input_tokens.extend(prompt_tokens[seq_id])
input_positions.extend(range(len(prompt_tokens[seq_id])))
block_table = block_tables[seq_id]
for i in range(prompt_len):
block_number = block_table[i // self.block_size]
block_offset = i % self.block_size
slot = block_number * self.block_size + block_offset
slot_mapping.append(slot)
# Add the generation tokens.
max_context_len = 0
max_num_blocks_per_seq = 0
generation_block_tables: List[List[int]] = []
generation_seq_ids = sorted(generation_tokens.keys())
for seq_id in generation_seq_ids:
input_tokens.append(generation_tokens[seq_id])
input_positions.append(context_lens[seq_id] - 1)
generation_block_tables.append(block_tables[seq_id])
max_context_len = max(max_context_len, context_lens[seq_id])
max_num_blocks_per_seq = max(
max_num_blocks_per_seq, len(block_tables[seq_id]))
# Optimization: Pad the input length to be a multiple of 8.
# This is required for utilizing the Tensor Cores in NVIDIA GPUs.
input_tokens = _pad_to_alignment(input_tokens, multiple_of=8)
input_positions = _pad_to_alignment(input_positions, multiple_of=8)
# Convert to tensors.
tokens_tensor = torch.tensor(
input_tokens, dtype=torch.long, device=self.device)
positions_tensor = torch.tensor(
input_positions, dtype=torch.long, device=self.device)
slot_mapping_tensor = torch.tensor(
slot_mapping, dtype=torch.int, device=self.device)
context_lens_tensor = torch.tensor(
[context_lens[seq_id] for seq_id in generation_seq_ids],
dtype=torch.int, device=self.device)
block_tables_tensor = torch.tensor(
[_pad_to_max(block_table) for block_table in generation_block_tables],
dtype=int, device=self.device)
input_metadata = InputMetadata(
prompt_lens=prompt_lens,
slot_mapping=slot_mapping_tensor,
context_lens=context_lens_tensor,
max_context_len=max_context_len,
block_tables=block_tables_tensor,
)
return tokens_tensor, positions_tensor, input_metadata
@torch.inference_mode()
def execute_stage(
self,
prompt_tokens: Dict[int, List[int]], # Seq id -> List of input token ids.
generation_tokens: Dict[int, int], # Seq id -> Input token id.
context_lens: Dict[int, int], # Seq id -> Number of tokens participating in attention.
block_tables: Dict[int, List[int]], # Seq id -> List of physical block numbers.
blocks_to_swap_in: Dict[int, int],
blocks_to_swap_out: Dict[int, int],
blocks_to_copy: Dict[int, int],
) -> torch.Tensor:
# Issue cache operations.
command_issued = False
if blocks_to_swap_in:
self.cache_engine.swap_in(blocks_to_swap_in)
command_issued = True
if blocks_to_swap_out:
self.cache_engine.swap_out(blocks_to_swap_out)
command_issued = True
if blocks_to_copy:
self.cache_engine.copy(blocks_to_copy)
command_issued = True
if command_issued:
cache_events = self.cache_events
else:
cache_events = None
# Prepare input tensors.
input_tokens, input_positions, input_metadata = self.prepare_inputs(
prompt_tokens, generation_tokens, context_lens, block_tables)
# Execute the model.
output = self.model(
input_ids=input_tokens,
positions=input_positions,
input_metadata=input_metadata,
cache_events=cache_events,
)
return output
def _pad_to_alignment(x: List[int], multiple_of: int) -> List[int]:
return x + [0] * ((-len(x)) % multiple_of)
def _pad_to_max(x: List[int], max_len: int) -> List[int]:
return x + [0] * (max_len - len(x))
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment