"tests/vscode:/vscode.git/clone" did not exist on "f26650d649aac25cb3b7a6b49863e1929da5df32"
Commit 333e3374 authored by lizhigong's avatar lizhigong
Browse files

debug and fix error

parent 3bbb6e9d
...@@ -414,7 +414,6 @@ class LLMEngine: ...@@ -414,7 +414,6 @@ class LLMEngine:
self.zero_overhead = os.environ.get('VLLM_ZERO_OVERHEAD') == '1' self.zero_overhead = os.environ.get('VLLM_ZERO_OVERHEAD') == '1'
if self.zero_overhead: if self.zero_overhead:
assert os.environ.get('HIP_ALLOC_INITIALIZE') == '0'
self.async_d2h = None self.async_d2h = None
self.last_record = None self.last_record = None
self.async_event = torch.cuda.Event(enable_timing=False) self.async_event = torch.cuda.Event(enable_timing=False)
......
...@@ -79,6 +79,8 @@ class SampleRecorder: ...@@ -79,6 +79,8 @@ class SampleRecorder:
last_sampler = None last_sampler = None
zero_overhead = os.environ.get('VLLM_ZERO_OVERHEAD') == '1'
def get_last_sampler(): def get_last_sampler():
return last_sampler return last_sampler
...@@ -216,8 +218,6 @@ class Sampler(nn.Module): ...@@ -216,8 +218,6 @@ class Sampler(nn.Module):
# speculative decoding. # speculative decoding.
self.include_gpu_probs_tensor = False self.include_gpu_probs_tensor = False
self.should_modify_greedy_probs_inplace = False self.should_modify_greedy_probs_inplace = False
self.zero_overhead = os.environ.get('VLLM_ZERO_OVERHEAD') == '1'
d2d_data.zero_overhead = self.zero_overhead
def _init_sampling_tensors( def _init_sampling_tensors(
self, self,
...@@ -480,7 +480,7 @@ def _greedy_sample( ...@@ -480,7 +480,7 @@ def _greedy_sample(
same as the length of selected_seq_groups. If the corresponding same as the length of selected_seq_groups. If the corresponding
seq_group has do_sample=False, tuple contains ([], []) seq_group has do_sample=False, tuple contains ([], [])
""" """
if not d2d_data.zero_overhead: if not zero_overhead:
samples_lst = samples.tolist() samples_lst = samples.tolist()
sample_idx = 0 sample_idx = 0
results: SampleResultType = [] results: SampleResultType = []
...@@ -494,7 +494,7 @@ def _greedy_sample( ...@@ -494,7 +494,7 @@ def _greedy_sample(
assert num_parent_seqs == 1, ( assert num_parent_seqs == 1, (
"Greedy sampling should have only one seq.") "Greedy sampling should have only one seq.")
parent_ids = list(range(num_parent_seqs)) parent_ids = list(range(num_parent_seqs))
if d2d_data.zero_overhead: if zero_overhead:
assert num_parent_seqs == 1 # not support muti seqences in seqence group assert num_parent_seqs == 1 # not support muti seqences in seqence group
next_token_ids = [0] #place holder token id next_token_ids = [0] #place holder token id
else: else:
...@@ -521,7 +521,7 @@ def _random_sample( ...@@ -521,7 +521,7 @@ def _random_sample(
seq_group has do_sample=False, tuple contains ([], []) seq_group has do_sample=False, tuple contains ([], [])
""" """
# Find the maximum n value of the prompt phase requests. # Find the maximum n value of the prompt phase requests.
if not d2d_data.zero_overhead: if not zero_overhead:
random_samples = random_samples.cpu() random_samples = random_samples.cpu()
sample_idx = 0 sample_idx = 0
results: SampleResultType = [] results: SampleResultType = []
...@@ -537,7 +537,7 @@ def _random_sample( ...@@ -537,7 +537,7 @@ def _random_sample(
if is_prompt: if is_prompt:
# Prompt phase. # Prompt phase.
parent_ids = [0] * sampling_params.n parent_ids = [0] * sampling_params.n
if d2d_data.zero_overhead: if zero_overhead:
assert num_parent_seqs == 1 # not support muti seqences in seqence group assert num_parent_seqs == 1 # not support muti seqences in seqence group
next_token_ids = [0] * sampling_params.n #place holder token id next_token_ids = [0] * sampling_params.n #place holder token id
else: else:
...@@ -546,7 +546,7 @@ def _random_sample( ...@@ -546,7 +546,7 @@ def _random_sample(
else: else:
# Generation phase. # Generation phase.
parent_ids = list(range(num_parent_seqs)) parent_ids = list(range(num_parent_seqs))
if d2d_data.zero_overhead: if zero_overhead:
assert num_parent_seqs == 1 # not support muti seqences in seqence group assert num_parent_seqs == 1 # not support muti seqences in seqence group
next_token_ids = [0] * num_parent_seqs #place holder token id next_token_ids = [0] * num_parent_seqs #place holder token id
else: else:
...@@ -1310,9 +1310,7 @@ def _build_sampler_output( ...@@ -1310,9 +1310,7 @@ def _build_sampler_output(
sampled_token_ids=sampled_token_ids, sampled_token_ids=sampled_token_ids,
logprobs=logprobs_tensor, logprobs=logprobs_tensor,
deferred_sample_results_args=deferred_sample_results_args, deferred_sample_results_args=deferred_sample_results_args,
logits=logits, logits=logits)
sampler_out_tenosr = d2d_data.sampled_token_ids_tensor,
sampler_out_ids = d2d_data.seq_id)
def _get_next_prompt_tokens(seq_group: SequenceGroupToSample) -> List[int]: def _get_next_prompt_tokens(seq_group: SequenceGroupToSample) -> List[int]:
......
...@@ -909,6 +909,7 @@ class ModelInputForGPUBuilder(ModelRunnerInputBuilderBase[ModelInputForGPU]): ...@@ -909,6 +909,7 @@ class ModelInputForGPUBuilder(ModelRunnerInputBuilderBase[ModelInputForGPU]):
if self.zero_overhead: if self.zero_overhead:
last_sampler = get_last_sampler() last_sampler = get_last_sampler()
if last_sampler is not None:
update_indices = [] update_indices = []
select_indices = [] select_indices = []
for i, seq_id in enumerate(self.req_ids): for i, seq_id in enumerate(self.req_ids):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment