Commit d2c4f48b authored by wangmin6's avatar wangmin6
Browse files

Merge branch 'gy_0151_mrope_1d-0312' into 'v0.15.1-dev'

mrope_1d修改

See merge request dcutoolkit/deeplearing/vllm!490
parents b22a4a14 6b03cfdb
...@@ -156,6 +156,7 @@ if TYPE_CHECKING: ...@@ -156,6 +156,7 @@ if TYPE_CHECKING:
VLLM_MXFP4_USE_MARLIN: bool | None = None VLLM_MXFP4_USE_MARLIN: bool | None = None
VLLM_DEEPEPLL_NVFP4_DISPATCH: bool = False VLLM_DEEPEPLL_NVFP4_DISPATCH: bool = False
VLLM_V1_USE_OUTLINES_CACHE: bool = False VLLM_V1_USE_OUTLINES_CACHE: bool = False
VLLM_1D_MROPE: bool = False
VLLM_TPU_BUCKET_PADDING_GAP: int = 0 VLLM_TPU_BUCKET_PADDING_GAP: int = 0
VLLM_TPU_MOST_MODEL_LEN: int | None = None VLLM_TPU_MOST_MODEL_LEN: int | None = None
VLLM_TPU_USING_PATHWAYS: bool = False VLLM_TPU_USING_PATHWAYS: bool = False
...@@ -1889,6 +1890,8 @@ environment_variables: dict[str, Callable[[], Any]] = { ...@@ -1889,6 +1890,8 @@ environment_variables: dict[str, Callable[[], Any]] = {
"VLLM_USE_MOE_W16A16_TRITON": "VLLM_USE_MOE_W16A16_TRITON":
lambda: (os.environ.get("VLLM_USE_MOE_W16A16_TRITON", "0").lower() in lambda: (os.environ.get("VLLM_USE_MOE_W16A16_TRITON", "0").lower() in
("true", "1")), ("true", "1")),
"VLLM_1D_MROPE":
lambda: (os.environ.get("VLLM_1D_MROPE", "0").lower() in ("true", "1")),
#If set to 1/True, enable the V1 fast token-id copy path in InputBatch. #If set to 1/True, enable the V1 fast token-id copy path in InputBatch.
"VLLM_V1_FAST_TOKEN_ID_COPY": "VLLM_V1_FAST_TOKEN_ID_COPY":
lambda: (os.environ.get("VLLM_V1_FAST_TOKEN_ID_COPY", "False").lower() in lambda: (os.environ.get("VLLM_V1_FAST_TOKEN_ID_COPY", "False").lower() in
......
...@@ -395,6 +395,7 @@ class GPUModelRunner( ...@@ -395,6 +395,7 @@ class GPUModelRunner(
# Multi-modal data support # Multi-modal data support
self.mm_registry = MULTIMODAL_REGISTRY self.mm_registry = MULTIMODAL_REGISTRY
self.uses_mrope = model_config.uses_mrope self.uses_mrope = model_config.uses_mrope
self.use_1d_mrope = self.uses_mrope and envs.VLLM_1D_MROPE
self.uses_xdrope_dim = model_config.uses_xdrope_dim self.uses_xdrope_dim = model_config.uses_xdrope_dim
self.supports_mm_inputs = self.mm_registry.supports_multimodal_inputs( self.supports_mm_inputs = self.mm_registry.supports_multimodal_inputs(
model_config model_config
...@@ -610,9 +611,15 @@ class GPUModelRunner( ...@@ -610,9 +611,15 @@ class GPUModelRunner(
# identical position IDs, making M-RoPE functionally equivalent to # identical position IDs, making M-RoPE functionally equivalent to
# 1D-RoPE. # 1D-RoPE.
# See page 5 of https://arxiv.org/abs/2409.12191 # See page 5 of https://arxiv.org/abs/2409.12191
self.mrope_positions = self._make_buffer( if self.use_1d_mrope:
(3, self.max_num_tokens + 1), dtype=torch.int64 self.mrope_positions = self._make_buffer(
) 3 * (self.max_num_tokens + 1), dtype=torch.int64
)
else:
self.mrope_positions = self._make_buffer(
(3, self.max_num_tokens + 1), dtype=torch.int64
)
# Only relevant for models using XD-RoPE (e.g, HunYuan-VL) # Only relevant for models using XD-RoPE (e.g, HunYuan-VL)
if self.uses_xdrope_dim > 0: if self.uses_xdrope_dim > 0:
...@@ -768,17 +775,24 @@ class GPUModelRunner( ...@@ -768,17 +775,24 @@ class GPUModelRunner(
def _get_positions(self, num_tokens: Any): def _get_positions(self, num_tokens: Any):
if isinstance(num_tokens, int): if isinstance(num_tokens, int):
if self.uses_mrope: if self.uses_mrope:
if self.use_1d_mrope:
return self.mrope_positions.gpu[: 3 * num_tokens].view(
num_tokens, 3
).T
return self.mrope_positions.gpu[:, :num_tokens] return self.mrope_positions.gpu[:, :num_tokens]
if self.uses_xdrope_dim > 0: if self.uses_xdrope_dim > 0:
return self.xdrope_positions.gpu[:, :num_tokens] return self.xdrope_positions.gpu[:, :num_tokens]
return self.positions.gpu[:num_tokens] return self.positions.gpu[:num_tokens]
else: else:
if self.uses_mrope: if self.uses_mrope:
if self.use_1d_mrope:
return self.mrope_positions.gpu.view(-1, 3)[num_tokens].T
return self.mrope_positions.gpu[:, num_tokens] return self.mrope_positions.gpu[:, num_tokens]
if self.uses_xdrope_dim > 0: if self.uses_xdrope_dim > 0:
return self.xdrope_positions.gpu[:, num_tokens] return self.xdrope_positions.gpu[:, num_tokens]
return self.positions.gpu[num_tokens] return self.positions.gpu[num_tokens]
def _make_buffer( def _make_buffer(
self, *size: int | torch.SymInt, dtype: torch.dtype, numpy: bool = True self, *size: int | torch.SymInt, dtype: torch.dtype, numpy: bool = True
) -> CpuGpuBuffer: ) -> CpuGpuBuffer:
...@@ -789,6 +803,31 @@ class GPUModelRunner( ...@@ -789,6 +803,31 @@ class GPUModelRunner(
pin_memory=self.pin_memory, pin_memory=self.pin_memory,
with_numpy=numpy, with_numpy=numpy,
) )
def _copy_mrope_positions_to_gpu(self, num_tokens: int) -> None:
if not self.uses_mrope:
return
if self.use_1d_mrope:
num_values = 3 * num_tokens
self.mrope_positions.gpu[:num_values].copy_(
self.mrope_positions.cpu[:num_values],
non_blocking=True,
)
return
self.mrope_positions.gpu[:, :num_tokens].copy_(
self.mrope_positions.cpu[:, :num_tokens],
non_blocking=True,
)
def _copy_xdrope_positions_to_gpu(self, num_tokens: int) -> None:
if self.uses_xdrope_dim <= 0:
return
self.xdrope_positions.gpu[:, :num_tokens].copy_(
self.xdrope_positions.cpu[:, :num_tokens],
non_blocking=True,
)
def _init_model_kwargs(self): def _init_model_kwargs(self):
model_kwargs = dict[str, Any]() model_kwargs = dict[str, Any]()
...@@ -1595,16 +1634,11 @@ class GPUModelRunner( ...@@ -1595,16 +1634,11 @@ class GPUModelRunner(
if self.uses_mrope: if self.uses_mrope:
# Only relevant for models using M-RoPE (e.g, Qwen2-VL) # Only relevant for models using M-RoPE (e.g, Qwen2-VL)
self.mrope_positions.gpu[:, :total_num_scheduled_tokens].copy_( self._copy_mrope_positions_to_gpu(total_num_scheduled_tokens)
self.mrope_positions.cpu[:, :total_num_scheduled_tokens],
non_blocking=True,
)
elif self.uses_xdrope_dim > 0: elif self.uses_xdrope_dim > 0:
# Only relevant for models using XD-RoPE (e.g, HunYuan-VL) # Only relevant for models using XD-RoPE (e.g, HunYuan-VL)
self.xdrope_positions.gpu[:, :total_num_scheduled_tokens].copy_( self._copy_xdrope_positions_to_gpu(total_num_scheduled_tokens)
self.xdrope_positions.cpu[:, :total_num_scheduled_tokens],
non_blocking=True,
)
else: else:
# Common case (1D positions) # Common case (1D positions)
self.positions.copy_to_gpu(total_num_scheduled_tokens) self.positions.copy_to_gpu(total_num_scheduled_tokens)
...@@ -2045,11 +2079,11 @@ class GPUModelRunner( ...@@ -2045,11 +2079,11 @@ class GPUModelRunner(
) )
return common_prefix_len if use_cascade else 0 return common_prefix_len if use_cascade else 0
def _calc_mrope_positions(self, scheduler_output: "SchedulerOutput"): def _calc_xdrope_positions(self, scheduler_output: "SchedulerOutput"):
mrope_pos_ptr = 0 xdrope_pos_ptr = 0
for index, req_id in enumerate(self.input_batch.req_ids): for index, req_id in enumerate(self.input_batch.req_ids):
req = self.requests[req_id] req = self.requests[req_id]
assert req.mrope_positions is not None assert req.xdrope_positions is not None
num_computed_tokens = self.input_batch.num_computed_tokens_cpu[index] num_computed_tokens = self.input_batch.num_computed_tokens_cpu[index]
num_scheduled_tokens = scheduler_output.num_scheduled_tokens[req_id] num_scheduled_tokens = scheduler_output.num_scheduled_tokens[req_id]
...@@ -2067,38 +2101,43 @@ class GPUModelRunner( ...@@ -2067,38 +2101,43 @@ class GPUModelRunner(
assert num_scheduled_tokens == prompt_part_len + completion_part_len assert num_scheduled_tokens == prompt_part_len + completion_part_len
if prompt_part_len > 0: if prompt_part_len > 0:
# prompt's mrope_positions are pre-computed # prompt's xdrope_positions are pre-computed
dst_start = mrope_pos_ptr dst_start = xdrope_pos_ptr
dst_end = mrope_pos_ptr + prompt_part_len dst_end = xdrope_pos_ptr + prompt_part_len
src_start = num_computed_tokens src_start = num_computed_tokens
src_end = num_computed_tokens + prompt_part_len src_end = num_computed_tokens + prompt_part_len
self.mrope_positions.cpu[:, dst_start:dst_end] = req.mrope_positions[ self.xdrope_positions.cpu[:, dst_start:dst_end] = req.xdrope_positions[
:, src_start:src_end :, src_start:src_end
] ]
mrope_pos_ptr += prompt_part_len xdrope_pos_ptr += prompt_part_len
if completion_part_len > 0: if completion_part_len > 0:
# compute completion's mrope_positions on-the-fly # compute completion's xdrope_positions on-the-fly
dst_start = mrope_pos_ptr dst_start = xdrope_pos_ptr
dst_end = mrope_pos_ptr + completion_part_len dst_end = xdrope_pos_ptr + completion_part_len
assert req.mrope_position_delta is not None XDRotaryEmbedding.get_next_input_positions_tensor(
MRotaryEmbedding.get_next_input_positions_tensor( out=self.xdrope_positions.np,
out=self.mrope_positions.np,
out_offset=dst_start, out_offset=dst_start,
mrope_position_delta=req.mrope_position_delta,
context_len=num_computed_tokens + prompt_part_len, context_len=num_computed_tokens + prompt_part_len,
num_new_tokens=completion_part_len, num_new_tokens=completion_part_len,
) )
mrope_pos_ptr += completion_part_len xdrope_pos_ptr += completion_part_len
def _calc_xdrope_positions(self, scheduler_output: "SchedulerOutput"): def _calc_mrope_positions(self, scheduler_output: "SchedulerOutput"):
xdrope_pos_ptr = 0 mrope_pos_ptr = 0
if self.use_1d_mrope:
mrope_positions_token_major = self.mrope_positions.cpu.view(
self.max_num_tokens + 1, 3
)
mrope_positions_token_major_np = self.mrope_positions.np.reshape(
self.max_num_tokens + 1, 3
)
for index, req_id in enumerate(self.input_batch.req_ids): for index, req_id in enumerate(self.input_batch.req_ids):
req = self.requests[req_id] req = self.requests[req_id]
assert req.xdrope_positions is not None assert req.mrope_positions is not None
num_computed_tokens = self.input_batch.num_computed_tokens_cpu[index] num_computed_tokens = self.input_batch.num_computed_tokens_cpu[index]
num_scheduled_tokens = scheduler_output.num_scheduled_tokens[req_id] num_scheduled_tokens = scheduler_output.num_scheduled_tokens[req_id]
...@@ -2116,30 +2155,50 @@ class GPUModelRunner( ...@@ -2116,30 +2155,50 @@ class GPUModelRunner(
assert num_scheduled_tokens == prompt_part_len + completion_part_len assert num_scheduled_tokens == prompt_part_len + completion_part_len
if prompt_part_len > 0: if prompt_part_len > 0:
# prompt's xdrope_positions are pre-computed # prompt's mrope_positions are pre-computed
dst_start = xdrope_pos_ptr dst_start = mrope_pos_ptr
dst_end = xdrope_pos_ptr + prompt_part_len dst_end = mrope_pos_ptr + prompt_part_len
src_start = num_computed_tokens src_start = num_computed_tokens
src_end = num_computed_tokens + prompt_part_len src_end = num_computed_tokens + prompt_part_len
self.xdrope_positions.cpu[:, dst_start:dst_end] = req.xdrope_positions[ if self.use_1d_mrope:
:, src_start:src_end mrope_positions_token_major[dst_start:dst_end, :].copy_(
] req.mrope_positions[:, src_start:src_end].transpose(0, 1)
xdrope_pos_ptr += prompt_part_len )
else:
self.mrope_positions.cpu[:, dst_start:dst_end] = req.mrope_positions[
:, src_start:src_end
]
mrope_pos_ptr += prompt_part_len
if completion_part_len > 0: if completion_part_len > 0:
# compute completion's xdrope_positions on-the-fly # compute completion's mrope_positions on-the-fly
dst_start = xdrope_pos_ptr dst_start = mrope_pos_ptr
dst_end = xdrope_pos_ptr + completion_part_len dst_end = mrope_pos_ptr + completion_part_len
XDRotaryEmbedding.get_next_input_positions_tensor( assert req.mrope_position_delta is not None
out=self.xdrope_positions.np, if self.use_1d_mrope:
out_offset=dst_start, values = np.arange(
context_len=num_computed_tokens + prompt_part_len, req.mrope_position_delta + num_computed_tokens + prompt_part_len,
num_new_tokens=completion_part_len, req.mrope_position_delta
) + num_computed_tokens
+ prompt_part_len
+ completion_part_len,
dtype=mrope_positions_token_major_np.dtype,
)
mrope_positions_token_major_np[dst_start:dst_end, :] = values[
:, None
]
else:
MRotaryEmbedding.get_next_input_positions_tensor(
out=self.mrope_positions.np,
out_offset=dst_start,
mrope_position_delta=req.mrope_position_delta,
context_len=num_computed_tokens + prompt_part_len,
num_new_tokens=completion_part_len,
)
xdrope_pos_ptr += completion_part_len mrope_pos_ptr += completion_part_len
def _calc_spec_decode_metadata( def _calc_spec_decode_metadata(
self, self,
...@@ -2574,11 +2633,11 @@ class GPUModelRunner( ...@@ -2574,11 +2633,11 @@ class GPUModelRunner(
if should_sync_mrope_positions: if should_sync_mrope_positions:
self._calc_mrope_positions(scheduler_output) self._calc_mrope_positions(scheduler_output)
self.mrope_positions.copy_to_gpu(total_num_scheduled_tokens) self._copy_mrope_positions_to_gpu(total_num_scheduled_tokens)
if should_sync_xdrope_positions: if should_sync_xdrope_positions:
self._calc_xdrope_positions(scheduler_output) self._calc_xdrope_positions(scheduler_output)
self.xdrope_positions.copy_to_gpu(total_num_scheduled_tokens) self._copy_xdrope_positions_to_gpu(total_num_scheduled_tokens)
return mm_embeds, is_mm_embed return mm_embeds, is_mm_embed
...@@ -2837,12 +2896,7 @@ class GPUModelRunner( ...@@ -2837,12 +2896,7 @@ class GPUModelRunner(
inputs_embeds = None inputs_embeds = None
model_kwargs = self._init_model_kwargs() model_kwargs = self._init_model_kwargs()
if self.uses_mrope: positions = self._get_positions(num_input_tokens)
positions = self.mrope_positions.gpu[:, :num_input_tokens]
elif self.uses_xdrope_dim > 0:
positions = self.xdrope_positions.gpu[:, :num_input_tokens]
else:
positions = self.positions.gpu[:num_input_tokens]
if is_first_rank: if is_first_rank:
intermediate_tensors = None intermediate_tensors = None
...@@ -4727,12 +4781,7 @@ class GPUModelRunner( ...@@ -4727,12 +4781,7 @@ class GPUModelRunner(
input_ids = self.input_ids.gpu[:num_tokens_padded] input_ids = self.input_ids.gpu[:num_tokens_padded]
inputs_embeds = None inputs_embeds = None
if self.uses_mrope: positions = self._get_positions(num_tokens_padded)
positions = self.mrope_positions.gpu[:, :num_tokens_padded]
elif self.uses_xdrope_dim > 0:
positions = self.xdrope_positions.gpu[:, :num_tokens_padded]
else:
positions = self.positions.gpu[:num_tokens_padded]
if get_pp_group().is_first_rank: if get_pp_group().is_first_rank:
intermediate_tensors = None intermediate_tensors = None
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment