Commit 63132045 authored by guanyu1's avatar guanyu1
Browse files

mrope的_get_position修改

parent e685ec92
...@@ -773,12 +773,18 @@ class GPUModelRunner( ...@@ -773,12 +773,18 @@ class GPUModelRunner(
def _get_positions(self, num_tokens: Any): def _get_positions(self, num_tokens: Any):
if isinstance(num_tokens, int): if isinstance(num_tokens, int):
if self.uses_mrope: if self.uses_mrope:
if self.use_1d_mrope:
return self.mrope_positions.gpu[: 3 * num_tokens].view(
num_tokens, 3
).T
return self.mrope_positions.gpu[:, :num_tokens] return self.mrope_positions.gpu[:, :num_tokens]
if self.uses_xdrope_dim > 0: if self.uses_xdrope_dim > 0:
return self.xdrope_positions.gpu[:, :num_tokens] return self.xdrope_positions.gpu[:, :num_tokens]
return self.positions.gpu[:num_tokens] return self.positions.gpu[:num_tokens]
else: else:
if self.uses_mrope: if self.uses_mrope:
if self.use_1d_mrope:
return self.mrope_positions.gpu.view(-1, 3)[num_tokens].T
return self.mrope_positions.gpu[:, num_tokens] return self.mrope_positions.gpu[:, num_tokens]
if self.uses_xdrope_dim > 0: if self.uses_xdrope_dim > 0:
return self.xdrope_positions.gpu[:, num_tokens] return self.xdrope_positions.gpu[:, num_tokens]
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment