Commit d359cda5 authored by Woosuk Kwon's avatar Woosuk Kwon
Browse files

Minor

parent 2f49f155
...@@ -7,7 +7,7 @@ from cacheflow.sequence import SequenceStatus ...@@ -7,7 +7,7 @@ from cacheflow.sequence import SequenceStatus
from cacheflow.utils import Device from cacheflow.utils import Device
class BlockManager: class BlockAllocator:
def __init__( def __init__(
self, self,
...@@ -65,8 +65,8 @@ class BlockSpaceManager: ...@@ -65,8 +65,8 @@ class BlockSpaceManager:
self.num_total_gpu_blocks = num_gpu_blocks self.num_total_gpu_blocks = num_gpu_blocks
self.num_total_cpu_blocks = num_cpu_blocks self.num_total_cpu_blocks = num_cpu_blocks
self.gpu_allocator = BlockManager(Device.GPU, block_size, num_gpu_blocks) self.gpu_allocator = BlockAllocator(Device.GPU, block_size, num_gpu_blocks)
self.cpu_allocator = BlockManager(Device.CPU, block_size, num_cpu_blocks) self.cpu_allocator = BlockAllocator(Device.CPU, block_size, num_cpu_blocks)
# Mapping: seq_id -> BlockTable. # Mapping: seq_id -> BlockTable.
self.block_tables: Dict[int, BlockTable] = {} self.block_tables: Dict[int, BlockTable] = {}
......
...@@ -8,6 +8,7 @@ from cacheflow.sampling_params import SamplingParams ...@@ -8,6 +8,7 @@ from cacheflow.sampling_params import SamplingParams
from cacheflow.sequence import SequenceOutputs from cacheflow.sequence import SequenceOutputs
from cacheflow.parallel_utils.tensor_parallel import gather_from_tensor_model_parallel_region from cacheflow.parallel_utils.tensor_parallel import gather_from_tensor_model_parallel_region
class Sampler(nn.Module): class Sampler(nn.Module):
def __init__(self) -> None: def __init__(self) -> None:
......
...@@ -30,7 +30,7 @@ class Sequence: ...@@ -30,7 +30,7 @@ class Sequence:
self.status = SequenceStatus.PENDING self.status = SequenceStatus.PENDING
self.output_logprobs: List[Dict[int, float]] = [] self.output_logprobs: List[Dict[int, float]] = []
self.cumulative_logprobs = 1.0 self.cumulative_logprobs = 0.0
def add_block(self) -> None: def add_block(self) -> None:
block = LogicalTokenBlock( block = LogicalTokenBlock(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment