Commit fb445dde authored by guanyu1's avatar guanyu1
Browse files

误删函数恢复

parent ef7e1214
...@@ -2073,11 +2073,11 @@ class GPUModelRunner( ...@@ -2073,11 +2073,11 @@ class GPUModelRunner(
) )
return common_prefix_len if use_cascade else 0 return common_prefix_len if use_cascade else 0
def _calc_mrope_positions(self, scheduler_output: "SchedulerOutput"): def _calc_xdrope_positions(self, scheduler_output: "SchedulerOutput"):
mrope_pos_ptr = 0 xdrope_pos_ptr = 0
for index, req_id in enumerate(self.input_batch.req_ids): for index, req_id in enumerate(self.input_batch.req_ids):
req = self.requests[req_id] req = self.requests[req_id]
assert req.mrope_positions is not None assert req.xdrope_positions is not None
num_computed_tokens = self.input_batch.num_computed_tokens_cpu[index] num_computed_tokens = self.input_batch.num_computed_tokens_cpu[index]
num_scheduled_tokens = scheduler_output.num_scheduled_tokens[req_id] num_scheduled_tokens = scheduler_output.num_scheduled_tokens[req_id]
...@@ -2095,33 +2095,31 @@ class GPUModelRunner( ...@@ -2095,33 +2095,31 @@ class GPUModelRunner(
assert num_scheduled_tokens == prompt_part_len + completion_part_len assert num_scheduled_tokens == prompt_part_len + completion_part_len
if prompt_part_len > 0: if prompt_part_len > 0:
# prompt's mrope_positions are pre-computed # prompt's xdrope_positions are pre-computed
dst_start = mrope_pos_ptr dst_start = xdrope_pos_ptr
dst_end = mrope_pos_ptr + prompt_part_len dst_end = xdrope_pos_ptr + prompt_part_len
src_start = num_computed_tokens src_start = num_computed_tokens
src_end = num_computed_tokens + prompt_part_len src_end = num_computed_tokens + prompt_part_len
self.mrope_positions.cpu[:, dst_start:dst_end] = req.mrope_positions[ self.xdrope_positions.cpu[:, dst_start:dst_end] = req.xdrope_positions[
:, src_start:src_end :, src_start:src_end
] ]
mrope_pos_ptr += prompt_part_len xdrope_pos_ptr += prompt_part_len
if completion_part_len > 0: if completion_part_len > 0:
# compute completion's mrope_positions on-the-fly # compute completion's xdrope_positions on-the-fly
dst_start = mrope_pos_ptr dst_start = xdrope_pos_ptr
dst_end = mrope_pos_ptr + completion_part_len dst_end = xdrope_pos_ptr + completion_part_len
assert req.mrope_position_delta is not None XDRotaryEmbedding.get_next_input_positions_tensor(
MRotaryEmbedding.get_next_input_positions_tensor( out=self.xdrope_positions.np,
out=self.mrope_positions.np,
out_offset=dst_start, out_offset=dst_start,
mrope_position_delta=req.mrope_position_delta,
context_len=num_computed_tokens + prompt_part_len, context_len=num_computed_tokens + prompt_part_len,
num_new_tokens=completion_part_len, num_new_tokens=completion_part_len,
) )
mrope_pos_ptr += completion_part_len xdrope_pos_ptr += completion_part_len
def _calc_mrope_positions(self, scheduler_output: "SchedulerOutput"): def _calc_mrope_positions(self, scheduler_output: "SchedulerOutput"):
mrope_pos_ptr = 0 mrope_pos_ptr = 0
if self.use_1d_mrope: if self.use_1d_mrope:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment