Commit 2799cc7a authored by guanyu1's avatar guanyu1
Browse files

删除1d_mrope

parent 34d497a1
......@@ -157,7 +157,6 @@ if TYPE_CHECKING:
VLLM_MXFP4_USE_MARLIN: bool | None = None
VLLM_DEEPEPLL_NVFP4_DISPATCH: bool = False
VLLM_V1_USE_OUTLINES_CACHE: bool = False
VLLM_1D_MROPE: bool = False
VLLM_TPU_BUCKET_PADDING_GAP: int = 0
VLLM_TPU_MOST_MODEL_LEN: int | None = None
VLLM_TPU_USING_PATHWAYS: bool = False
......@@ -1926,8 +1925,6 @@ environment_variables: dict[str, Callable[[], Any]] = {
"VLLM_USE_MOE_W16A16_TRITON":
lambda: (os.environ.get("VLLM_USE_MOE_W16A16_TRITON", "0").lower() in
("true", "1")),
"VLLM_1D_MROPE":
lambda: (os.environ.get("VLLM_1D_MROPE", "0").lower() in ("true", "1")),
#If set to 1/True, enable the V1 fast token-id copy path in InputBatch.
"VLLM_V1_FAST_TOKEN_ID_COPY":
lambda: (os.environ.get("VLLM_V1_FAST_TOKEN_ID_COPY", "False").lower() in
......
......@@ -398,7 +398,6 @@ class GPUModelRunner(
# Multi-modal data support
self.mm_registry = MULTIMODAL_REGISTRY
self.uses_mrope = model_config.uses_mrope
self.use_1d_mrope = self.uses_mrope and envs.VLLM_1D_MROPE
self.uses_xdrope_dim = model_config.uses_xdrope_dim
self.supports_mm_inputs = self.mm_registry.supports_multimodal_inputs(
model_config
......@@ -614,15 +613,9 @@ class GPUModelRunner(
# identical position IDs, making M-RoPE functionally equivalent to
# 1D-RoPE.
# See page 5 of https://arxiv.org/abs/2409.12191
if self.use_1d_mrope:
self.mrope_positions = self._make_buffer(
3 * (self.max_num_tokens + 1), dtype=torch.int64
)
else:
self.mrope_positions = self._make_buffer(
(3, self.max_num_tokens + 1), dtype=torch.int64
)
self.mrope_positions = self._make_buffer(
(3, self.max_num_tokens + 1), dtype=torch.int64
)
# Only relevant for models using XD-RoPE (e.g, HunYuan-VL)
if self.uses_xdrope_dim > 0:
......@@ -778,18 +771,12 @@ class GPUModelRunner(
def _get_positions(self, num_tokens: Any):
if isinstance(num_tokens, int):
if self.uses_mrope:
if self.use_1d_mrope:
return self.mrope_positions.gpu[: 3 * num_tokens].view(
num_tokens, 3
).T
return self.mrope_positions.gpu[:, :num_tokens]
if self.uses_xdrope_dim > 0:
return self.xdrope_positions.gpu[:, :num_tokens]
return self.positions.gpu[:num_tokens]
else:
if self.uses_mrope:
if self.use_1d_mrope:
return self.mrope_positions.gpu.view(-1, 3)[num_tokens].T
return self.mrope_positions.gpu[:, num_tokens]
if self.uses_xdrope_dim > 0:
return self.xdrope_positions.gpu[:, num_tokens]
......@@ -806,17 +793,10 @@ class GPUModelRunner(
pin_memory=self.pin_memory,
with_numpy=numpy,
)
def _copy_mrope_positions_to_gpu(self, num_tokens: int) -> None:
if not self.uses_mrope:
return
if self.use_1d_mrope:
num_values = 3 * num_tokens
self.mrope_positions.gpu[:num_values].copy_(
self.mrope_positions.cpu[:num_values],
non_blocking=True,
)
return
self.mrope_positions.gpu[:, :num_tokens].copy_(
self.mrope_positions.cpu[:, :num_tokens],
non_blocking=True,
......@@ -825,7 +805,7 @@ class GPUModelRunner(
def _copy_xdrope_positions_to_gpu(self, num_tokens: int) -> None:
if self.uses_xdrope_dim <= 0:
return
self.xdrope_positions.gpu[:, :num_tokens].copy_(
self.xdrope_positions.cpu[:, :num_tokens],
non_blocking=True,
......@@ -2131,13 +2111,6 @@ class GPUModelRunner(
def _calc_mrope_positions(self, scheduler_output: "SchedulerOutput"):
mrope_pos_ptr = 0
if self.use_1d_mrope:
mrope_positions_token_major = self.mrope_positions.cpu.view(
self.max_num_tokens + 1, 3
)
mrope_positions_token_major_np = self.mrope_positions.np.reshape(
self.max_num_tokens + 1, 3
)
for index, req_id in enumerate(self.input_batch.req_ids):
req = self.requests[req_id]
assert req.mrope_positions is not None
......@@ -2164,14 +2137,9 @@ class GPUModelRunner(
src_start = num_computed_tokens
src_end = num_computed_tokens + prompt_part_len
if self.use_1d_mrope:
mrope_positions_token_major[dst_start:dst_end, :].copy_(
req.mrope_positions[:, src_start:src_end].transpose(0, 1)
)
else:
self.mrope_positions.cpu[:, dst_start:dst_end] = req.mrope_positions[
:, src_start:src_end
]
self.mrope_positions.cpu[:, dst_start:dst_end] = req.mrope_positions[
:, src_start:src_end
]
mrope_pos_ptr += prompt_part_len
if completion_part_len > 0:
......@@ -2180,26 +2148,13 @@ class GPUModelRunner(
dst_end = mrope_pos_ptr + completion_part_len
assert req.mrope_position_delta is not None
if self.use_1d_mrope:
values = np.arange(
req.mrope_position_delta + num_computed_tokens + prompt_part_len,
req.mrope_position_delta
+ num_computed_tokens
+ prompt_part_len
+ completion_part_len,
dtype=mrope_positions_token_major_np.dtype,
)
mrope_positions_token_major_np[dst_start:dst_end, :] = values[
:, None
]
else:
MRotaryEmbedding.get_next_input_positions_tensor(
out=self.mrope_positions.np,
out_offset=dst_start,
mrope_position_delta=req.mrope_position_delta,
context_len=num_computed_tokens + prompt_part_len,
num_new_tokens=completion_part_len,
)
MRotaryEmbedding.get_next_input_positions_tensor(
out=self.mrope_positions.np,
out_offset=dst_start,
mrope_position_delta=req.mrope_position_delta,
context_len=num_computed_tokens + prompt_part_len,
num_new_tokens=completion_part_len,
)
mrope_pos_ptr += completion_part_len
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment