Commit fb445dde authored by guanyu1's avatar guanyu1
Browse files

误删函数恢复

parent ef7e1214
......@@ -2073,11 +2073,11 @@ class GPUModelRunner(
)
return common_prefix_len if use_cascade else 0
def _calc_mrope_positions(self, scheduler_output: "SchedulerOutput"):
mrope_pos_ptr = 0
def _calc_xdrope_positions(self, scheduler_output: "SchedulerOutput"):
xdrope_pos_ptr = 0
for index, req_id in enumerate(self.input_batch.req_ids):
req = self.requests[req_id]
assert req.mrope_positions is not None
assert req.xdrope_positions is not None
num_computed_tokens = self.input_batch.num_computed_tokens_cpu[index]
num_scheduled_tokens = scheduler_output.num_scheduled_tokens[req_id]
......@@ -2095,32 +2095,30 @@ class GPUModelRunner(
assert num_scheduled_tokens == prompt_part_len + completion_part_len
if prompt_part_len > 0:
# prompt's mrope_positions are pre-computed
dst_start = mrope_pos_ptr
dst_end = mrope_pos_ptr + prompt_part_len
# prompt's xdrope_positions are pre-computed
dst_start = xdrope_pos_ptr
dst_end = xdrope_pos_ptr + prompt_part_len
src_start = num_computed_tokens
src_end = num_computed_tokens + prompt_part_len
self.mrope_positions.cpu[:, dst_start:dst_end] = req.mrope_positions[
self.xdrope_positions.cpu[:, dst_start:dst_end] = req.xdrope_positions[
:, src_start:src_end
]
mrope_pos_ptr += prompt_part_len
xdrope_pos_ptr += prompt_part_len
if completion_part_len > 0:
# compute completion's mrope_positions on-the-fly
dst_start = mrope_pos_ptr
dst_end = mrope_pos_ptr + completion_part_len
# compute completion's xdrope_positions on-the-fly
dst_start = xdrope_pos_ptr
dst_end = xdrope_pos_ptr + completion_part_len
assert req.mrope_position_delta is not None
MRotaryEmbedding.get_next_input_positions_tensor(
out=self.mrope_positions.np,
XDRotaryEmbedding.get_next_input_positions_tensor(
out=self.xdrope_positions.np,
out_offset=dst_start,
mrope_position_delta=req.mrope_position_delta,
context_len=num_computed_tokens + prompt_part_len,
num_new_tokens=completion_part_len,
)
mrope_pos_ptr += completion_part_len
xdrope_pos_ptr += completion_part_len
def _calc_mrope_positions(self, scheduler_output: "SchedulerOutput"):
mrope_pos_ptr = 0
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment