"src/array/cuda/spmm.hip" did not exist on "272cb9e29aaa2bb3ee6eb31003530a537c0bee3d"
Unverified Commit e12358dc authored by Lianmin Zheng's avatar Lianmin Zheng Committed by GitHub
Browse files

Simplify the usage of device (#1734)

parent 554fbf93
...@@ -425,7 +425,6 @@ class ScheduleBatch: ...@@ -425,7 +425,6 @@ class ScheduleBatch:
req_pool_indices: torch.Tensor = None req_pool_indices: torch.Tensor = None
seq_lens: torch.Tensor = None seq_lens: torch.Tensor = None
out_cache_loc: torch.Tensor = None out_cache_loc: torch.Tensor = None
output_ids: torch.Tensor = None output_ids: torch.Tensor = None
# For processing logprobs # For processing logprobs
...@@ -442,27 +441,23 @@ class ScheduleBatch: ...@@ -442,27 +441,23 @@ class ScheduleBatch:
# Stream # Stream
has_stream: bool = False has_stream: bool = False
# device
device: str = "cuda"
# Has regex # Has regex
has_regex: bool = False has_regex: bool = False
# device
device: str = "cuda"
@classmethod @classmethod
def init_new(cls, reqs, req_to_token_pool, token_to_kv_pool, tree_cache): def init_new(cls, reqs, req_to_token_pool, token_to_kv_pool, tree_cache):
return_logprob = any(req.return_logprob for req in reqs)
has_stream = any(req.stream for req in reqs)
has_regex = any(req.regex_fsm for req in reqs)
return cls( return cls(
reqs=reqs, reqs=reqs,
req_to_token_pool=req_to_token_pool, req_to_token_pool=req_to_token_pool,
token_to_kv_pool=token_to_kv_pool, token_to_kv_pool=token_to_kv_pool,
tree_cache=tree_cache, tree_cache=tree_cache,
return_logprob=return_logprob, return_logprob=any(req.return_logprob for req in reqs),
has_stream=has_stream, has_stream=any(req.stream for req in reqs),
has_regex=any(req.regex_fsm for req in reqs),
device=req_to_token_pool.device, device=req_to_token_pool.device,
has_regex=has_regex,
) )
def batch_size(self): def batch_size(self):
...@@ -754,7 +749,7 @@ class ScheduleBatch: ...@@ -754,7 +749,7 @@ class ScheduleBatch:
return jump_forward_reqs return jump_forward_reqs
def prepare_for_decode(self): def prepare_for_decode(self, enable_overlap: bool = False):
self.forward_mode = ForwardMode.DECODE self.forward_mode = ForwardMode.DECODE
self.input_ids = self.output_ids self.input_ids = self.output_ids
...@@ -767,10 +762,19 @@ class ScheduleBatch: ...@@ -767,10 +762,19 @@ class ScheduleBatch:
# Alloc mem # Alloc mem
bs = len(self.reqs) bs = len(self.reqs)
self.out_cache_loc = self.alloc_token_slots(bs) self.out_cache_loc = self.alloc_token_slots(bs)
self.req_to_token_pool.write(
(self.req_pool_indices, self.seq_lens), self.out_cache_loc if enable_overlap:
) # Do not use in-place operations in the overlap mode
self.seq_lens.add_(1) self.req_to_token_pool.write(
(self.req_pool_indices, self.seq_lens), self.out_cache_loc
)
self.seq_lens = self.seq_lens + 1
else:
# A faster in-place version
self.req_to_token_pool.write(
(self.req_pool_indices, self.seq_lens), self.out_cache_loc
)
self.seq_lens.add_(1)
def filter_batch( def filter_batch(
self, self,
...@@ -882,6 +886,7 @@ class ScheduleBatch: ...@@ -882,6 +886,7 @@ class ScheduleBatch:
) )
def copy(self): def copy(self):
# Only contain fields that will be used by process_batch_result
return ScheduleBatch( return ScheduleBatch(
reqs=self.reqs, reqs=self.reqs,
forward_mode=self.forward_mode, forward_mode=self.forward_mode,
...@@ -940,9 +945,9 @@ class ModelWorkerBatch: ...@@ -940,9 +945,9 @@ class ModelWorkerBatch:
return ModelWorkerBatch( return ModelWorkerBatch(
bid=self.bid, bid=self.bid,
forward_mode=self.forward_mode, forward_mode=self.forward_mode,
input_ids=self.input_ids.clone(), input_ids=self.input_ids,
req_pool_indices=self.req_pool_indices, req_pool_indices=self.req_pool_indices,
seq_lens=self.seq_lens.clone(), seq_lens=self.seq_lens,
out_cache_loc=self.out_cache_loc, out_cache_loc=self.out_cache_loc,
req_to_token_pool_records=self.req_to_token_pool_records, req_to_token_pool_records=self.req_to_token_pool_records,
return_logprob=self.return_logprob, return_logprob=self.return_logprob,
......
...@@ -103,6 +103,7 @@ class Scheduler: ...@@ -103,6 +103,7 @@ class Scheduler:
self.disable_regex_jump_forward = server_args.disable_regex_jump_forward self.disable_regex_jump_forward = server_args.disable_regex_jump_forward
self.lora_paths = server_args.lora_paths self.lora_paths = server_args.lora_paths
self.max_loras_per_batch = server_args.max_loras_per_batch self.max_loras_per_batch = server_args.max_loras_per_batch
self.enable_overlap = server_args.enable_overlap_schedule
# Init inter-process communication # Init inter-process communication
context = zmq.Context(2) context = zmq.Context(2)
...@@ -146,7 +147,7 @@ class Scheduler: ...@@ -146,7 +147,7 @@ class Scheduler:
) )
# Launch a tensor parallel worker # Launch a tensor parallel worker
if self.server_args.enable_overlap_schedule: if self.enable_overlap:
TpWorkerClass = TpModelWorkerClient TpWorkerClass = TpModelWorkerClient
self.resolve_next_token_ids = ( self.resolve_next_token_ids = (
lambda bid, x: self.tp_worker.resolve_future_token_ids(bid) lambda bid, x: self.tp_worker.resolve_future_token_ids(bid)
...@@ -670,7 +671,7 @@ class Scheduler: ...@@ -670,7 +671,7 @@ class Scheduler:
# Mixed-style chunked prefill # Mixed-style chunked prefill
if self.is_mixed_chunk and self.running_batch is not None: if self.is_mixed_chunk and self.running_batch is not None:
self.running_batch.prepare_for_decode() self.running_batch.prepare_for_decode(self.enable_overlap)
new_batch.mix_with_running(self.running_batch) new_batch.mix_with_running(self.running_batch)
new_batch.decoding_reqs = self.running_batch.reqs new_batch.decoding_reqs = self.running_batch.reqs
self.running_batch = None self.running_batch = None
...@@ -717,7 +718,7 @@ class Scheduler: ...@@ -717,7 +718,7 @@ class Scheduler:
return return
# Update batch tensors # Update batch tensors
batch.prepare_for_decode() batch.prepare_for_decode(self.enable_overlap)
def run_batch(self, batch: ScheduleBatch): def run_batch(self, batch: ScheduleBatch):
"""Run a batch.""" """Run a batch."""
......
...@@ -51,7 +51,7 @@ class SamplingBatchInfo: ...@@ -51,7 +51,7 @@ class SamplingBatchInfo:
disable_penalizer: bool, disable_penalizer: bool,
): ):
reqs = batch.reqs reqs = batch.reqs
device = batch.input_ids.device device = batch.device
temperatures = ( temperatures = (
torch.tensor( torch.tensor(
[r.sampling_params.temperature for r in reqs], [r.sampling_params.temperature for r in reqs],
...@@ -95,7 +95,7 @@ class SamplingBatchInfo: ...@@ -95,7 +95,7 @@ class SamplingBatchInfo:
ret.penalizer_orchestrator = penaltylib.BatchedPenalizerOrchestrator( ret.penalizer_orchestrator = penaltylib.BatchedPenalizerOrchestrator(
vocab_size=vocab_size, vocab_size=vocab_size,
batch=batch, batch=batch,
device=batch.input_ids.device, device=batch.device,
Penalizers={ Penalizers={
penaltylib.BatchedFrequencyPenalizer, penaltylib.BatchedFrequencyPenalizer,
penaltylib.BatchedMinNewTokensPenalizer, penaltylib.BatchedMinNewTokensPenalizer,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment