Unverified Commit 9a3b8832 authored by Vadim Gimpelson's avatar Vadim Gimpelson Committed by GitHub
Browse files

[PERF] Speedup of MRoPE prepare inputs (#19939)


Signed-off-by: default avatarVadim Gimpelson <vadim.gimpelson@centml.ai>
parent 3014c920
...@@ -26,6 +26,7 @@ ...@@ -26,6 +26,7 @@
import math import math
from typing import Any, Optional, Union from typing import Any, Optional, Union
import numpy as np
import torch import torch
import torch.nn as nn import torch.nn as nn
from transformers import PretrainedConfig from transformers import PretrainedConfig
...@@ -1458,15 +1459,14 @@ class MRotaryEmbedding(RotaryEmbedding): ...@@ -1458,15 +1459,14 @@ class MRotaryEmbedding(RotaryEmbedding):
] ]
@staticmethod @staticmethod
def get_next_input_positions_tensor( def get_next_input_positions_tensor(out: np.ndarray, out_offset: int,
mrope_position_delta: int, mrope_position_delta: int,
context_len: int, context_len: int, num_new_tokens: int):
seq_len: int,
) -> torch.Tensor: values = np.arange(mrope_position_delta + context_len,
return torch.arange( mrope_position_delta + context_len + num_new_tokens,
mrope_position_delta + context_len, dtype=out.dtype)
mrope_position_delta + seq_len, out[:, out_offset:out_offset + num_new_tokens] = values
).expand(3, -1)
@classmethod @classmethod
def omni_get_updates_use_audio_in_video( def omni_get_updates_use_audio_in_video(
......
...@@ -262,6 +262,7 @@ class GPUModelRunner(LoRAModelRunnerMixin): ...@@ -262,6 +262,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
dtype=torch.int64, dtype=torch.int64,
device="cpu", device="cpu",
pin_memory=self.pin_memory) pin_memory=self.pin_memory)
self.mrope_positions_np = self.mrope_positions_cpu.numpy()
# Only relevant for models using ALiBi (e.g, MPT) # Only relevant for models using ALiBi (e.g, MPT)
self.use_alibi = check_use_alibi(model_config) self.use_alibi = check_use_alibi(model_config)
...@@ -889,15 +890,13 @@ class GPUModelRunner(LoRAModelRunnerMixin): ...@@ -889,15 +890,13 @@ class GPUModelRunner(LoRAModelRunnerMixin):
dst_start = mrope_pos_ptr dst_start = mrope_pos_ptr
dst_end = mrope_pos_ptr + completion_part_len dst_end = mrope_pos_ptr + completion_part_len
self.mrope_positions_cpu[:, dst_start:dst_end] = \ MRotaryEmbedding.get_next_input_positions_tensor(
MRotaryEmbedding.get_next_input_positions_tensor( out=self.mrope_positions_np,
req.mrope_position_delta, out_offset=dst_start,
context_len=num_computed_tokens + mrope_position_delta=req.mrope_position_delta,
prompt_part_len, context_len=num_computed_tokens + prompt_part_len,
seq_len=num_computed_tokens + num_new_tokens=completion_part_len,
prompt_part_len + )
completion_part_len,
)
mrope_pos_ptr += completion_part_len mrope_pos_ptr += completion_part_len
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment