Commit b01c8270 authored by lizhigong's avatar lizhigong
Browse files

delete triton kernel ,use tensor indices

parent 1ed30424
...@@ -263,7 +263,6 @@ class ZeroOverheadEngine(LLMEngine): ...@@ -263,7 +263,6 @@ class ZeroOverheadEngine(LLMEngine):
self._skip_scheduling_next_step = False self._skip_scheduling_next_step = False
self.async_d2h = None self.async_d2h = None
self.last_record = None self.last_record = None
assert os.environ.get('HIP_ALLOC_INITIALIZE') == '0'
self.async_event = torch.cuda.Event(enable_timing=False) self.async_event = torch.cuda.Event(enable_timing=False)
self.thread_running = False self.thread_running = False
self.q_recorder = queue.Queue() self.q_recorder = queue.Queue()
...@@ -410,7 +409,7 @@ class ZeroOverheadEngine(LLMEngine): ...@@ -410,7 +409,7 @@ class ZeroOverheadEngine(LLMEngine):
#sample_out_list = output[0].sampler_out_tenosr.cpu().tolist() #sample_out_list = output[0].sampler_out_tenosr.cpu().tolist()
sample_out_list = self.async_d2h.tolist() sample_out_list = self.async_d2h.tolist()
sample_out_ids = last_sampler.seq_id.tolist() sample_out_ids = last_sampler.seq_ids
for seq_group_metadata, sequence_group_outputs, scheduled_seq_group in \ for seq_group_metadata, sequence_group_outputs, scheduled_seq_group in \
zip(seq_group_metadata_list, output[0], scheduled_seq_groups): zip(seq_group_metadata_list, output[0], scheduled_seq_groups):
seq_group = scheduled_seq_group.seq_group seq_group = scheduled_seq_group.seq_group
......
...@@ -36,11 +36,20 @@ class ZeroOverheadModelInputForGpuBuilder(ModelInputForGPUBuilder): ...@@ -36,11 +36,20 @@ class ZeroOverheadModelInputForGpuBuilder(ModelInputForGPUBuilder):
model_input = super().build() model_input = super().build()
last_sampler = get_last_sampler() last_sampler = get_last_sampler()
if last_sampler is not None: if last_sampler is not None:
input_ids = async_tensor_h2d(self.req_ids, torch.long, update_indices = []
select_indices = []
for i, seq_id in enumerate(self.req_ids):
for j, seq_id_ in enumerate(last_sampler.seq_ids):
if seq_id == seq_id_:
select_indices.append(j)
update_indices.append(i)
break
select_indices = async_tensor_h2d(select_indices, torch.long,
self.runner.device, self.runner.device,
self.runner.pin_memory) self.runner.pin_memory)
last_ids = async_tensor_h2d(last_sampler.seq_id.tolist(), torch.long, update_indices = async_tensor_h2d(update_indices, torch.long,
self.runner.device, self.runner.device,
self.runner.pin_memory) self.runner.pin_memory)
UpdateInputTokens(model_input.input_tokens, input_ids, last_sampler.sampled_token_ids_tensor, last_ids) if len(select_indices) > 0:
model_input.input_tokens[update_indices] = last_sampler.sampled_token_ids_tensor[select_indices, 0]
return model_input return model_input
...@@ -22,7 +22,7 @@ else: ...@@ -22,7 +22,7 @@ else:
class SampleRecorder: class SampleRecorder:
def __init__(self): def __init__(self):
self.seq_id:torch.Tensor = None self.seq_ids:torch.Tensor = None
self.sampled_token_ids_tensor:torch.Tensor = None self.sampled_token_ids_tensor:torch.Tensor = None
last_sampler = None last_sampler = None
...@@ -275,10 +275,10 @@ def _sample_with_torch( ...@@ -275,10 +275,10 @@ def _sample_with_torch(
t: [] t: []
for t in SamplingType for t in SamplingType
} }
last_sampler.seq_id = torch.zeros(len(sampling_metadata.seq_groups), dtype=torch.int32) last_sampler.seq_ids = []
categorized_sample_indices = sampling_metadata.categorized_sample_indices categorized_sample_indices = sampling_metadata.categorized_sample_indices
for i, seq_group in enumerate(sampling_metadata.seq_groups): for i, seq_group in enumerate(sampling_metadata.seq_groups):
last_sampler.seq_id[i] = seq_group.seq_ids[0] last_sampler.seq_ids.append(seq_group.seq_ids[0])
sampling_params = seq_group.sampling_params sampling_params = seq_group.sampling_params
sampling_type = sampling_params.sampling_type sampling_type = sampling_params.sampling_type
categorized_seq_group_ids[sampling_type].append(i) categorized_seq_group_ids[sampling_type].append(i)
......
import torch
import triton
import triton.language as tl
@triton.jit
def _update_input_tokens(
sample_output,
seq_ids,
input_tokens,
input_seq_ids,
BATCH_SIZE1,
BATCH_SIZE2,
):
pid = tl.program_id(0)
if pid >= BATCH_SIZE2:
return
output_token = tl.load(input_tokens + pid)
_input_seq_id = tl.load(input_seq_ids + pid)
for i in range(BATCH_SIZE1):
_seq_ids = tl.load(seq_ids + i)
if _seq_ids == _input_seq_id:
output_token = tl.load(sample_output + i)
tl.store(input_tokens + pid, output_token)
_update_input_tokens_ptr = None
def UpdateInputTokens(input_tokens, input_seq_ids, last_sample, last_ids):
global _update_input_tokens_ptr
grid = [input_seq_ids.shape[0], 1, 1]
if _update_input_tokens_ptr is None:
_update_input_tokens_ptr = _update_input_tokens[grid](last_sample, last_ids, input_tokens, input_seq_ids, last_ids.shape[0], input_seq_ids.shape[0])
else:
_update_input_tokens_ptr[grid](last_sample, last_ids, input_tokens, input_seq_ids, last_ids.shape[0], input_seq_ids.shape[0])
\ No newline at end of file
...@@ -9,4 +9,12 @@ def is_zero_overhead(): ...@@ -9,4 +9,12 @@ def is_zero_overhead():
return zero_overhead return zero_overhead
def is_zero_no_thread(): def is_zero_no_thread():
return zero_no_thread and zero_overhead return zero_no_thread and zero_overhead
\ No newline at end of file
def UpdateInputTokens(input_tokens, last_sample, indices):
global _update_input_tokens_ptr
grid = [input_tokens.shape[0], 1, 1]
if _update_input_tokens_ptr is None:
_update_input_tokens_ptr = _update_input_tokens[grid](last_sample, input_tokens, indices, input_tokens.shape[0])
else:
_update_input_tokens_ptr[grid](last_sample, input_tokens, indices, input_tokens.shape[0])
\ No newline at end of file
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment