Commit c07d9253 authored by guanyu1's avatar guanyu1
Browse files

补充1d_mrope

parent 2d940766
...@@ -157,6 +157,7 @@ if TYPE_CHECKING: ...@@ -157,6 +157,7 @@ if TYPE_CHECKING:
VLLM_MXFP4_USE_MARLIN: bool | None = None VLLM_MXFP4_USE_MARLIN: bool | None = None
VLLM_DEEPEPLL_NVFP4_DISPATCH: bool = False VLLM_DEEPEPLL_NVFP4_DISPATCH: bool = False
VLLM_V1_USE_OUTLINES_CACHE: bool = False VLLM_V1_USE_OUTLINES_CACHE: bool = False
VLLM_1D_MROPE: bool = False
VLLM_ENCODER_CACHE_SIZE: int | None = None VLLM_ENCODER_CACHE_SIZE: int | None = None
VLLM_TPU_BUCKET_PADDING_GAP: int = 0 VLLM_TPU_BUCKET_PADDING_GAP: int = 0
VLLM_TPU_MOST_MODEL_LEN: int | None = None VLLM_TPU_MOST_MODEL_LEN: int | None = None
...@@ -1926,6 +1927,8 @@ environment_variables: dict[str, Callable[[], Any]] = { ...@@ -1926,6 +1927,8 @@ environment_variables: dict[str, Callable[[], Any]] = {
"VLLM_USE_MOE_W16A16_TRITON": "VLLM_USE_MOE_W16A16_TRITON":
lambda: (os.environ.get("VLLM_USE_MOE_W16A16_TRITON", "0").lower() in lambda: (os.environ.get("VLLM_USE_MOE_W16A16_TRITON", "0").lower() in
("true", "1")), ("true", "1")),
"VLLM_1D_MROPE":
lambda: (os.environ.get("VLLM_1D_MROPE", "0").lower() in ("true", "1")),
"VLLM_ENCODER_CACHE_SIZE": "VLLM_ENCODER_CACHE_SIZE":
lambda: maybe_convert_int(os.environ.get("VLLM_ENCODER_CACHE_SIZE", None)), lambda: maybe_convert_int(os.environ.get("VLLM_ENCODER_CACHE_SIZE", None)),
#If set to 1/True, enable the V1 fast token-id copy path in InputBatch. #If set to 1/True, enable the V1 fast token-id copy path in InputBatch.
......
...@@ -398,6 +398,7 @@ class GPUModelRunner( ...@@ -398,6 +398,7 @@ class GPUModelRunner(
# Multi-modal data support # Multi-modal data support
self.mm_registry = MULTIMODAL_REGISTRY self.mm_registry = MULTIMODAL_REGISTRY
self.uses_mrope = model_config.uses_mrope self.uses_mrope = model_config.uses_mrope
self.use_1d_mrope = self.uses_mrope and envs.VLLM_1D_MROPE
self.uses_xdrope_dim = model_config.uses_xdrope_dim self.uses_xdrope_dim = model_config.uses_xdrope_dim
self.supports_mm_inputs = self.mm_registry.supports_multimodal_inputs( self.supports_mm_inputs = self.mm_registry.supports_multimodal_inputs(
model_config model_config
...@@ -613,9 +614,14 @@ class GPUModelRunner( ...@@ -613,9 +614,14 @@ class GPUModelRunner(
# identical position IDs, making M-RoPE functionally equivalent to # identical position IDs, making M-RoPE functionally equivalent to
# 1D-RoPE. # 1D-RoPE.
# See page 5 of https://arxiv.org/abs/2409.12191 # See page 5 of https://arxiv.org/abs/2409.12191
self.mrope_positions = self._make_buffer( if self.use_1d_mrope:
(3, self.max_num_tokens + 1), dtype=torch.int64 self.mrope_positions = self._make_buffer(
) 3 * (self.max_num_tokens + 1), dtype=torch.int64
)
else:
self.mrope_positions = self._make_buffer(
(3, self.max_num_tokens + 1), dtype=torch.int64
)
# Only relevant for models using XD-RoPE (e.g, HunYuan-VL) # Only relevant for models using XD-RoPE (e.g, HunYuan-VL)
if self.uses_xdrope_dim > 0: if self.uses_xdrope_dim > 0:
...@@ -771,12 +777,18 @@ class GPUModelRunner( ...@@ -771,12 +777,18 @@ class GPUModelRunner(
def _get_positions(self, num_tokens: Any): def _get_positions(self, num_tokens: Any):
if isinstance(num_tokens, int): if isinstance(num_tokens, int):
if self.uses_mrope: if self.uses_mrope:
if self.use_1d_mrope:
return self.mrope_positions.gpu[: 3 * num_tokens].view(
num_tokens, 3
).T
return self.mrope_positions.gpu[:, :num_tokens] return self.mrope_positions.gpu[:, :num_tokens]
if self.uses_xdrope_dim > 0: if self.uses_xdrope_dim > 0:
return self.xdrope_positions.gpu[:, :num_tokens] return self.xdrope_positions.gpu[:, :num_tokens]
return self.positions.gpu[:num_tokens] return self.positions.gpu[:num_tokens]
else: else:
if self.uses_mrope: if self.uses_mrope:
if self.use_1d_mrope:
return self.mrope_positions.gpu.view(-1, 3)[num_tokens].T
return self.mrope_positions.gpu[:, num_tokens] return self.mrope_positions.gpu[:, num_tokens]
if self.uses_xdrope_dim > 0: if self.uses_xdrope_dim > 0:
return self.xdrope_positions.gpu[:, num_tokens] return self.xdrope_positions.gpu[:, num_tokens]
...@@ -797,6 +809,13 @@ class GPUModelRunner( ...@@ -797,6 +809,13 @@ class GPUModelRunner(
def _copy_mrope_positions_to_gpu(self, num_tokens: int) -> None: def _copy_mrope_positions_to_gpu(self, num_tokens: int) -> None:
if not self.uses_mrope: if not self.uses_mrope:
return return
if self.use_1d_mrope:
num_values = 3 * num_tokens
self.mrope_positions.gpu[:num_values].copy_(
self.mrope_positions.cpu[:num_values],
non_blocking=True,
)
return
self.mrope_positions.gpu[:, :num_tokens].copy_( self.mrope_positions.gpu[:, :num_tokens].copy_(
self.mrope_positions.cpu[:, :num_tokens], self.mrope_positions.cpu[:, :num_tokens],
non_blocking=True, non_blocking=True,
...@@ -2111,6 +2130,13 @@ class GPUModelRunner( ...@@ -2111,6 +2130,13 @@ class GPUModelRunner(
def _calc_mrope_positions(self, scheduler_output: "SchedulerOutput"): def _calc_mrope_positions(self, scheduler_output: "SchedulerOutput"):
mrope_pos_ptr = 0 mrope_pos_ptr = 0
if self.use_1d_mrope:
mrope_positions_token_major = self.mrope_positions.cpu.view(
self.max_num_tokens + 1, 3
)
mrope_positions_token_major_np = self.mrope_positions.np.reshape(
self.max_num_tokens + 1, 3
)
for index, req_id in enumerate(self.input_batch.req_ids): for index, req_id in enumerate(self.input_batch.req_ids):
req = self.requests[req_id] req = self.requests[req_id]
assert req.mrope_positions is not None assert req.mrope_positions is not None
...@@ -2137,9 +2163,14 @@ class GPUModelRunner( ...@@ -2137,9 +2163,14 @@ class GPUModelRunner(
src_start = num_computed_tokens src_start = num_computed_tokens
src_end = num_computed_tokens + prompt_part_len src_end = num_computed_tokens + prompt_part_len
self.mrope_positions.cpu[:, dst_start:dst_end] = req.mrope_positions[ if self.use_1d_mrope:
:, src_start:src_end mrope_positions_token_major[dst_start:dst_end, :].copy_(
] req.mrope_positions[:, src_start:src_end].transpose(0, 1)
)
else:
self.mrope_positions.cpu[:, dst_start:dst_end] = (
req.mrope_positions[:, src_start:src_end]
)
mrope_pos_ptr += prompt_part_len mrope_pos_ptr += prompt_part_len
if completion_part_len > 0: if completion_part_len > 0:
...@@ -2148,13 +2179,28 @@ class GPUModelRunner( ...@@ -2148,13 +2179,28 @@ class GPUModelRunner(
dst_end = mrope_pos_ptr + completion_part_len dst_end = mrope_pos_ptr + completion_part_len
assert req.mrope_position_delta is not None assert req.mrope_position_delta is not None
MRotaryEmbedding.get_next_input_positions_tensor( if self.use_1d_mrope:
out=self.mrope_positions.np, values = np.arange(
out_offset=dst_start, req.mrope_position_delta
mrope_position_delta=req.mrope_position_delta, + num_computed_tokens
context_len=num_computed_tokens + prompt_part_len, + prompt_part_len,
num_new_tokens=completion_part_len, req.mrope_position_delta
) + num_computed_tokens
+ prompt_part_len
+ completion_part_len,
dtype=mrope_positions_token_major_np.dtype,
)
mrope_positions_token_major_np[dst_start:dst_end, :] = values[
:, None
]
else:
MRotaryEmbedding.get_next_input_positions_tensor(
out=self.mrope_positions.np,
out_offset=dst_start,
mrope_position_delta=req.mrope_position_delta,
context_len=num_computed_tokens + prompt_part_len,
num_new_tokens=completion_part_len,
)
mrope_pos_ptr += completion_part_len mrope_pos_ptr += completion_part_len
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment